nn.Linear

torch.nn.Linear(in_features, out_features, bias=True, device=None, dtype=None)
这个函数主要是进行空间的线性映射

  • in_features:输入数据的数据维度
  • out_features:输出数据的数据维度

函数执行过程:

假设我们有一批数据 x x x x x x的维度为20维,这一批数据一共有128个,我们要将20维的 x x x映射到30维空间的 y y y中,下面是计算过程,其中 w w wLinear函数的weight权重

y = x W T + b y = xW^{T}+b y=xWT+b

其中 x = ( x 11 x 12 . . . x 1 , 20 x 21 x 22 . . . x 2 , 20 . . . . . . . . . . . . x 128 , 1 x 128 , 2 . . . x 128 , 20 ) 128 × 20 x=\begin{pmatrix} x_{11} & x_{12} & ... & x_{1,20} \\ x_{21} & x_{22} & ... & x_{2,20} \\ ... & ... & ... & ... \\ x_{128,1} & x_{128,2} & ... & x_{128,20} \\ \end{pmatrix}_{128\times 20} x=x11x21...x128,1x12x22...x128,2............x1,20x2,20...x128,20128×20 w = ( w 11 w 12 . . . w 1 , 20 w 21 w 22 . . . w 2 , 20 . . . . . . . . . . . . w 30 , 1 w 30 , 2 . . . w 30 , 20 ) 30 × 20 w = \begin{pmatrix} w_{11} & w_{12} & ... & w_{1,20} \\ w_{21} & w_{22} & ... & w_{2,20} \\ ... & ... & ... & ... \\ w_{30,1} & w_{30,2} & ... & w_{30,20} \\ \end{pmatrix}_{30\times 20} w=w11w21...w30,1w12w22...w30,2............w1,20w2,20...w30,2030×20

( x 11 x 12 . . . x 1 , 20 x 21 x 22 . . . x 2 , 20 . . . . . . . . . . . . x 128 , 1 x 128 , 2 . . . x 128 , 20 ) 128 × 20 ( w 11 w 21 . . . w 30 , 1 w 12 w 22 . . . w 30 , 2 . . . . . . . . . . . . w 1 , 20 w 2 , 20 . . . w 30 , 20 ) 20 × 30 = ( y 11 y 12 . . . y 1 , 30 y 12 y 22 . . . y 2 , 30 . . . . . . . . . . . . y 128 , 1 y 128 , 2 . . . y 128 , 30 ) 128 × 30 \begin{pmatrix} x_{11} & x_{12} & ... & x_{1,20} \\ x_{21} & x_{22} & ... & x_{2,20} \\ ... & ... & ... & ... \\ x_{128,1} & x_{128,2} & ... & x_{128,20} \\ \end{pmatrix}_{128\times 20} \begin{pmatrix} w_{11} & w_{21} & ... & w_{30,1} \\ w_{12} & w_{22} & ... & w_{30,2} \\ ... & ... & ... & ... \\ w_{1,20} & w_{2,20} & ... & w_{30,20} \\ \end{pmatrix}_{20\times 30} = \begin{pmatrix} y_{11} & y_{12} & ... & y_{1,30} \\ y_{12} & y_{22} & ... & y_{2,30} \\ ... & ... & ... & ... \\ y_{128,1} & y_{128,2} & ... & y_{128,30} \\ \end{pmatrix}_{128\times 30} x11x21...x128,1x12x22...x128,2............x1,20x2,20...x128,20128×20w11w12...w1,20w21w22...w2,20............w30,1w30,2...w30,2020×30=y11y12...y128,1y12y22...y128,2............y1,30y2,30...y128,30128×30

一个简单的例子

import torch


x = torch.randn(128, 20)  # 输入的维度是(128,20)
linear = torch.nn.Linear(20, 30)  # 20, 30是指维度
output = linear(x)

print('linear.weight.shape:   ', linear.weight.shape)
print('linear.bias.shape:     ', linear.bias.shape)
print('output.shape:          ', output.shape)

# ans = torch.mm(input,torch.t(m.weight))+m.bias 等价于下面的
# .t就是w转置之后的部分
ans = torch.mm(x, linear.weight.t()) + linear.bias
print('ans.shape:             ', ans.shape)
print(torch.equal(ans, output))


'''output:
linear.weight.shape:    torch.Size([30, 20])
linear.bias.shape:      torch.Size([30])
output.shape:           torch.Size([128, 30])
ans.shape:              torch.Size([128, 30])
True
'''

你可能感兴趣的:(PyTorch,pytorch,深度学习)