PyTorch入门3——张量索引、张量拼接

PyTorch入门 3 —— 张量索引、张量拼接

  • 张量索引
    • 简单行、列索引
    • 列表索引
    • 范围索引
    • 布尔索引
    • 多维索引
  • 张量拼接
    • torch.cat 函数的使用
    • torch.stack 函数的使用

张量索引

在操作张量时,经常需要进行获取或者修改张量元素值的操作,这时候各种张量的花式索引操作就派上大用场了。Pytorch 中对张量进行索引有多种方法,比如:简单行列索引、列表索引、范围索引、布尔索引、多维索引等等,可以根据需要灵活使用。

简单行、列索引

import torch

data = torch.randint(0, 10, [4, 5])  # 四行五列的二维张量
print(data)
print(data[2])     # 获取第三行数据,返回一维张量
print(data[:, 1])  # 获取第二列数据,返回一维张量
print(data[1, 2])  # 获取第二行的第三列数据,返回零维张量
print(data[1][2])  # 同上

列表索引

import torch

data = torch.randint(0, 10, [4, 5])  # 四行五列的二维张量
print(data)
print(data[[1,0,2])                # 返回下标为1行、0行、2行共三行数据组成的3行5列的二维张量
print(data[[0,1,3], [3,2,4]])      # 返回下标为0行3列、1行2列、3行4列三个数据组成的一维张量
print(data[[[0],[1]], [[3],[4]]])  # 返回下标为0行3列、1行4列两个数据组成的2行1列的二维张量
print(data[[0,1], [[3],[4]]])      # 返回下标为0行3列、1行3列、0行4列、1行4列四个数据组成的2行2列的二维张量
print(data[[0,1], [[1,2],[0,4]]])  # 返回下标为0行1列、1行2列、0行0列、1行4列四个数据组成的2行2列的二维张量
print(data[[[1],[0]], [3,4]])      # 返回下标为1行3列、1行4列、0行3列、0行4列四个数据组成的2行2列的二维张量
print(data[[[1,3],[0,2]], [3,4]])  # 返回下标为1行3列、3行4列、0行3列、2行4列四个数据组成的2行2列的二维张量

范围索引

import torch

data = torch.randint(0, 10, [4, 5])  # 四行五列的二维张量
print(data)
print(data[:3, 4])   # 返回前三行的第五列数据组成的一维张量
print(data[:3, [0,2,4]])  # 返回前三行的第一三五列数据组成的二维张量
print(data[:3, :4])  # 返回前三行的前四列数据组成的二维张量
print(data[2:, :4])  # 返回第三行到末行的前四列数据组成的二维张量

布尔索引

import torch

data = torch.randint(0, 10, [4, 5])  # 四行五列的二维张量
print(data)
print(data[data > 5])  # 返回所有大于5的元素组成的一维张量
print(data[[True,False,True,False]])  # 返回第一行与第三行数据组成的二维张量
print(data[1:, [True,False,True,False,True]])  # 返回第二行到末行的第一三五列数据组成的二维张量
print(data[data[:, 2] > 5])  # 返回第三列大于5的行数据组成的二维张量
print(data[:, data[1] > 5])  # 返回第二行大于5的列数据组成的二维张量

多维索引

data = torch.randint(0, 10, [3, 4, 5])  # 三片四行五列的三维张量
print(data)
print(data[0, :, :])  # 返回第一片所有数据,四行五列的二维张量
print(data[:, 0, :])  # 返回所有片的第一行数据,三行五列的二维张量
print(data[:, :, 0])  # 返回所有片的第一列数据,三行四列的二维张量

张量拼接

张量的拼接操作在神经网络搭建过程中是非常常用的方法,比如:残差网络、注意力机制中都使用到了张量拼接。使用 cat 函数可以将张量按照指定维度拼接起来,并不会升维;而 stack 函数可以将张量在指定维度叠加起来,会升维!

torch.cat 函数的使用

import torch

data1 = torch.randint(0, 10, [3, 5, 4])
data2 = torch.randint(0, 10, [3, 5, 4])
print(data1)
print(data2)

new_data = torch.cat([data1, data2], dim=0)  # 1. 按0维度拼接
print(new_data)  # shape:torch.Size([6, 5, 4])

new_data = torch.cat([data1, data2], dim=1)  # 2. 按1维度拼接
print(new_data)  # shape:torch.Size([3, 10, 4])

new_data = torch.cat([data1, data2], dim=2)  # 3. 按2维度拼接
print(new_data)  # shape:torch.Size([3, 5, 8])

torch.stack 函数的使用

import torch

data1= torch.randint(0, 10, [4, 5])
data2= torch.randint(0, 10, [4, 5])
print(data1)
print(data2)

new_data = torch.stack([data1, data2], dim=0)  # 在0维度叠加,升维!
print(new_data)  # shape:torch.Size([2, 4, 5])

new_data = torch.stack([data1, data2], dim=1)  # 在1维度叠加,升维!
print(new_data)  # shape:torch.Size([4, 2, 5])

new_data = torch.stack([data1, data2], dim=2)  # 在2维度叠加,升维!
print(new_data)  # shape:torch.Size([4, 5, 2])

以上。

你可能感兴趣的:(Pytorch,深度学习,人工智能,pytorch)