1.使用 index_select 函数
t=torch.arange(24).reshape(2,6,2)
index = torch.tensor([3,4]) # 选择第index=3和4的元素
ptf_tensor(t,'t')
ptf_tensor(t.index_select(1,index)) #在 dim=1 上选取
结果如下,可仔细观察索引方法
tensor([[[ 0, 1],
[ 2, 3],
[ 4, 5],
[ 6, 7],
[ 8, 9],
[10, 11]],
[[12, 13],
[14, 15],
[16, 17],
[18, 19],
[20, 21],
[22, 23]]])
##############################################################
tensor([[[ 6, 7],
[ 8, 9]],
[[18, 19],
[20, 21]]])
2.np.random.choice()方法
import torch
import numpy as np
a = torch.rand(32,256,4,4,4).cuda()
# print("a.size():",a.size())
b = a.view(32,-1,2,2,2).cuda()
# print("b.size():",b.size())
B,N,C,D,E = b.shape
t = [i for i in range(N)]
index = torch.LongTensor(np.random.choice(t, 512, replace=True)).cuda()
## 这里np.random.choice()从t中选取512个随机数构成索引tensor
c = torch.index_select(b, 1, index).cuda() # 在b的dim=1维上进行索引
print(c.size())
c的结果如下
torch.Size([32, 512, 2, 2, 2])
可以看到索引成功
其中参数replace
用来设置是否可以取相同元素:True表示可以取相同数字;False表示不可以取相同数字。默认是True
这种错一般是在GPU上出现索引越界导致的,建议在CPU上测试寻找越界点,因为GPU上debug不太直观,容易报你也看不懂的错误。
一个直观的例子如下,在cuda上会让你CUDA_LAUNCH_BLOCKING=1去debug,所以考虑直接在CPU上测试。
data = [[1, 2, 14],[3, 4, 5]]
x_data = torch.tensor(data).cuda()
rand_data = torch.rand(13).cuda()
rand_data[x_data[0]]
>>>RuntimeError: CUDA error: device-side assert triggered
CUDA kernel errors might be asynchronously reported at some other API call,so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1.