pytorch中的一些函数---torch.cat()、index_select()、torch.gather()

torch.cat()

torch.cat是将两个张量(tensor)拼接在一起,cat是concatnate的意思,即拼接。
其中dim代表维度,0代表行,1代表列。
pytorch中的一些函数---torch.cat()、index_select()、torch.gather()_第1张图片
pytorch中的一些函数---torch.cat()、index_select()、torch.gather()_第2张图片

index_select()

index_select(x, 1, indices)

1代表dim(维度)为1,即列。

indices是筛选的索引序号。当dim为0时,就是筛选行的索引,当dim为1时,就是筛选列的索引。
pytorch中的一些函数---torch.cat()、index_select()、torch.gather()_第3张图片
注:torch.index_select(y, 1, indices)的结果等价于y.index_select(1, indices)

torch.gather()

torch.gather(input, dim, index, out=None)中,input是输入,dim代表维度,index表示的是所选择的维度上的索引,out表示输出。
pytorch中的一些函数---torch.cat()、index_select()、torch.gather()_第4张图片
注:torch.gather(b, dim=0, index=index_2)的结果等价于b.gather(0, index_2)
注:index的类型必须是LongTensor类型的

gather在one-hot为输出的多分类问题中,可以把最大值坐标作为index传进去,然后提取到每一行的正确预测结果。

你可能感兴趣的:(pytorch)