卷积维度的判断

import torch

width, height = 28, 28

in_channle = 1

batch_size = 1

inputs = torch.randn(batch_size, in_channle,

                     width, height)

print(inputs.shape)

conv_lay1 = torch.nn.Conv2d(in_channels=1,

                            out_channels=10,

                            kernel_size=5)

output1 = conv_lay1(inputs)

print(output1.shape)

maxpool_lay = torch.nn.MaxPool2d(kernel_size=2)

output2 = maxpool_lay(output1)

print(output2.shape)

conv_lay2 = torch.nn.Conv2d(in_channels=10,

                            out_channels=20,

                            kernel_size=5)

output3 = conv_lay2(output2)

print(output3.shape)

output4 = maxpool_lay(output3)

print(output4.shape)

output5 = output4.view(1, -1)

linear_lay = torch.nn.Linear(320, 10)

output6 = linear_lay(output5)

print(output6.shape)

你可能感兴趣的:(分类)