论文里通常要放分割可视化对比图。如果每种方法的分割可视化配色不同,则不利于对比,也不美观。
期望的效果:
原理:对每个实例计算重叠比,将不同分割方法的结果分别与label作比较。
ids, count = np.unique(seg, return_counts=True)
id_dict = {}
for k,v in zip(ids, count):
id_dict.setdefault(k, v)
以上用于统计分割结果中出现了多少个id,每个id出现了多少次。
再对每一个id进行与label 中id的匹配。(find_id
)
import os
import h5py
from skimage import io
import numpy as np
from PIL import Image
def gen_colormap(seed=666):
np.random.seed(seed)
color_val = np.random.randint(0, 255, (1000000, 3))
print('the shape of color map:', color_val.shape)
return color_val
def sort_gt(path='../data/snemi3d/AC3_labels.h5', out='ac3.h5'):
f_gt = h5py.File(path, 'r')
labels = f_gt['main'][:]
f_gt.close()
labels = labels[-50:]
value_0 = labels == 0
print('min value=%d, max value=%d' % (np.min(labels), np.max(labels)))
ids, count = np.unique(labels, return_counts=True)
id_dict = {}
for k,v in zip(ids, count):
id_dict.setdefault(k, v)
sorted_results = sorted(id_dict.items(), key = lambda kv:(kv[1], kv[0]), reverse=True)
print(sorted_results[:10])
num = 1
new_labels = np.zeros_like(labels)
for k in sorted_results:
temp_id = k[0]
if temp_id == 0:
continue
new_labels[labels==temp_id] = num
num += 1
new_labels[value_0] = 0
new_labels = new_labels.astype(np.uint16)
print('min value=%d, max value=%d' % (np.min(new_labels), np.max(new_labels)))
f_out = h5py.File('./'+out, 'w')
f_out.create_dataset('main', data=new_labels, dtype=new_labels.dtype, compression='gzip')
f_out.close()
return new_labels
def find_id(label, mask):
label = label * mask
ids, count = np.unique(label, return_counts=True)
id_dict = {}
for k,v in zip(ids, count):
id_dict.setdefault(k, v)
sorted_results = sorted(id_dict.items(), key = lambda kv:(kv[1], kv[0]), reverse=True)
new_id = sorted_results[0][0]
if new_id == 0:
new_id = sorted_results[1][0]
return new_id
def show_color(label, color_val):
h, w = label.shape
gt_color = np.zeros((h,w,3), dtype=np.int8)
ids = np.unique(label)
for i in ids:
if i == 0:
continue
temp = np.zeros_like(label, dtype=np.uint8)
temp[label==i] = 1
tmp_color = color_val[i]
gt_color[:,:,0] += temp * tmp_color[0]
gt_color[:,:,1] += temp * tmp_color[1]
gt_color[:,:,2] += temp * tmp_color[2]
gt_color = gt_color.astype(np.uint8)
return gt_color
def match_id(seg, label, mask=True, seed=0):
if mask:
seg[label==0] = 0
ids, count = np.unique(seg, return_counts=True)
id_dict = {}
for k,v in zip(ids, count):
id_dict.setdefault(k, v)
sorted_results = sorted(id_dict.items(), key = lambda kv:(kv[1], kv[0]), reverse=True)
print(sorted_results[:10])
max_id = np.max(seg) + 1
used_id = []
new_seg = np.zeros_like(seg)
for k in sorted_results:
tmp_id = k[0]
if tmp_id == 0:
continue
mask = np.zeros_like(seg)
mask[seg==tmp_id] = 1
new_id = find_id(label.copy(), mask)
if new_id not in used_id:
used_id.append(new_id)
new_seg[seg==tmp_id] = new_id
else:
new_seg[seg==tmp_id] = max_id
max_id += 1
new_seg[label==0] = 0
color_val = gen_colormap(seed=seed)
label_color = show_color(label, color_val)
seg_color = show_color(new_seg, color_val)
im_cat = np.concatenate([seg_color, label_color], axis=1)
return im_cat
def randomlabel(segmentation):
segmentation=segmentation.astype(np.uint32)
uid=np.unique(segmentation)
mid=int(uid.max())+1
mapping=np.zeros(mid,dtype=segmentation.dtype)
mapping[uid]=np.random.choice(len(uid),len(uid),replace=False).astype(segmentation.dtype)#(len(uid), dtype=segmentation.dtype)
out=mapping[segmentation]
out[segmentation==0]=0
return out
if __name__ == '__main__':
path1='set1_mala/'
subs=[ 'BasicVSR.hdf', 'Bicubic.hdf','EDVR.hdf', 'ESRGAN.hdf','Ours.hdf',
'RCAN.hdf','RealESRGAN.hdf', 'SwinIR.hdf']
folders=[ 'BasicVSR', 'Bicubic', 'EDVR', 'ESRGAN', 'Ours',
'RCAN', 'RealESRGAN', 'SwinIR']
# label
in_path2=path1+'cremiC_labels.h5'
read_nth_img = 15
read_nth_img = 'all'
seed_num = 666
for sub_id in range(len(subs)):
in_path1=path1+subs[sub_id]
print(in_path1)
save_folder=path1+folders[sub_id]+'\\'
gt_folder=path1+'cremiC_labels\\'
if not os.path.exists(save_folder):
os.makedirs(save_folder)
if not os.path.exists(gt_folder):
os.makedirs(gt_folder)
f = h5py.File(in_path1, 'r')
for k in f.keys():
print(k)
seg = f['main'][:,:,:]
print(seg.shape)
f.close()
f2 = h5py.File(in_path2, 'r')
for k in f2.keys():
print(k)
labels = f2['main'][:,:,:]
labels = labels[-50:]
print(labels.shape)
f2.close()
if read_nth_img =='all':
# read every z-frame
for i in range(labels.shape[0]):
out = match_id(seg[i], labels[i], seed=seed_num) # 2D match i-th frame
if out.ndim==3:
[h,w,c]=out.shape
else:
[h,w]=out.shape
seg_out=out[:,:w//2]
labels_out=out[:,w//2:]
io.imsave(save_folder+str(i).zfill(4)+'.png' , seg_out)
io.imsave(gt_folder+str(i).zfill(4)+'.png' , labels_out)
else:
i = read_nth_img
out = match_id(seg[i], labels[i], seed=seed_num) # 2D match i-th frame
if out.ndim==3:
[h,w,c]=out.shape
else:
[h,w]=out.shape
seg_out=out[:,:w//2]
labels_out=out[:,w//2:]
io.imsave(save_folder+str(i).zfill(4)+'.png' , seg_out)
io.imsave(gt_folder+str(i).zfill(4)+'.png' , labels_out)