torch.nn.
Linear
(in_features, out_features, bias=True) 函数是一个线性变换函数:
其中,in_features为输入样本的大小,out_features为输出样本的大小,bias默认为true。如果设置bias = false那么该层将不会学习一个加性偏差。
Linear()
函数通常用于设置网络中的全连接层。
用例:
import torch
x = torch.randn(128, 20) # 输入样本
fc = torch.nn.Linear(20, 30) # 20为输入样本大小,30为输出样本大小
output = fc(x)
print('fc.weight.shape:\n ', fc.weight.shape)
print('fc.bias.shape:\n', fc.bias.shape)
print('output.shape:\n', output.shape)
ans = torch.mm(x,torch.t(fc.weight))+fc.bias # 计算结果与fc(x)相同
print('ans.shape:\n', ans.shape)
print(torch.equal(ans, output))
运行结果:
m.weight.shape:
torch.Size([30, 20])
m.bias.shape:
torch.Size([30])
output.shape:
torch.Size([128, 30])
ans.shape:
torch.Size([128, 30])
true