nn.linear()

import torch
import torch.nn

nn.linear()是用来设置网络中的全连接层的,而在全连接层中的输入与输出都是二维张量,一般形状为[batch_size, size],与卷积层要求输入输出是4维张量不同。
用法与形参见说明如下:


nn.Linear

in_features指的是输入的二维张量的大小,即输入的[batch_size, size]中的size。
batch_size指的是每次训练(batch)的时候样本的大小。比如CNN train的样张图片是60张,设置batch_size=15,那么iteration=4。如果想多训练几次(因为可以每次的batch不是相同的数据),那么就是epoch。
所以nn.Linear()中的输入包括有输入的图片数量,同时还有每张图片的维度。
out_features指的是输出的二维张量的大小,即输出[batch_size,size]中的size是输出的张量维度,而batch_size与输入中的一致。

参考:PyTorch的nn.Linear()详解

你可能感兴趣的:(nn.linear())