import matplotlib.pyplot as plt
import matplotlib.image as mpimg
import matplotlib.patches as patches

一些零散的知识

图像上的

Figure:可以看成整张画布。
Axis:坐标轴,分XAxis和YAxis。
Subplot:子图。
Axes:也可以看成子图,与Subplot的区别见具体代码。

创建图像

fig = plt.figure() # 创建figure

具体参数如下:
num:图像的编号,如果没有指定,则会自动分配一个编号。这个参数在创建多个图像时很有用。
figsize:图像窗口的大小,以英寸为单位。它是一个包含两个元素的元组,第一个元素表示图像的宽度,第二个元素表示图像的高度。
dpi:图像的分辨率(每英寸点数)。默认值为 100。
facecolor:图像的背景颜色。
edgecolor:图像边界的颜色。
frameon:是否显示图像边界。默认值为 True。
clear:是否清除图像窗口内容。默认值为 False。
subplotpars:子图参数配置。
tight_layout:是否自动调整子图布局。默认值为 False。

ax = fig.add_subplot(*args)

参数格式:int 或 (int, int, index) 或 int, int, int。default: (1, 1, 1)。
如果只有一个int时要为three-digit integer。
第一个第二个参数分别为column、row分几份,第三个参数为画subplot的索引,左上角为1往右递增。
返回一个创建好的空白子图

ax = fig.add_axes(*args)

sequence of float
The dimensions [left, bottom, width, height] of the new Axes.
left和width是相对于figure的宽,bottom和height是相对于figure的高

fig, axes = plt.subplots() # 创建多个subplot,返回包含图像窗口对象(<class ‘matplotlib.figure.Figure’>)和子图对象数组(<class ‘numpy.ndarray’>)的元组。

参数:
nrows:子图的行数。
ncols:子图的列数。
sharex:是否共享x轴范围。默认值为 False。
sharey:是否共享y轴范围。默认值为 False。
squeeze:如果设置为 True,当子图的行数和列数都为1时,将返回一个单一的Axes对象,而不是一个Axes对象数组。默认值为 True。
num:图像的编号,如果没有指定,则会自动分配一个编号。这个参数在创建多个图像时很有用。
figsize:图像窗口的大小,以英寸为单位。它是一个包含两个元素的元组,第一个元素表示图像的宽度,第二个元素表示图像的高度。
dpi:图像的分辨率(每英寸点数)。默认值为 100。
facecolor:图像的背景颜色。
edgecolor:图像边界的颜色。
frameon:是否显示图像边界。默认值为 True。
clear:是否清除图像窗口内容。默认值为 False。
subplotpars:子图参数配置。
tight_layout:是否自动调整子图布局。默认值为 False。

开关画板

plt.show()
plt.close()

画图逻辑

先创建figure,再设置subplot或者axies,再在子图上面进行画图。

具体操作

画点与线

在subplot和axes上画图的方法函数几乎一样。

ax.plot() # 用于在坐标轴上绘制线图(Line Plot)的函数。它可以绘制一组数据点之间的连接线,从而显示数据的趋势和关系。

参数:
x:用于表示横坐标的数据序列。
y:用于表示纵坐标的数据序列。
fmt:线的格式字符串,用于指定线的样式、颜色和标记。
color:线的颜色。
linestyle:线的样式,如 ‘-’ 表示实线,‘–’ 表示虚线等。
marker:数据点的标记样式。
markersize:数据点的标记大小。
label:数据序列的标签,用于图例。
其他参数:你还可以使用其他参数来进一步定制线图,如 linewidth、alpha(透明度)等。

ax.scatter() # 在图形中绘制散点图(Scatter Plot)

参数:
x:用于表示横坐标的数据序列。
y:用于表示纵坐标的数据序列。
marker:用于指定散点的标记样式,如’o’表示圆圈。
color:用于指定散点的颜色。
label:用于指定数据集的标签,方便添加图例。
其他参数:你还可以使用其他参数来设置散点的大小、透明度等。

ax.set_xlabel(str) # 设置X轴标签
ax.set_ylabel(str) # 设置y轴标签
ax.axis(‘off’) # 关闭坐标轴

linestyle 参数用于指定线条的样式。

