pyTorch.randn()、rand()、cat()、pow()、scatter_()、.squeeze() 、.unsqueeze()、gather()

文章目录

  • torch.randn()
  • torch.rand()
  • torch.cat()
  • torch.pow()
  • .item()
  • .scatter_()
  • .squeeze()
  • .unsqueeze()
  • .gather()

torch.randn()

torch.randn(*sizes, out=None) → Tensor

返回一个张量,包含了从标准正态分布(均值为0,方差为 1,即高斯白噪声)中抽取一组随机数,形状由可变参数sizes定义。
参数:

  • sizes(int…): 整数序列,定义了输出形状
  • out (Tensor, optinal) : 结果张量

例子:

>>> torch.randn(4)

-0.1145
 0.0094
-1.1717
 0.9846
[torch.FloatTensor of size 4]

>>> torch.randn(2, 3)

 1.4339  0.3351 -1.0999
 1.5458 -0.9643 -0.3558
[torch.FloatTensor of size 2x3]

torch.rand()

torch.rand(*sizes, out=None) → Tensor

返回一个张量,包含了从区间[0,1)的均匀分布中抽取的一组随机数,形状由可变参数sizes 定义。
参数:

  • sizes (int…) : 整数序列,定义了输出形状
  • out (Tensor, optinal) : 结果张量
    例子:
>>> torch.rand(4)

 0.9193
 0.3347
 0.3232
 0.7715
[torch.FloatTensor of size 4]

>>> torch.rand(2, 3)

 0.5010  0.5140  0.0719
 0.1435  0.5636  0.0538
[torch.FloatTensor of size 2x3]

torch.cat()

torch.cat(inputs, dimension=0) → Tensor

在给定维度上对输入的张量序列seq 进行连接操作
参数:

  • inputs (sequence of Tensors) – 可以是任意相同Tensor 类型的python 序列
  • dimension (int, optional) – 沿着此维连接张量序列。

例子:

x = torch.randn(2, 3)
print('x',x,'\n',x.shape)
y = torch.cat((x,x,x),0)
print('y',y,'\n',y.shape)
z = torch.cat((x,x,x),1)
print('z',z,'\n',z.shape)

pyTorch.randn()、rand()、cat()、pow()、scatter_()、.squeeze() 、.unsqueeze()、gather()_第1张图片

torch.pow()

torch.pow(input, exponent, out=None)

对输入的 input 按元素求 exponent 次幂值,并返回结果张量。
幂值exponent 可以为单一 float 数或者与input相同元素数的张量。
当幂值为标量时,执行操作:
o u t i = x e x p o n e n t out_i=x^{exponent} outi=xexponent
当幂值为张量时,执行操作:
o u t i = x e x p o n e n t i out_i=x^{exponent_i} outi=xexponenti
参数:

  • input (Tensor) – 输入张量
  • exponent (float or Tensor) – 幂值
  • out (Tensor, optional) – 输出张量参数:
    例子:
x = torch.arange(1, 5)
y = torch.arange(1, 5)
z = torch.pow(x,2)
k = torch.pow(x,y)
print(x,'\n',z)
print(y,'\n',k)

在这里插入图片描述

torch.pow(base, input, out=None)

base标量浮点值,input张量, 返回的输出张量 out 与输入张量相同形状。
执行操作为:
o u t i = b a s e i n p u t i out_i=base^{input_i} outi=baseinputi
参数:

  • base (float) – 标量值,指数的底
  • input ( Tensor) – 幂值
  • out (Tensor, optional) – 输出张量

例子:

exp = torch.arange(1, 4)
print(exp)
base = 2
a = torch.pow(base,exp)
print(a)

在这里插入图片描述

.item()

只含一个元素的张量可以用item得到元素值,请注意这里的print(x)和print(x.item())值是不一样的,一个是打印张量,一个是打印元素,例子如下:

y = torch.randn(1)
print(y)
print(y.item())

在这里插入图片描述
如果 x 不是只含一个元素的张量可以吗?不行的!但是可以用这种方法访问特定位置的元素~

x = torch.randn(2,1)
print(x)
print(x[1,0].item())

