torch.nn.Linear(in_features, out_features, bias=True, device=None, dtype=None)
这个函数主要是进行空间的线性映射
假设我们有一批数据 x x x, x x x的维度为20维,这一批数据一共有128个,我们要将20维的 x x x映射到30维空间的 y y y中,下面是计算过程,其中 w w w是Linear
函数的weight权重
y = x W T + b y = xW^{T}+b y=xWT+b
其中 x = ( x 11 x 12 . . . x 1 , 20 x 21 x 22 . . . x 2 , 20 . . . . . . . . . . . . x 128 , 1 x 128 , 2 . . . x 128 , 20 ) 128 × 20 x=\begin{pmatrix} x_{11} & x_{12} & ... & x_{1,20} \\ x_{21} & x_{22} & ... & x_{2,20} \\ ... & ... & ... & ... \\ x_{128,1} & x_{128,2} & ... & x_{128,20} \\ \end{pmatrix}_{128\times 20} x=⎝⎜⎜⎛x11x21...x128,1x12x22...x128,2............x1,20x2,20...x128,20⎠⎟⎟⎞128×20 w = ( w 11 w 12 . . . w 1 , 20 w 21 w 22 . . . w 2 , 20 . . . . . . . . . . . . w 30 , 1 w 30 , 2 . . . w 30 , 20 ) 30 × 20 w = \begin{pmatrix} w_{11} & w_{12} & ... & w_{1,20} \\ w_{21} & w_{22} & ... & w_{2,20} \\ ... & ... & ... & ... \\ w_{30,1} & w_{30,2} & ... & w_{30,20} \\ \end{pmatrix}_{30\times 20} w=⎝⎜⎜⎛w11w21...w30,1w12w22...w30,2............w1,20w2,20...w30,20⎠⎟⎟⎞30×20
( x 11 x 12 . . . x 1 , 20 x 21 x 22 . . . x 2 , 20 . . . . . . . . . . . . x 128 , 1 x 128 , 2 . . . x 128 , 20 ) 128 × 20 ( w 11 w 21 . . . w 30 , 1 w 12 w 22 . . . w 30 , 2 . . . . . . . . . . . . w 1 , 20 w 2 , 20 . . . w 30 , 20 ) 20 × 30 = ( y 11 y 12 . . . y 1 , 30 y 12 y 22 . . . y 2 , 30 . . . . . . . . . . . . y 128 , 1 y 128 , 2 . . . y 128 , 30 ) 128 × 30 \begin{pmatrix} x_{11} & x_{12} & ... & x_{1,20} \\ x_{21} & x_{22} & ... & x_{2,20} \\ ... & ... & ... & ... \\ x_{128,1} & x_{128,2} & ... & x_{128,20} \\ \end{pmatrix}_{128\times 20} \begin{pmatrix} w_{11} & w_{21} & ... & w_{30,1} \\ w_{12} & w_{22} & ... & w_{30,2} \\ ... & ... & ... & ... \\ w_{1,20} & w_{2,20} & ... & w_{30,20} \\ \end{pmatrix}_{20\times 30} = \begin{pmatrix} y_{11} & y_{12} & ... & y_{1,30} \\ y_{12} & y_{22} & ... & y_{2,30} \\ ... & ... & ... & ... \\ y_{128,1} & y_{128,2} & ... & y_{128,30} \\ \end{pmatrix}_{128\times 30} ⎝⎜⎜⎛x11x21...x128,1x12x22...x128,2............x1,20x2,20...x128,20⎠⎟⎟⎞128×20⎝⎜⎜⎛w11w12...w1,20w21w22...w2,20............w30,1w30,2...w30,20⎠⎟⎟⎞20×30=⎝⎜⎜⎛y11y12...y128,1y12y22...y128,2............y1,30y2,30...y128,30⎠⎟⎟⎞128×30
import torch
x = torch.randn(128, 20) # 输入的维度是(128,20)
linear = torch.nn.Linear(20, 30) # 20, 30是指维度
output = linear(x)
print('linear.weight.shape: ', linear.weight.shape)
print('linear.bias.shape: ', linear.bias.shape)
print('output.shape: ', output.shape)
# ans = torch.mm(input,torch.t(m.weight))+m.bias 等价于下面的
# .t就是w转置之后的部分
ans = torch.mm(x, linear.weight.t()) + linear.bias
print('ans.shape: ', ans.shape)
print(torch.equal(ans, output))
'''output:
linear.weight.shape: torch.Size([30, 20])
linear.bias.shape: torch.Size([30])
output.shape: torch.Size([128, 30])
ans.shape: torch.Size([128, 30])
True
'''