在使用pytorch框架复现模型的时候,我们需要再forward()函数中定义模型的逻辑,这时就要对模型参数使用一些运算,这里简单介绍一下pytorch框架下的两种常用的乘法运算。
按元素乘,即张量的对应元素相乘,将每个位置上相乘的结果作为返回值,使用“*”实现。看一下例子:
import torch
a = torch.Tensor([[1, 2], [3, 4]])
print('a', a)
b = torch.Tensor([[5, 6], [7, 8]])
print('b', b)
z = a * b # 按元素乘
print('z', z)
a tensor([[1., 2.],
[3., 4.]])
b tensor([[5., 6.],
[7., 8.]])
z tensor([[ 5., 12.],
[21., 32.]])
pytorch中的按元素乘支持广播,看一下例子:
import torch
a = torch.rand(2, 10)
print('a', a)
print(a.size())
b = torch.Tensor([[0.2], [0.8]])
print('b', b)
print(b.size())
print('a * b', a * b) # 按元素乘,广播
a tensor([[0.8498, 0.3875, 0.8588, 0.6866, 0.9692, 0.9897, 0.3961, 0.2562, 0.0801,
0.9533],
[0.1715, 0.7883, 0.7576, 0.1322, 0.8059, 0.7220, 0.1398, 0.0333, 0.4146,
0.4013]])
torch.Size([2, 10])
b tensor([[0.2000],
[0.8000]])
torch.Size([2, 1])
a * b tensor([[0.1700, 0.0775, 0.1718, 0.1373, 0.1938, 0.1979, 0.0792, 0.0512, 0.0160,
0.1907],
[0.1372, 0.6306, 0.6061, 0.1057, 0.6447, 0.5776, 0.1119, 0.0266, 0.3317,
0.3210]])
矩阵乘法就是我们熟悉的矩阵的乘法,可以使用“@”实现。看一下例子:
import torch
a = torch.Tensor([[1, 2], [3, 4]])
print('a', a)
b = torch.Tensor([[5, 6], [7, 8]])
print('b', b)
f = a @ b # 矩阵乘法
print('f', f)
a tensor([[1., 2.],
[3., 4.]])
b tensor([[5., 6.],
[7., 8.]])
f tensor([[19., 22.],
[43., 50.]])
最近在学习pytorch框架,可以加QQ3408649893交流。