Pytorch常见函数总结

0. 介绍

持续更新
主要是总结pytorch中有关Tensor的常规操作。
其他博客
pytorch: Tensor 常用操作
https://blog.csdn.net/xholes/article/details/81667211
pytorch核心初探:
https://www.douban.com/note/720875470/

1. torch.squeeze(), torch.unsqueeze()

1.1 unsqueeze

Pytorch常见函数总结_第1张图片
unsqueeze会在input的第dim维中添加一个维度。假如是四维Tensor的话,在第二维扩展(dim从0开始),则原先的第(i, j, k, l)个元素映射到第(i, j, 0, k, l)个元素,例子如下:

A  = torch.rand([3,3,4, 5])
print(torch.unsqueeze(A, 2).shape)
# 结果为torch.Size([3, 3, 1, 4, 5])

其他版本的使用:
unsqueeze(dim) → Tensor:See torch.unsqueeze()
unsqueeze_(dim) → Tensor:In-place version of unsqueeze()

1.1 squeeze

Pytorch常见函数总结_第2张图片
同理

2. torch.Tensor.expand()

Pytorch常见函数总结_第3张图片
将Tensor的某个维度为1的指标扩展为多维,比方说上面的x.expand(3, 4)。
Note:这里的扩展并非实际分配内存,如果令y = x.expand(3, 4),y[0, 0] = 1000,那么y的第一行都会变为1000

3. torch.gather(), torch.Tensor.gather()

Pytorch常见函数总结_第4张图片
outputindex的shape相同,并且inputindex只有第dim个指标的维数不同。

4. torch.clamp

将输入input张量每个元素的夹紧到区间 [min,max][min,max],并返回结果到一个新张量。
https://blog.csdn.net/u013230189/article/details/82627375

你可能感兴趣的:(总结)