在操作张量时,经常需要进行获取或者修改张量元素值的操作,这时候各种张量的花式索引操作就派上大用场了。Pytorch 中对张量进行索引有多种方法,比如:简单行列索引、列表索引、范围索引、布尔索引、多维索引等等,可以根据需要灵活使用。
import torch
data = torch.randint(0, 10, [4, 5]) # 四行五列的二维张量
print(data)
print(data[2]) # 获取第三行数据,返回一维张量
print(data[:, 1]) # 获取第二列数据,返回一维张量
print(data[1, 2]) # 获取第二行的第三列数据,返回零维张量
print(data[1][2]) # 同上
import torch
data = torch.randint(0, 10, [4, 5]) # 四行五列的二维张量
print(data)
print(data[[1,0,2]) # 返回下标为1行、0行、2行共三行数据组成的3行5列的二维张量
print(data[[0,1,3], [3,2,4]]) # 返回下标为0行3列、1行2列、3行4列三个数据组成的一维张量
print(data[[[0],[1]], [[3],[4]]]) # 返回下标为0行3列、1行4列两个数据组成的2行1列的二维张量
print(data[[0,1], [[3],[4]]]) # 返回下标为0行3列、1行3列、0行4列、1行4列四个数据组成的2行2列的二维张量
print(data[[0,1], [[1,2],[0,4]]]) # 返回下标为0行1列、1行2列、0行0列、1行4列四个数据组成的2行2列的二维张量
print(data[[[1],[0]], [3,4]]) # 返回下标为1行3列、1行4列、0行3列、0行4列四个数据组成的2行2列的二维张量
print(data[[[1,3],[0,2]], [3,4]]) # 返回下标为1行3列、3行4列、0行3列、2行4列四个数据组成的2行2列的二维张量
import torch
data = torch.randint(0, 10, [4, 5]) # 四行五列的二维张量
print(data)
print(data[:3, 4]) # 返回前三行的第五列数据组成的一维张量
print(data[:3, [0,2,4]]) # 返回前三行的第一三五列数据组成的二维张量
print(data[:3, :4]) # 返回前三行的前四列数据组成的二维张量
print(data[2:, :4]) # 返回第三行到末行的前四列数据组成的二维张量
import torch
data = torch.randint(0, 10, [4, 5]) # 四行五列的二维张量
print(data)
print(data[data > 5]) # 返回所有大于5的元素组成的一维张量
print(data[[True,False,True,False]]) # 返回第一行与第三行数据组成的二维张量
print(data[1:, [True,False,True,False,True]]) # 返回第二行到末行的第一三五列数据组成的二维张量
print(data[data[:, 2] > 5]) # 返回第三列大于5的行数据组成的二维张量
print(data[:, data[1] > 5]) # 返回第二行大于5的列数据组成的二维张量
data = torch.randint(0, 10, [3, 4, 5]) # 三片四行五列的三维张量
print(data)
print(data[0, :, :]) # 返回第一片所有数据,四行五列的二维张量
print(data[:, 0, :]) # 返回所有片的第一行数据,三行五列的二维张量
print(data[:, :, 0]) # 返回所有片的第一列数据,三行四列的二维张量
张量的拼接操作在神经网络搭建过程中是非常常用的方法,比如:残差网络、注意力机制中都使用到了张量拼接。使用 cat 函数可以将张量按照指定维度拼接起来,并不会升维;而 stack 函数可以将张量在指定维度叠加起来,会升维!
import torch
data1 = torch.randint(0, 10, [3, 5, 4])
data2 = torch.randint(0, 10, [3, 5, 4])
print(data1)
print(data2)
new_data = torch.cat([data1, data2], dim=0) # 1. 按0维度拼接
print(new_data) # shape:torch.Size([6, 5, 4])
new_data = torch.cat([data1, data2], dim=1) # 2. 按1维度拼接
print(new_data) # shape:torch.Size([3, 10, 4])
new_data = torch.cat([data1, data2], dim=2) # 3. 按2维度拼接
print(new_data) # shape:torch.Size([3, 5, 8])
import torch
data1= torch.randint(0, 10, [4, 5])
data2= torch.randint(0, 10, [4, 5])
print(data1)
print(data2)
new_data = torch.stack([data1, data2], dim=0) # 在0维度叠加,升维!
print(new_data) # shape:torch.Size([2, 4, 5])
new_data = torch.stack([data1, data2], dim=1) # 在1维度叠加,升维!
print(new_data) # shape:torch.Size([4, 2, 5])
new_data = torch.stack([data1, data2], dim=2) # 在2维度叠加,升维!
print(new_data) # shape:torch.Size([4, 5, 2])
以上。