PyTorch入门实战教程笔记(九):基础张量操作5

PyTorch入门实战教程笔记(九):基础张量操作5:高阶操作 where&gather

1. Where

torch.where(condition, x, y) -->tensor,一共3个参数,第一个参数条件,判断其Ture or False。 如下:
PyTorch入门实战教程笔记(九):基础张量操作5_第1张图片
比如,现在将A, B的值根据condition传递给C,可以从C[0]=A[0],C[1]=B[1], 此外,设计到具体的数值,也可以进行替换,比如C[0, 0, 0]=A[0, 0, 0], C[0, 0, 1]=B[0, 0, 1],C[0, 0, 2]=A[0, 0, 2],…具体例子中,如下tensor,将大于0.5的全部替换为a(全为0),将剩余的全部替换为b(全为1),操作后结果如下:
PyTorch入门实战教程笔记(九):基础张量操作5_第2张图片

2. Gather

gather可以称为收集操作,torch.gether(input, dim, index, out=None) ->tensor。第一个参数为数据,比如【dog, cat, whale】等,第二个参数为dim,在该维度上查找收集,第三个参数为index索引,比如第一个为1,第二个为0,第三个为1,第四个为2,则对应的数据为【cat,dog,cat,whale】。对于图像处理,比如最后的概率为[0.1 0.8 0.05 0.05 …] top1为0.8,对应的索引号为1,而1对应的label为dog,那么通过gether之后就能够知道top1为dog这个数据。再看一个具体的实例,比如现在有数据的shape为[4,10], 通过topk来求出top3的数据及索引值,那么idx[1]的shape为[4,3],假设数据的标签为对应的0对应100, 1对应101, … ,9对应109,那么通过torch.gather( label.expend(4,10), dim=1,index=idx.long() ), 第一个参数将label的shape扩展为[4, 10], 第二个参数表示在idx的 dim=1上操作,第三个参数表示index的长度(shape)为[4, 3]。即可求出每一个bach对应的top的标签数据,具体如下:
PyTorch入门实战教程笔记(九):基础张量操作5_第3张图片

你可能感兴趣的:(PyTorch实战学习笔记,pytorch,深度学习,神经网络)