以下是一些常用的线条样式(linestyle):
‘-’ 或 ‘solid’:实线(默认样式)。
‘–’ 或 ‘dashed’:虚线。
‘:’ 或 ‘dotted’:点线。
‘-.’ 或 ‘dashdot’:点划线,一种介于虚线和点线之间的样式。
‘None’ 或 ‘’:无线条,通常用于取消线条绘制。
’ '(空格):同样会绘制无线条,但是会在数据点之间插入空格,从而在绘制时留下间隙。

画图形

返回类型 <class ‘matplotlib.patches.Rectangle’>

ax.add_patch(patch)

Add a “.Patch” to the Axes; return the patch.

矩形

patches.Rectangle()

参数:
xy:矩形的左下角坐标,是一个表示 (x, y) 的元组。
width:矩形的宽度。
height:矩形的高度。
angle:矩形的旋转角度(以度为单位)。

edgecolor:矩形的边界颜色。
facecolor:矩形的填充颜色。
linewidth:矩形边界线的宽度。
linestyle:矩形边界线的样式。
alpha:矩形的透明度。
fill:是否填充,默认为True。

椭圆形

patches.Ellipse()

参数:
xy:椭圆的中心坐标,是一个表示 (x, y) 的元组。
width:椭圆的宽度(主轴的长度)。
height:椭圆的高度(次轴的长度)。
angle:椭圆的旋转角度(以度为单位)。默认为 0。
**kwargs:其他参数,用于设置椭圆的属性,如颜色、填充等。

扇形

patches.Wedge()

center:扇形的中心坐标,是一个表示 (x, y) 的元组。
r:扇形的半径。
theta1:扇形的起始角度(以度为单位)。
theta2:扇形的终止角度(以度为单位)。
**kwargs:其他参数,用于设置扇形的属性,如颜色、填充等。

打开图片

mpimg.imread(image_path) # 返回ndarray,范围为0-225,shape为(H, W, C),torch的tensor为(B, C, H, W)

参数:
fname:图像文件路径。
format:要解释图像的格式。如果不指定,将根据文件扩展名自动确定格式。

image_shape = image.shape[:2] # 获取图片的尺寸

mpimg.imsave() # 存入图片

参数:
fname:要保存的图像文件路径。
arr:要保存的图像数据,通常是一个 NumPy 数组。如果是浮点数数组,则像素值范围应在 [0, 1]。
cmap:颜色映射,用于将数据数组映射为颜色。默认为 None。
format:要保存的图像格式。如果不指定,根据文件扩展名自动确定格式。
vmin:数据数组中的最小值,用于映射颜色。如果不指定,根据数据数组的最小值确定。
vmax:数据数组中的最大值,用于映射颜色。如果不指定,根据数据数组的最大值确定。
origin:图像坐标原点的位置。默认为 None,表示左上角。
dpi:图像的分辨率(每英寸点数)。默认值为 100。

其他

np.linspace() # 生成指定的ndarray

参数:
start:序列的起始值。
stop:序列的结束值。
num:要生成的样本数量(默认为 50)。
endpoint:如果为 True,则在序列中包含结束值;如果为 False,则不包含(默认为 True)。
retstep:如果为 True,返回 (数组, 步长) 的元组,其中步长是序列中相邻值之间的间隔(默认为 False)。
dtype:输出数组的数据类型。
axis:要填充的轴。

ndarray.flatten() # 转化为一维数组

torch.stack(tensors, dim=0, out=None) # 堆叠tensor

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
import torch

# 创建几个示例张量
tensor1 = torch.tensor([1, 2, 3])
tensor2 = torch.tensor([4, 5, 6])
tensor3 = torch.tensor([7, 8, 9])

# 在默认维度上堆叠张量
stacked_tensor = torch.stack([tensor1, tensor2, tensor3])
print(stacked_tensor)
# 输出:
# tensor([[1, 2, 3],
# [4, 5, 6],
# [7, 8, 9]])

# 沿着不同的维度堆叠
stacked_tensor_dim1 = torch.stack([tensor1, tensor2, tensor3], dim=1)
print(stacked_tensor_dim1)
# 输出:
# tensor([[1, 4, 7],
# [2, 5, 8],
# [3, 6, 9]])
1
2
3
4
5
>>> torch.arange(8).reshape(2, 4)[:,0]
tensor([0, 4])
>>> torch.arange(8).reshape(2, 4)[:]
tensor([[0, 1, 2, 3],
[4, 5, 6, 7]])