only batches of spatial targets supported (3D tensors) but got targets of dimension

问题产生的原因是使用nn.CrossEntropyLoss()来计算损失的时候,target的维度超过4

import torch
import torch.nn as nn

logit = torch.ones(size=(4, 32, 256, 256))  # b,c,h,w
target = torch.ones(size=(4, 1, 256, 256))

criterion = nn.CrossEntropyLoss()
loss = criterion(logit, target)

only batches of spatial targets supported (3D tensors) but got targets of dimension_第1张图片

如实target中的C不是1,则可以:

import torch
import torch.nn as nn

logit = torch.ones(size=(4, 32, 256, 256))  # b,c,h,w
target = torch.ones(size=(4, 2, 256, 256))

criterion = nn.CrossEntropyLoss()
losses = 0
for i in range(2):
    loss = criterion(logit, target[:, i, ...].long())
    losses += loss

 可以看到代码里面有个.long(),如果不用的话则会报错:

RuntimeError: expected scalar type Long but found Float

你可能感兴趣的:(解决bug,深度学习,计算机视觉,人工智能)