pytorch下index数据

1. 通过下标取数据:index_select

torch.index_select(input, dim, index)
input: 待index的tensor
dim: 在哪个维度上index
index: torch.longtensor, 在dim维上根据index中的每个元素取数据,组成新的tensor。

a = torch.randn(5, 10, 20)
b = torch.tensor([2, 4, 6]).to(torch.long)
c = a.index_select(a, dim=1, index=b)
c.shape: [5, 3, 20],分别由a的dim1的第2、4、6行构成

2. 通过mask取数据

a = torch.randn(5, 10, 20)
mask = [True, True, False, True, False]
c = a[mask, :, :]
c.shape: [3, 10, 20], 由a的第0维的第0、1、3行组成

有时想在一个更密集的step上取一个稀疏的子集出来,手动输入mask很慢,此时可以用numpy的isin:

x = torch.randn(6, 10)
a = [5, 10, 15, 20, 25, 30]
b = [5, 15, 25]
index = np.isin(a, b)   # [True, False, True, False, True, False]
x_sample = x[index, :] 

isin中,比较第一个参数中的每个元素是否在第二个参数中,返回一个跟第一个参数一样大的布尔张量。

3. 双线性差值取数据

torch.functional.grid_sample,可以根据下标到原tensor中index数据,用最近邻或其他差值方法差值得到新的数据。

你可能感兴趣的:(pytorch下index数据)