在用pytorch写网络结构的时候,会发现在卷积层与第一个全连接层的input_features不知道该写多少?
这个时候我们可以换一个思路,由于pytorch支持动态图,所以我们可以把全链接层写在forward这个函数里面,这样就很好的解决了这个尴尬的问题。
代码如下
from torch import nn
class CNN(nn.Module):
def __init__(self):
super(CNN,self).__init__()
self.CNN1=nn.Sequential(nn.Conv2d(in_channels=1,
kernel_size=3,out_channels=32),
nn.ReLU(),
nn.MaxPool2d(2,1))
self.CNN2=nn.Sequential(nn.Conv2d(in_channels=32,
kernel_size=3,out_channels=32),
nn.ReLU(),
nn.MaxPool2d(2,1))
self.CNN3=nn.Sequential(nn.Conv2d(in_channels=32,
kernel_size=3,out_channels=32),
nn.ReLU(),
nn.MaxPool2d(2,1))
self.CNN4=nn.Sequential(nn.Conv2d(in_channels=32,
kernel_size=3,out_channels=64),
nn.ReLU(),
nn.MaxPool2d(2,1))
self.CNN5=nn.Sequential(nn.Conv2d(in_channels=64,
kernel_size=3,out_channels=64),
nn.ReLU(),
nn.MaxPool2d(2,1))
def forward(self,x,built_FC):
x=self.CNN1(x)
x=self.CNN2(x)
x=self.CNN3(x)
x=self.CNN4(x)
x=self.CNN5(x)
print(x.shape)
x=x.view(x.shape[0],-1) # 展开
print(x.shape)
if built_FC:
(b,in_f)=x.shape # 查看卷积层输出的tensor平铺后的形状
self.FC=nn.Linear(in_f,10) #全链接层
x=self.FC(x)
return x
net = CNN()
#在实际训练过程中必须先进行一次前向传播。
#否则后向传播可能不会更新FC的参数,
#(我的猜测,具体会不会更新我没有试,有兴趣的可以试一下,之后告诉我一下)。
data_input = tc.autograd.Variable(torch.randn([1, 1, 28,28])) # 这里假设输入图片是28*28
print(data_input.size())
net(data_input,True)
print(net)
data_input = tc.autograd.Variable(torch.randn([1, 1, 28,28]))
net(data_input,False) #不报错
print(net)
data_input = tc.autograd.Variable(torch.randn([1, 1, 96,96]))
net(data_input,False) #报错
print(net)
输出
torch.Size([1, 1, 28, 28])
torch.Size([1, 64, 13, 13])
torch.Size([1, 10816])
CNN(
(CNN1): Sequential(
(0): Conv2d(1, 32, kernel_size=(3, 3), stride=(1, 1))
(1): ReLU()
(2): MaxPool2d(kernel_size=2, stride=1, padding=0, dilation=1, ceil_mode=False)
)
(CNN2): Sequential(
(0): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1))
(1): ReLU()
(2): MaxPool2d(kernel_size=2, stride=1, padding=0, dilation=1, ceil_mode=False)
)
(CNN3): Sequential(
(0): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1))
(1): ReLU()
(2): MaxPool2d(kernel_size=2, stride=1, padding=0, dilation=1, ceil_mode=False)
)
(CNN4): Sequential(
(0): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1))
(1): ReLU()
(2): MaxPool2d(kernel_size=2, stride=1, padding=0, dilation=1, ceil_mode=False)
)
(CNN5): Sequential(
(0): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1))
(1): ReLU()
(2): MaxPool2d(kernel_size=2, stride=1, padding=0, dilation=1, ceil_mode=False)
)
(FC): Linear(in_features=10816, out_features=10, bias=True)
)
torch.Size([1, 64, 13, 13])
torch.Size([1, 10816])
CNN(
(CNN1): Sequential(
(0): Conv2d(1, 32, kernel_size=(3, 3), stride=(1, 1))
(1): ReLU()
(2): MaxPool2d(kernel_size=2, stride=1, padding=0, dilation=1, ceil_mode=False)
)
(CNN2): Sequential(
(0): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1))
(1): ReLU()
(2): MaxPool2d(kernel_size=2, stride=1, padding=0, dilation=1, ceil_mode=False)
)
(CNN3): Sequential(
(0): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1))
(1): ReLU()
(2): MaxPool2d(kernel_size=2, stride=1, padding=0, dilation=1, ceil_mode=False)
)
(CNN4): Sequential(
(0): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1))
(1): ReLU()
(2): MaxPool2d(kernel_size=2, stride=1, padding=0, dilation=1, ceil_mode=False)
)
(CNN5): Sequential(
(0): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1))
(1): ReLU()
(2): MaxPool2d(kernel_size=2, stride=1, padding=0, dilation=1, ceil_mode=False)
)
(FC): Linear(in_features=10816, out_features=10, bias=True)
)
torch.Size([1, 64, 81, 81])
torch.Size([1, 419904])
---------------------------------------------------------------------------
RuntimeError Traceback (most recent call last)
in ()
56 print(net)
57 data_input = tc.autograd.Variable(torch.randn([1, 1, 96,96]))
---> 58 net(data_input,False) #报错
59 print(net)
/root/anaconda3/lib/python3.6/site-packages/torch/nn/modules/module.py in __call__(self, *input, **kwargs)
475 result = self._slow_forward(*input, **kwargs)
476 else:
--> 477 result = self.forward(*input, **kwargs)
478 for hook in self._forward_hooks.values():
479 hook_result = hook(self, input, result)
in forward(self, x, buit_FC)
39 self.FC=nn.Linear(in_f,10) #全链接层
40
---> 41 x=self.FC(x)
42
43 return x
/root/anaconda3/lib/python3.6/site-packages/torch/nn/modules/module.py in __call__(self, *input, **kwargs)
475 result = self._slow_forward(*input, **kwargs)
476 else:
--> 477 result = self.forward(*input, **kwargs)
478 for hook in self._forward_hooks.values():
479 hook_result = hook(self, input, result)
/root/anaconda3/lib/python3.6/site-packages/torch/nn/modules/linear.py in forward(self, input)
53
54 def forward(self, input):
---> 55 return F.linear(input, self.weight, self.bias)
56
57 def extra_repr(self):
/root/anaconda3/lib/python3.6/site-packages/torch/nn/functional.py in linear(input, weight, bias)
1022 if input.dim() == 2 and bias is not None:
1023 # fused op is marginally faster
-> 1024 return torch.addmm(bias, input, weight.t())
1025
1026 output = input.matmul(weight.t())
RuntimeError: size mismatch, m1: [1 x 419904], m2: [10816 x 10] at /pytorch/aten/src/TH/generic/THTensorMath.cpp:2070