U-net网络用于多分类——坑点

一共有data.py、main.py、model.py三个文件
参数修改:
flag_multi_class=True
num_class=类数
一、data.py
1、adjustData函数改为:

def adjustData(img,mask,flag_multi_class,num_class):
    if(flag_multi_class):#多类情况
        img = img / 255
        mask = mask[:,:,:,0] if(len(mask.shape) == 4) else mask[:,:,0]
        new_mask = np.zeros(mask.shape + (num_class,))
        new_mask[mask == 0,0] = 1
        new_mask[mask == 50,1] = 1
        new_mask[mask == 150,2] = 1
        new_mask[mask == 255,3] = 1
        mask = new_mask
    elif(np.max(img) > 1):
        img = img / 255
        mask = mask /255
        mask[mask > 0.5] = 1
        mask[mask <= 0.5] = 0
    return (img,mask)

2、未完待续
二、model.py
1、修改如下:
在这里插入图片描述
其中最后一行因为所做的是四分类,所以第一个数字是4
2、修改如下:
在这里插入图片描述
loss也可以是自己定义的损失函数
三、main.py
1、未完待续

总结:
U-net网络用于多分类——坑点_第1张图片

你可能感兴趣的:(U-net网络用于多分类——坑点)