import torch
from torch import nn
from torch.nn import functional as F
class Lenet5(nn.Module):
'''
for cifar10 dataset
'''
def __init__(self):
super(Lenet5, self).__init__()
self.conv_unit = nn.Sequential(
# x: [b,3,32,32]
nn.Conv2d(3,6,kernel_size=5, stride=1,padding=0),
# =>[b,6,
nn.AvgPool2d(kernel_size=2, stride=2,padding=0),
#
nn.Conv2d(6,16,kernel_size=5, stride=1,padding=0),
#
nn.AvgPool2d(kernel_size=2, stride=2,padding=0),
# 打平
)
# flatten
# fc unit
self.fc_unit = nn.Sequential(
nn.Linear(16*5*5, 120),
nn.ReLU(),
nn.Linear(120,84),
nn.ReLU(),
nn.Linear(84,10)
)
# [b, 3,32,32]
tmp = torch.randn(2, 3, 32,32)
out = self.conv_unit(tmp)
# [b,16,5,5]
print("conv out:", out.shape)
# use Cross Entropy Loss
# nn.MSELoss
self.criteon = nn.CrossEntropyLoss()
def forward(self, x):
'''
:param x: [b,3,32,32]
:return:
'''
batchsz = x.size(0) # x.shape[0]
# [b,3,32,32]=>[b,16,5,5]
x = self.conv_unit(x)
# [b,16,5,5]=>[b,16*5*5]
# 打平
x = x.view(batchsz, -1)
# [b,16*5*5]=>[b,10]
logits = self.fc_unit(x)
# pred = F.softmax(logits, dim=1)
# loss = self.criteon(logits,y) nn 和F 的区别, 一个要初始化,另一个直接运行
return logits
def main():
net = Lenet5()
# [b, 3,32,32]
tmp = torch.randn(2, 3, 32, 32)
out = net(tmp)
# [b,16,5,5]
print("conv out:", out.shape)
# use Cross Entropy Loss
# nn.MSELoss
if __name__ == '__main__':
main()