分割可视化图——控制相同实例用相同颜色表示

论文里通常要放分割可视化对比图。如果每种方法的分割可视化配色不同,则不利于对比,也不美观。

期望的效果:
分割可视化图——控制相同实例用相同颜色表示_第1张图片
原理:对每个实例计算重叠比,将不同分割方法的结果分别与label作比较。

ids, count = np.unique(seg, return_counts=True)
id_dict = {}
for k,v in zip(ids, count):
    id_dict.setdefault(k, v)

以上用于统计分割结果中出现了多少个id,每个id出现了多少次。
再对每一个id进行与label 中id的匹配。(find_id

import os
import h5py
from skimage import  io
import numpy as np
from PIL import Image

def gen_colormap(seed=666):
    np.random.seed(seed)
    color_val = np.random.randint(0, 255, (1000000, 3))
    print('the shape of color map:', color_val.shape)
    return color_val

def sort_gt(path='../data/snemi3d/AC3_labels.h5', out='ac3.h5'):
    f_gt = h5py.File(path, 'r')
    labels = f_gt['main'][:]
    f_gt.close()
    labels = labels[-50:]

    value_0 = labels == 0
    print('min value=%d, max value=%d' % (np.min(labels), np.max(labels)))
    ids, count = np.unique(labels, return_counts=True)
    id_dict = {}
    for k,v in zip(ids, count):
        id_dict.setdefault(k, v)
    sorted_results = sorted(id_dict.items(), key = lambda kv:(kv[1], kv[0]), reverse=True)
    print(sorted_results[:10])
    num = 1
    new_labels = np.zeros_like(labels)
    for k in sorted_results:
        temp_id = k[0]
        if temp_id == 0:
            continue
        new_labels[labels==temp_id] = num
        num += 1

    new_labels[value_0] = 0
    new_labels = new_labels.astype(np.uint16)
    print('min value=%d, max value=%d' % (np.min(new_labels), np.max(new_labels)))

    f_out = h5py.File('./'+out, 'w')
    f_out.create_dataset('main', data=new_labels, dtype=new_labels.dtype, compression='gzip')
    f_out.close()

    return new_labels

def find_id(label, mask):
    label = label * mask
    ids, count = np.unique(label, return_counts=True)
    id_dict = {}
    for k,v in zip(ids, count):
        id_dict.setdefault(k, v)
    sorted_results = sorted(id_dict.items(), key = lambda kv:(kv[1], kv[0]), reverse=True)
    new_id = sorted_results[0][0]
    if new_id == 0:
        new_id = sorted_results[1][0]
    return new_id

def show_color(label, color_val):
    h, w = label.shape
    gt_color = np.zeros((h,w,3), dtype=np.int8)
    ids = np.unique(label)
    for i in ids:
        if i == 0:
            continue
        temp = np.zeros_like(label, dtype=np.uint8)
        temp[label==i] = 1
        tmp_color = color_val[i]
        gt_color[:,:,0] += temp * tmp_color[0]
        gt_color[:,:,1] += temp * tmp_color[1]
        gt_color[:,:,2] += temp * tmp_color[2]
    gt_color = gt_color.astype(np.uint8)
    return gt_color

def match_id(seg, label, mask=True, seed=0):
    if mask:
        seg[label==0] = 0
    ids, count = np.unique(seg, return_counts=True)
    id_dict = {}
    for k,v in zip(ids, count):
        id_dict.setdefault(k, v)
    sorted_results = sorted(id_dict.items(), key = lambda kv:(kv[1], kv[0]), reverse=True)
    print(sorted_results[:10])

    max_id = np.max(seg) + 1
    used_id = []
    new_seg = np.zeros_like(seg)
    for k in sorted_results:
        tmp_id = k[0]
        if tmp_id == 0:
            continue
        mask = np.zeros_like(seg)
        mask[seg==tmp_id] = 1
        new_id = find_id(label.copy(), mask)
        if new_id not in used_id:
            used_id.append(new_id)
            new_seg[seg==tmp_id] = new_id
        else:
            new_seg[seg==tmp_id] = max_id
            max_id += 1
    new_seg[label==0] = 0

    color_val = gen_colormap(seed=seed)
    label_color = show_color(label, color_val)
    seg_color = show_color(new_seg, color_val)
    im_cat = np.concatenate([seg_color, label_color], axis=1)
    return im_cat

def randomlabel(segmentation):
    segmentation=segmentation.astype(np.uint32)
    uid=np.unique(segmentation)
    mid=int(uid.max())+1
    mapping=np.zeros(mid,dtype=segmentation.dtype)
    mapping[uid]=np.random.choice(len(uid),len(uid),replace=False).astype(segmentation.dtype)#(len(uid), dtype=segmentation.dtype)
    out=mapping[segmentation]
    out[segmentation==0]=0
    return out

if __name__ == '__main__':
    path1='set1_mala/'

    subs=[ 'BasicVSR.hdf', 'Bicubic.hdf','EDVR.hdf', 'ESRGAN.hdf','Ours.hdf',
            'RCAN.hdf','RealESRGAN.hdf', 'SwinIR.hdf']
    folders=[ 'BasicVSR', 'Bicubic', 'EDVR', 'ESRGAN', 'Ours',
            'RCAN', 'RealESRGAN', 'SwinIR']


    # label
    in_path2=path1+'cremiC_labels.h5'

    read_nth_img = 15
    read_nth_img = 'all'

    seed_num = 666


    for sub_id in range(len(subs)):
        in_path1=path1+subs[sub_id]
        print(in_path1)
        save_folder=path1+folders[sub_id]+'\\'
        gt_folder=path1+'cremiC_labels\\'

        if not os.path.exists(save_folder):
            os.makedirs(save_folder)
        if not os.path.exists(gt_folder):
            os.makedirs(gt_folder)

        f = h5py.File(in_path1, 'r')
        for k in f.keys():
            print(k)
        seg = f['main'][:,:,:]
        print(seg.shape)
        f.close()

        f2 = h5py.File(in_path2, 'r')
        for k in f2.keys():
            print(k)
        labels = f2['main'][:,:,:]
        labels = labels[-50:]
        print(labels.shape)
        f2.close()

        if read_nth_img =='all':
            # read every z-frame
            for i in range(labels.shape[0]):
                out = match_id(seg[i], labels[i], seed=seed_num) # 2D match i-th frame
                if out.ndim==3:
                    [h,w,c]=out.shape
                else:
                    [h,w]=out.shape
                seg_out=out[:,:w//2]
                labels_out=out[:,w//2:]

                io.imsave(save_folder+str(i).zfill(4)+'.png' , seg_out)
                io.imsave(gt_folder+str(i).zfill(4)+'.png' , labels_out)
        else:
            i = read_nth_img
            out = match_id(seg[i], labels[i], seed=seed_num) # 2D match i-th frame
            if out.ndim==3:
                [h,w,c]=out.shape
            else:
                [h,w]=out.shape
            seg_out=out[:,:w//2]
            labels_out=out[:,w//2:]

            io.imsave(save_folder+str(i).zfill(4)+'.png' , seg_out)
            io.imsave(gt_folder+str(i).zfill(4)+'.png' , labels_out)


你可能感兴趣的:(python,python,人工智能,机器学习)