pytorch卷积神经网络的全连接层第一层参数(input_features)设置

  在用pytorch写网络结构的时候,会发现在卷积层与第一个全连接层的input_features不知道该写多少?

  这个时候我们可以换一个思路,由于pytorch支持动态图,所以我们可以把全链接层写在forward这个函数里面,这样就很好的解决了这个尴尬的问题。

代码如下


from torch import nn
class CNN(nn.Module):
    def __init__(self):
        super(CNN,self).__init__()
        self.CNN1=nn.Sequential(nn.Conv2d(in_channels=1,
            kernel_size=3,out_channels=32),
                                 nn.ReLU(),
                                 nn.MaxPool2d(2,1))
        self.CNN2=nn.Sequential(nn.Conv2d(in_channels=32,
            kernel_size=3,out_channels=32),
                                 nn.ReLU(),
                                 nn.MaxPool2d(2,1))
        self.CNN3=nn.Sequential(nn.Conv2d(in_channels=32,
            kernel_size=3,out_channels=32),
                                 nn.ReLU(),
                                 nn.MaxPool2d(2,1))
        self.CNN4=nn.Sequential(nn.Conv2d(in_channels=32,
            kernel_size=3,out_channels=64),
                                 nn.ReLU(),
                                 nn.MaxPool2d(2,1))
        self.CNN5=nn.Sequential(nn.Conv2d(in_channels=64,
            kernel_size=3,out_channels=64),
                                 nn.ReLU(),
                                 nn.MaxPool2d(2,1))
    def forward(self,x,built_FC):
        x=self.CNN1(x)
        x=self.CNN2(x)
        x=self.CNN3(x)
        x=self.CNN4(x)
        x=self.CNN5(x)
        print(x.shape)
        x=x.view(x.shape[0],-1) # 展开
        print(x.shape)

        if built_FC:
            (b,in_f)=x.shape   # 查看卷积层输出的tensor平铺后的形状
            self.FC=nn.Linear(in_f,10)     #全链接层
        
        x=self.FC(x)
        
        return x




net = CNN()  
#在实际训练过程中必须先进行一次前向传播。
#否则后向传播可能不会更新FC的参数,
#(我的猜测,具体会不会更新我没有试,有兴趣的可以试一下,之后告诉我一下)。

data_input = tc.autograd.Variable(torch.randn([1, 1, 28,28])) # 这里假设输入图片是28*28
print(data_input.size())
net(data_input,True)
print(net)
data_input = tc.autograd.Variable(torch.randn([1, 1, 28,28]))
net(data_input,False) #不报错
print(net)
data_input = tc.autograd.Variable(torch.randn([1, 1, 96,96]))
net(data_input,False) #报错
print(net)

输出

torch.Size([1, 1, 28, 28])
torch.Size([1, 64, 13, 13])
torch.Size([1, 10816])
CNN(
  (CNN1): Sequential(
    (0): Conv2d(1, 32, kernel_size=(3, 3), stride=(1, 1))
    (1): ReLU()
    (2): MaxPool2d(kernel_size=2, stride=1, padding=0, dilation=1, ceil_mode=False)
  )
  (CNN2): Sequential(
    (0): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1))
    (1): ReLU()
    (2): MaxPool2d(kernel_size=2, stride=1, padding=0, dilation=1, ceil_mode=False)
  )
  (CNN3): Sequential(
    (0): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1))
    (1): ReLU()
    (2): MaxPool2d(kernel_size=2, stride=1, padding=0, dilation=1, ceil_mode=False)
  )
  (CNN4): Sequential(
    (0): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1))
    (1): ReLU()
    (2): MaxPool2d(kernel_size=2, stride=1, padding=0, dilation=1, ceil_mode=False)
  )
  (CNN5): Sequential(
    (0): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1))
    (1): ReLU()
    (2): MaxPool2d(kernel_size=2, stride=1, padding=0, dilation=1, ceil_mode=False)
  )
  (FC): Linear(in_features=10816, out_features=10, bias=True)
)
torch.Size([1, 64, 13, 13])
torch.Size([1, 10816])
CNN(
  (CNN1): Sequential(
    (0): Conv2d(1, 32, kernel_size=(3, 3), stride=(1, 1))
    (1): ReLU()
    (2): MaxPool2d(kernel_size=2, stride=1, padding=0, dilation=1, ceil_mode=False)
  )
  (CNN2): Sequential(
    (0): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1))
    (1): ReLU()
    (2): MaxPool2d(kernel_size=2, stride=1, padding=0, dilation=1, ceil_mode=False)
  )
  (CNN3): Sequential(
    (0): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1))
    (1): ReLU()
    (2): MaxPool2d(kernel_size=2, stride=1, padding=0, dilation=1, ceil_mode=False)
  )
  (CNN4): Sequential(
    (0): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1))
    (1): ReLU()
    (2): MaxPool2d(kernel_size=2, stride=1, padding=0, dilation=1, ceil_mode=False)
  )
  (CNN5): Sequential(
    (0): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1))
    (1): ReLU()
    (2): MaxPool2d(kernel_size=2, stride=1, padding=0, dilation=1, ceil_mode=False)
  )
  (FC): Linear(in_features=10816, out_features=10, bias=True)
)
torch.Size([1, 64, 81, 81])
torch.Size([1, 419904])







---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
 in ()
     56 print(net)
     57 data_input = tc.autograd.Variable(torch.randn([1, 1, 96,96]))
---> 58 net(data_input,False) #报错
     59 print(net)

/root/anaconda3/lib/python3.6/site-packages/torch/nn/modules/module.py in __call__(self, *input, **kwargs)
    475             result = self._slow_forward(*input, **kwargs)
    476         else:
--> 477             result = self.forward(*input, **kwargs)
    478         for hook in self._forward_hooks.values():
    479             hook_result = hook(self, input, result)

 in forward(self, x, buit_FC)
     39             self.FC=nn.Linear(in_f,10)     #全链接层
     40 
---> 41         x=self.FC(x)
     42 
     43         return x

/root/anaconda3/lib/python3.6/site-packages/torch/nn/modules/module.py in __call__(self, *input, **kwargs)
    475             result = self._slow_forward(*input, **kwargs)
    476         else:
--> 477             result = self.forward(*input, **kwargs)
    478         for hook in self._forward_hooks.values():
    479             hook_result = hook(self, input, result)

/root/anaconda3/lib/python3.6/site-packages/torch/nn/modules/linear.py in forward(self, input)
     53 
     54     def forward(self, input):
---> 55         return F.linear(input, self.weight, self.bias)
     56 
     57     def extra_repr(self):

/root/anaconda3/lib/python3.6/site-packages/torch/nn/functional.py in linear(input, weight, bias)
   1022     if input.dim() == 2 and bias is not None:
   1023         # fused op is marginally faster
-> 1024         return torch.addmm(bias, input, weight.t())
   1025 
   1026     output = input.matmul(weight.t())

RuntimeError: size mismatch, m1: [1 x 419904], m2: [10816 x 10] at /pytorch/aten/src/TH/generic/THTensorMath.cpp:2070

你可能感兴趣的:(pytorch)