LeNet网络参数注释

LeNet网络参数注释

from torch import nn

class LeNet(nn.Module):
    def __init__(self):
        super(LeNet, self).__init__()  # 例:输入数据为(1,32,32)
        self.conv = nn.Sequential(
            nn.Conv2d(1, 6, 5),  # 输入通道数1,输出通道数6,卷积核大小5 ==> (6, 28, 28) 步长默认为1
            nn.Sigmoid(),
            nn.MaxPool2d(2, 2),  # 经过池化输出为 (6,14,14)
            nn.Conv2d(6, 16, 5),  # 输入通道数6,输出通道数16, 卷积核大小5 ==>(16, 10, 10)
            nn.Sigmoid(),
            nn.MaxPool2d(2, 2)  # 经过池化输出为(16, 5, 5)
        )
        self.fc = nn.Sequential(
            nn.Linear(16 * 5 * 5, 120),  # 输入为是上边 16*5*5
            nn.Sigmoid(),
            nn.Linear(120, 84),
            nn.Sigmoid(),
            nn.Linear(84, 10)
        )

    def forward(self, img):
        feature = self.conv(img)
        output = self.fc(feature.view(img.shape[0], -1))  # 全连接需要进行平铺
        return output

LeNet网络参数注释_第1张图片
图片.png

你可能感兴趣的:(LeNet网络参数注释)