Pytorch Error: ValueError: Expected input batch_size (324) to match target batch_size (4) Log In

ERROR

运行到loss = criterion(output, target)时
报错:

ValueError: Expected input batch_size (324) to match target batch_size (4) Log In

解决方法

打印

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(3, 32, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(32, 64, 5)
        self.fc1 = nn.Linear(64, 1024)
        self.fc2 = nn.Linear(1024, 7)
        #self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        print(x.shape)
        x = x.view(-1, 64)
        print(x.shape)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x

在x.view(-1,64)前后打印tensor的shape,这个时候就能发现问题出在哪里了:

torch.Size([4, 64, 9, 9])

根据这个形状,需要把view修改为:

x = x.view(-1, 64 * 9 * 9)

后面的Linear层也需要对应修改:

self.fc1 = nn.Linear(64 * 9 * 9, 1024)

你可能感兴趣的:(pytorch,python)