RuntimeError: 1only batches of spatial targets supported (non-empty 3D tensors) but got targets of s

在用自己的数据集训练unet时,碰到了这样的问题。

RuntimeError: 1only batches of spatial targets supported (3D tensors) but got targets of size: : [1, 640, 959, 3]

损失函数 nn.CrossEntropyLoss()的输入应该是一个4维的张量(网络的输出)和一个三维的张量(target),而读取的数据集中的标签为RGB三通道的图片 [batch size,weight,height,RGB]。

需要将该四维张量的RGB图片输入转为单值的类别信息。

重新将标签制作为单值灰度图。

import os
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
import cv2


color2class_dict = {   
        0: [64.0, 140.0, 214.0, 255.0],
        1: [2.0, 4.0, 244.0, 255.0],
        2: [210.0, 21.0, 28.0, 255],
        3: [9.0, 243.0, 25.0, 255.0]

    }   #自行设置类别对应的颜色字典
def get_keys1(value):   #按字典中的颜色对应关系分类
    p = 0
    for k, v in color2class_dict.items():
        if v == value:
            p = k
            break
    return p
def get_keys2(value):   #自行设置颜色范围
    if value[0] > 150:
        return 1
    elif value[1] > 150 and value[2] < 50:
        return 2
    elif value[1] < 50 and value[2] > 150:
        return 3
    else:
        return 0

def main(input_path, save_path, mode):
    get_keys = get_keys1 if mode == 0 else get_keys2
    img_list = os.listdir(input_path)
    for image in img_list:
        img_path = os.path.join(input_path, image)
        save_path_img = os.path.join(save_path, image.split(".")[0]+"_mask.png")
        img = plt.imread(img_path)*255.0
        img_label = np.zeros((img.shape[0], img.shape[1]))
        img_new_label = np.zeros((img.shape[0], img.shape[1]))
        for i in range(img.shape[0]):
            for j in range(img.shape[1]):
                value = list(img[i, j])
                img_label[i, j] = get_keys(value)
                img_new_label[i, j] = img_label[i, j]
        label0 = Image.fromarray(np.uint8(img_new_label))
        cv2.imwrite(save_path_img, img_label)
        print(image+" done")


input_path = ""
save_path = ""
mode = 0
main(input_path, save_path, mode)

你可能感兴趣的:(python,神经网络,人工智能,计算机视觉,cnn)