在这里插入图片描述

.scatter_()

函数的参数为:

scatter_(dim,index,src)
dim:维度,表示在第几维上操作;
index:索引(数组),表示位置;
src:原数组tensor,即用来填充的tensor。

src和index的维度应该是一致的!!
这个函数的主要作用是按照一定规则用src中的值去填充/替换 要操作tensor的值。即把 input 数组中的数据进行重新分配。index 中表示了要把原数组中的数据分配到 output 数组中的位置,如果未指定,则填充0
例子1:

input = torch.tensor([[1.0],[2.0]])
index = torch.tensor([[2],[3]])
output = torch.zeros(3,4)
output.scatter_(1,index,input)

tensor([[0., 0., 1., 0.],[0., 0., 0., 2.],[0., 0., 0., 0.]])
首先,src和index的维度应该是一致的。由于dim=1,所以行不变,只变化列。

  • 对于output的第一行,要改变的列索引是2,所以output索引为2处的值由0替换为input的第一个值1;
  • 对于output的第二行,要改变的列索引是3,所以output索引为3处的值由0替换为input的第二个值2;

例子2:

import torch 
input = torch.randn(2, 4)
print(input)
output = torch.zeros(2, 5)
index = torch.tensor([[3, 1, 2, 0], [1, 2, 0, 3]])
output = output.scatter(1, index, input)
print(output)

pyTorch.randn()、rand()、cat()、pow()、scatter_()、.squeeze() 、.unsqueeze()、gather()_第2张图片
从结果可以看出,把 input 的第一行元素 按照 [3, 1, 2, 0] 的顺序重新排列,作为output第一行的前四列数据;同样地,把 input 的第二行元素 按照 [1, 2, 0, 3] 的顺序从新排列为 output 第二行的前四个元素。

例子3:

用scatter_( )函数生成one-hot向量

import torch 
output = torch.zeros(4, 5)
index = torch.tensor([[3],[1],[2],[0]])
output = output.scatter(1, index, 1)
print(output)

pyTorch.randn()、rand()、cat()、pow()、scatter_()、.squeeze() 、.unsqueeze()、gather()_第3张图片

.squeeze()

主要对数据的维度进行压缩,去掉维数为1的的维度

  • a.squeeze():将a中所有维数为1的维度删掉。维数不为1的维度没有影响
  • a.squeeze(N):将a中指定位置处(维度序号为N处)的维数为1的维度删掉。等价于 b = torch.squeeze(a,N)
import torch
a = torch.arange(10).view(2,1,5)
print(a)
print(a.shape)
b = a.squeeze()
print(b)
print(b.shape)

pyTorch.randn()、rand()、cat()、pow()、scatter_()、.squeeze() 、.unsqueeze()、gather()_第4张图片

c = a.view(2,1,5,1)
print(c)
d = c.squeeze(3)
print('d',d)
f = torch.squeeze(c,3) # 和d = c.squeeze(3)等价
print('f',f)
print('f.shape',f.shape)
h = c.squeeze()
print('h',h)
print('h.shape',h.shape)

pyTorch.randn()、rand()、cat()、pow()、scatter_()、.squeeze() 、.unsqueeze()、gather()_第5张图片

.unsqueeze()

与.squeeze()作用相反,.unsqueeze()主要是对数据维度进行扩充,在指定位置加上维数为一的维度

  • a.unsqueeze(N):将a中指定位置处(维度索引为N)增加维数为1的维度
  • b = torch.unsqueeze(a,N):作用同上行
a = torch.randn(2,3)
print('a',a)
print('a.shape',a.shape)
b = torch.unsqueeze(a,2)
print(b)
print('b.shape',b.shape)
c = a.unsqueeze(2)  # 等价于b = torch.unsqueeze(a,2)
print('c',c)
print('c.shape',c.shape)

pyTorch.randn()、rand()、cat()、pow()、scatter_()、.squeeze() 、.unsqueeze()、gather()_第6张图片

.gather()

欢迎参考本小菜的另一篇博客的详细介绍:
PyTorch中gather()函数的用法

你可能感兴趣的:(PyTorch)