PyTorch:切片函数index_select()

index_select()函数有两种用法。
第一种是将被切片的函数作为参数传入index_select()中

torch.index_select(input, dim, index, out=None)

还有一种是调用张量内置的index_select()函数。

input.index_select(dim, index)

index_select()函数的作用是针对张量input,在它的dim维度上切取index指定的范围切片。

参数:
input:被操作的张量
dim:维度
index:一维Tensor,表示索引下标的范围

例如

import torch
a = torch.tensor([[1, 2, 3, 4], [4, 5, 6, 7]])

b = torch.index_select(a, 0, torch.tensor([1]))
print(b)

c = torch.index_select(a, 1, torch.tensor([1,3]))
print(c)

输出为


这里维度dim从0开始算,则b表示在第0维(即行)上,切下下标为1的行;c表示在第1维(即列)上,切下下标为1和3的列。

你可能感兴趣的:(PyTorch:切片函数index_select())