pytorch学习笔记(六)——pytorch进阶教程之tensor高阶操作

pytorch学习笔记(六)——pytorch进阶教程之高阶操作

  • 目录
    • where
    • gather

目录

where

pytorch学习笔记(六)——pytorch进阶教程之tensor高阶操作_第1张图片
新生成的tensor取决与输入x,y和条件condition。condition是一个矩阵,如果元素是1对应x,元素是0对应y。
示例代码:
pytorch学习笔记(六)——pytorch进阶教程之tensor高阶操作_第2张图片
cond>0.5返回的是一个[2,2]size的矩阵,大于0.5对应元素为1,否则为0。
只有最右下角元素不成立返回0,所以where函数得到的tensor最右下角元素值和b一致,其余和a一致。
引入where操作的用途:对于一些不规则的赋值方式,直接使用复制语句是非常不方便的,假如没有引入where要实现根据条件进行赋值,则需要多个for循环来实现,极大的阻碍了并行,代码运行过程几乎都是在cpu上进行的,引入where是为了方便该类操作可以在gpu上并行工作从而提高性能。

gather

pytorch学习笔记(六)——pytorch进阶教程之tensor高阶操作_第3张图片
gather函数主要是用于通过索引收集数值,参数一是要索引的数据表格,参数二是要进行索引收集的dim,参数三是用来进行索引收集的索引表格。
如有数据[dog,cat,whale]作为数据输入,dim=0,索引表格为[1,0,1,2],则输出为[cat,dog,cat,whale]。
引入gather操作的用途:用where类似,增加代码的并行度,运行于gpu,提高性能。
示例代码:
pytorch学习笔记(六)——pytorch进阶教程之tensor高阶操作_第4张图片
randn生成[4,10]size的tensor,意义是代表四张图片,总共十个类别,每张图片分属于每个类别的概率。topk得到最有可能的三类别,输出为[4,3]size的数据和对应索引,将label表格扩展为[4,10],在dim1上利用索引查label表,得到输出

你可能感兴趣的:(pytorch)