pytorch nn.linear() 原理

0. 做了啥到底

在这里插入图片描述
nn.linear()是用来设置网络中的全连接层的,从源码上看到其实就是对输入的数据做一次线性变化,通过构造参数矩阵A来给输入X做维度变换。

1.怎么用

比如下例子中我们要把【64,32】–>【64,64】

import torch
from torch import nn
a = torch.randn(64,32)
# e.g., 想要让 [batch_size, input_dim] ---> [batch_size,out_dim]
# do:  nn.Linear( input_dim, out_dim, bias)
#(自注) torch中 对于linear input大于两维的,好像前面的dim视为样本数这类的,所以只改动后面的dim
linear = nn.Linear(32,64)
b = linear(a)
print(a.shape)
print(b.shape)
print(linear.weight)

在这里插入图片描述

不知道大家发现没有,这里的参数weight的shape是 out * input 而不是我们直观的理解的input * out。 这里其实nn.linear 是调用torch funcion 里的linear的,F.function 里用的就是weight.T()

也就是让input通过和参数weight的transport做矩阵乘法

你可能感兴趣的:(pytorch,深度学习,python)