张量的索引是从第零维度开始的。让我们来创建一个四维的张量做举例说明:torch.Tensor(2,3,64,64) 此时,这个张量可以表示两张边长为64的正方形彩色图像,具体来说,张量的第零维表示图像的数量;第一维表示图像的颜色通道(3即为彩色图片,代表RGB三通道);第二维和第三维代表图像的高度和宽度。此张量的索引代码如下:
import torch
a = torch.Tensor(2, 3, 64, 64)
# 通过.shape的方法查看当前张量的形状
print(a.shape)
print(a[0].shape)
print(a[0][0].shape)
print(a[0][0][0].shape)
print(a[0][0][0][0].shape)
上述代码的输出为:
torch.Size([2, 3, 64, 64]) # 图像的形状
torch.Size([3, 64, 64]) # 取到第一张图像,形状为 [3, 64, 64]
torch.Size([64, 64]) # 取到第一张图像的第一个颜色通道, 形状为[64, 64]
torch.Size([64]) # 取到第一张图像的第一个颜色通道的第一列像素值,形状为64
torch.Size([]) # 取到第一张图像的第一个颜色通道的第一个像素值,形状为0(因为是标量)
另外,需要注意的是pytorch也支持负索引,使用方法与python中的负索引相同。
在上一小节中,维度的索引是取到某维度上的全部数据。但是,如果我们只想要某维度上的部分数据应该怎么做呢?这就是切片的作用。
切片方法的格式为:tensor[ first : last : step] first与last为切片的起始和结束位置,取值方法是按照step的间隔进行左闭右开的取值;当间隔为1时,step可以默认不写;当取到该维度的所有数据时,使用冒号即可。实例如下:
import torch
a = torch.Tensor(2, 3, 64, 64)
# 通过.shape的方法查看当前张量的形状
print(a.shape)
print(a[1:2, :, :, :].shape)
print(a[ : , : , 0:32, 0:32].shape)
print(a[ : , : , 0:32:2, 0:32:2].shape)
print(a[ : , : , : : 2, : : 2].shape)
上述代码的输出为:
torch.Size([2, 3, 64, 64]) # 图像的形状
torch.Size([1, 3, 64, 64]) # 取到第二张图像
torch.Size([2, 3, 32, 32]) # 取到两张图像1/4大小的左上角子图
torch.Size([2, 3, 16, 16]) # 取到两张图像1/4大小的左上角子图后,在子图上隔点取样
torch.Size([2, 3, 32, 32]) # 在原图上隔点取样
在之后的学习中我们会发现每个算法模型都有自己要求的输入数据维度,每个问题下的数据维度也不同。因为,为了使用各种的算法来处理各种的问题往往需要对数据进行维度的变换。例如,如果想用神经网络层来处理图像数据,我们就可以发现,图片是三维的数据维度(颜色通道,高度,宽度),但是神经网络层能接受的数据维度是二维,此时维度是不匹配的,因此需要将图像的空间维度打平成向量。下面介绍pytorch中一些常见的维度变换方法。
(1)view() 和 reshape() 变换维度
import torch
a = torch.Tensor(2, 3, 32, 32)
print(a.view(2, 3, 32*32))
print(a.reshape(2, 3, 32*32))
print(a.reshape(2, 3, -1))
torch.Size([2, 3, 1024])
torch.Size([2, 3, 1024])
torch.Size([2, 3, 1024])
view() 和reshape()都可以对某张量进行维度的变化,但是reshape()方法的鲁棒性更强,更推荐大家使用。此外,view() 和reshape()接受的参数都是变换后的维度大小,在设置变换后维度的参数时,如果只剩一个维度没有给予,可直接使用-1来代替,pytorch会根据之前已设置的维度自动推导出最后未给予的维度。最后,这里需要注意的是变换后的总维度数量必须与变换前相等,否则报错。实例如下所示:
import torch
a = torch.Tensor(2, 3, 32, 32)
print(a.reshape(2, 3, 10).shape)
# ---------------------------------------------------------------------------
# RuntimeError Traceback (most recent call last)
# Input In [15], in () |
# 1 import torch
# 3 a = torch.Tensor(2, 3, 32, 32)
# ----> 5 print(a.reshape(2, 3, 10).shape)
# RuntimeError: shape '[2, 3, 10]' is invalid for input of size 6144
(2)unsqueeze() 增加新的数据维度
有时候,我们往往因为数据的增加需要在原始张量表示的基础上扩张维度来存储新增加的数据。举个例子,我们创建一个小学年级的档案时,可以创建一个三维张量:[年级数量,每年级的班级数量,班级的人数] ,此时,我要合并另一所学校的年级档案,最好的办法是在扩充出一个学校的维度,变成四维张量:[学校的数量,年级数量,每年级的班级数量,班级的人数] 。 unsqueeze() 方法就是用来增加数据维度的,接受的参数含义是在哪个维度之前增加新维度,这个参数也支持负索引。具体实例如下:
import torch
a = torch.Tensor(2, 3, 64, 64)
print(a.unsqueeze(0).shape)
print(a.unsqueeze(1).shape)
print(a.unsqueeze(2).shape)
print(a.unsqueeze(-1).shape)
# output
# torch.Size([1, 2, 3, 64, 64])
# torch.Size([2, 1, 3, 64, 64])
# torch.Size([2, 3, 1, 64, 64])
# torch.Size([2, 3, 64, 64, 1])
(3)squeeze() 缩减数据维度
增加某张量维度的反操作是减少维度,对于pytorch中的方法是squeeze(),接受的参数是要进行维度缩减的维度索引,注意,缩减的维度值必须等于1,否则不能进行缩减,而且程序不报错,实例如下:
import torch
a = torch.Tensor(2, 1, 64, 64)
print(a.squeeze(1).shape)
print(a.squeeze(2).shape)
# output
# torch.Size([2, 64, 64])
# torch.Size([2, 1, 64, 64])
(4)expand()和 repeat()在某维度上扩展数据
expand()可以在某维度上进行数据扩展,扩展的方法是复制原始数据。需要注意的是,expand()方法不能扩展维度大于1的维度,否则报错。因为其扩展方式是复制,当维度大于1时,expand()方法不清楚应该复制哪个数据。具体实例如下:
import torch
a = torch.Tensor(2, 1, 64, 64)
print(a.shape)
print(a.expand(2,3,64,64).shape)
print(a.expand(2,3,65,65).shape)
# output
# torch.Size([2, 1, 64, 64])
# torch.Size([2, 3, 64, 64])
# ---------------------------------------------------------------------------
# RuntimeError Traceback (most recent call last)
# Input In [22], in () |
# 4 print(a.shape)
# 5 print(a.expand(2,3,64,64).shape)
# ----> 6 print(a.expand(2,3,65,65).shape)
# RuntimeError: The expanded size of the tensor (65) must match the existing size (64)
# at non-singleton dimension 3 Target sizes: [2, 3, 65, 65]. Tensor sizes: [2, 1, 64, 64]
repeat()也可以在某维度上进行数据扩展,但是其接受的参数含义与expand()函数不同。repeat()函数接受的是在该维度上复制全部数据的次数,实例如下:
import torch
a = torch.Tensor(2, 1, 64, 64)
print(a.shape)
print(a.repeat(1,3,1,1).shape)
print(a.repeat(3,3,3,3).shape)
# output
# torch.Size([2, 1, 64, 64])
# torch.Size([2, 3, 64, 64])
# torch.Size([6, 3, 192, 192])
(5)transpose()和 permute()进行张量的维度调整
transpose()可以通过指定张量中某两个维度的索引,来对这两个维度的数据进行交换维度操作,示例如下:
import torch
a = torch.Tensor(2, 3, 64, 64)
print(a.shape)
print(a.transpose(0, 1).shape)
# output
# torch.Size([2, 3, 64, 64])
# torch.Size([3, 2, 64, 64])
(6)Broadcast:pytorch对不同维度张量进行计算时的自动补全规则
注意,broadcast不是函数,而是pytorch在加减两个不同维度张量时,底层自动实现的计算逻辑。首先,一个常识是当两个张量维度不同时,是不能进行加减操作的。broadcast的主要思想是针对维度小的数据依次从最后一个维度开始匹配维度大的数据,如果没匹配上,则插入一个新的维度。举例如下:
[2, 3, 32, 32] + [3,1,1] 是不能直接相加的。
Broadcast机制会先将 [3,1,1] 增加新维度变为 [1, 3, 1, 1] (等价于unsqueeze()方法),然后再将 [1, 3, 1, 1]扩展维度为 [2, 3, 32, 32] (等价于expand()方法)
从某种程度上说,broadcast机制等价于unsqueeze()和expand()两个方法的组合。目的是在处理两个维度不同的张量时,可以显式的不做任何处理进行直接加减操作。实际上,在底层隐式的进行了unsqueeze()和expand()。
注意,broadcast机制也有限制:当维度小的数据依次从最后一个维度开始匹配维度大的数据时,小维度数据的维度值必须符合以下两种情况之一,才能进行broadcast:等于1,与大维度数据的维度值相等,否则报错。示例如下:
import torch
a = torch.Tensor(2,3,32,32)
b = torch.Tensor(1,1,1)
c = torch.Tensor(32)
d = torch.Tensor(32, 1)
e = torch.Tensor(2, 32, 32)
print((a + b).shape)
print((a + c).shape)
print((a + d).shape)
print((a + e) .shape)
# output
# torch.Size([2, 3, 32, 32])
# torch.Size([2, 3, 32, 32])
# torch.Size([2, 3, 32, 32])
# ---------------------------------------------------------------------------
# RuntimeError Traceback (most recent call last)
# Input In [29], in () |
# 9 print((a + c).shape)
# 10 print((a + d).shape)
# ---> 11 print((a + e) .shape)
# RuntimeError: The size of tensor a (3) must match the size of tensor b (2)
# at non-singleton dimension 1