pytorch版本UNet训练报错:1only batchs if spatial targets supported (non-empty 3D tensor) but got targets ..

目录

  • 1. 问题分析
  • 2. 解决方法
  • 3. 其他问题

1. 问题分析

参考的pytorch版本的UNet github地址:https://github.com/milesial/Pytorch-UNet

现在的需求是:

  1. 输入单通道的图像,大小为:512X512;
  2. 输出是8个类别的语义分割结果,每个类别占用一个通道,值为0或1;
  3. 设置batch_size=1;
  4. 因此,输入为:1x1x512x512,输出为:1x8x512x512。 (batch_size x channel x W x H)。

根据代码里的提示,设置n_channels=1, n_classes=8,训练过程发生如下报错:
pytorch版本UNet训练报错:1only batchs if spatial targets supported (non-empty 3D tensor) but got targets .._第1张图片

2. 解决方法

选择loss function为criterion = nn.BCEWithLogitsLoss(),如下图所示:
pytorch版本UNet训练报错:1only batchs if spatial targets supported (non-empty 3D tensor) but got targets .._第2张图片
原始代码中,如果n_class>1 会选择nn.CrossEntropyLoss()。

这里其实相当于一个multi-label的任务,输出多个通道代表多个类别,每个通道的值输出是0或者1。

如果是输出在一个通道上,每个类别的值用一个数字表示,例如我有8个类别,分别用0,1,2,3,4,5,6,7的像素值表示,则应该选择用nn.CrossEntropyLoss()的loss function。

3. 其他问题

在train和val过程中,计算loss的时候,可能会出现type类型的报错,即传入的masks_pred和true_masks一个是float type,另一个是Long type,根据提示在变量后面添加.float()或者.long(),让两者类型一致即可。
pytorch版本UNet训练报错:1only batchs if spatial targets supported (non-empty 3D tensor) but got targets .._第3张图片

结束。

你可能感兴趣的:(pytorch版本UNet训练报错:1only batchs if spatial targets supported (non-empty 3D tensor) but got targets ..)