pytorch报RuntimeError: mat1 and mat2 shapes cannot be multiplied (64x360 and 40x128) 编译代码

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = torch.nn.Sequential(torch.nn.Conv2d(1 , 10, 5), torch.nn.ReLU(), torch.nn.MaxPool2d(2))
        self.conv2 = torch.nn.Sequential(torch.nn.Conv2d(10, 20, 5), torch.nn.ReLU(), torch.nn.MaxPool2d(2))
        self.conv3 = torch.nn.Sequential(torch.nn.Conv2d(20, 40, 3), torch.nn.ReLU(), torch.nn.MaxPool2d(2))
        self.dense = torch.nn.Sequential(torch.nn.Linear(40, 128), torch.nn.ReLU(), torch.nn.Linear(128, 10))

    def forward(self, x):
        conv1_out = self.conv1(x)
        conv2_out = self.conv2(conv1_out)
        conv3_out = self.conv3(conv2_out)
        res = conv3_out.view(conv3_out.size(0), -1)
        print('out', conv3_out.shape)  //torch.Size([64, 40, 3, 3])
        out = self.dense(res)
        return F.log_softmax(out, dim=1)

运行代码报错:RuntimeError: mat1 and mat2 shapes cannot be multiplied (64x360 and 40x128)

 print('out', conv3_out.shape)  //torch.Size([64, 40, 3, 3])

([64, 40, 3, 3]) 其中的64是设置的batch_size,后三维才是其真正的形状,而全连接层的输入是一维特征,因此需要添加一个flatten层进行压平操作。压平后如下:

[64, 40*3*3]


 self.dense = torch.nn.Sequential(torch.nn.Linear(360, 128), torch.nn.ReLU(), torch.nn.Linear(128, 10))



你可能感兴趣的:(pytorch报RuntimeError: mat1 and mat2 shapes cannot be multiplied (64x360 and 40x128))