torch.nn.Linear()

功能是定义一个线性变换(连同偏置),即定义一个这样的运算:

                                                                 y=xW^T+b

例:

import torch
import torch.nn as nn
linear=nn.Linear(5,3,bias=True)
x=torch.randn(10,5)
out=linear(x)
print(out)
print('weight.shape:\n ', linear.weight.shape)
print('bias.shape:\n', linear.bias.shape)

其实是定义了一个权重矩阵和一个偏置向量,权重矩阵的size正好是我们输入的行列转置。但这不重要,我们记住定义是第一个参数的值等于输入特征的维度即可。

你可能感兴趣的:(工具类)