项目GitHub主页:https://github.com/orobix/retina-unet
参考论文:Retina blood vessel segmentation with a convolution neural network (U-net) Retina blood vessel segmentation with a convolution neural network (U-net)
import os
import h5py
import numpy as np
from PIL import Image
def write_hdf5(arr,outfile): # arr:数据 outfile:数据保存文件位置
with h5py.File(outfile,"w") as f:
f.create_dataset("image", data=arr, dtype=arr.dtype)
# 训练数据位置:图像 金标准 掩膜
original_imgs_train = "./DRIVE/training/images/"
groundTruth_imgs_train = "./DRIVE/training/1st_manual/"
borderMasks_imgs_train = "./DRIVE/training/mask/"
# 测试数据位置:图像 金标准 掩膜
original_imgs_test = "./DRIVE/test/images/"
groundTruth_imgs_test = "./DRIVE/test/1st_manual/"
borderMasks_imgs_test = "./DRIVE/test/mask/"
# 封装数据保存位置
dataset_path = "./datasets_training_testing/"
Nimgs = 20
channels = 3
height = 584
width = 565
def get_datasets(imgs_dir,groundTruth_dir,borderMasks_dir,train_test="null"):
imgs = np.empty((Nimgs,height,width,channels))
groundTruth = np.empty((Nimgs,height,width)) # 二值图像 channels=1
border_masks = np.empty((Nimgs,height,width)) # 二值图像 channels=1
for path, subdirs, files in os.walk(imgs_dir):# path=当前路径 subdirs=子文件夹 files=文件夹内所有的文件
for i in range(len(files)): # len(files) 所有图像的数量
print ("original image: " +files[i])
img = Image.open(imgs_dir+files[i]) # 读取图像到内存
imgs[i] = np.asarray(img) # 转换成numpy数据格式
groundTruth_name = files[i][0:2] + "_manual1.gif"
print ("ground truth name: " + groundTruth_name)
g_truth = Image.open(groundTruth_dir + groundTruth_name)
groundTruth[i] = np.asarray(g_truth)
border_masks_name = ""
if train_test=="train":
border_masks_name = files[i][0:2] + "_training_mask.gif"
elif train_test=="test":
border_masks_name = files[i][0:2] + "_test_mask.gif"
else:
print "please specify if train or test!!"
exit()
print ("border masks name: " + border_masks_name)
b_mask = Image.open(borderMasks_dir + border_masks_name)
border_masks[i] = np.asarray(b_mask)
print ("imgs max: " +str(np.max(imgs)))
print ("imgs min: " +str(np.min(imgs)))
assert(np.max(groundTruth)==255 and np.max(border_masks)==255) # 断言判断
assert(np.min(groundTruth)==0 and np.min(border_masks)==0)
# 调整张量格式 [Nimg, channels, height, width]
imgs = np.transpose(imgs,(0,3,1,2))
groundTruth = np.reshape(groundTruth,(Nimgs,1,height,width))
border_masks = np.reshape(border_masks,(Nimgs,1,height,width))
# 检查张量格式
assert(imgs.shape == (Nimgs,channels,height,width))
assert(groundTruth.shape == (Nimgs,1,height,width))
assert(border_masks.shape == (Nimgs,1,height,width))
return imgs, groundTruth, border_masks
if not os.path.exists(dataset_path):
os.makedirs(dataset_path)
# 封装训练数据集
imgs_train, groundTruth_train, border_masks_train
= get_datasets(original_imgs_train,groundTruth_imgs_train,borderMasks_imgs_train,"train")
print ("saving train datasets ... ...")
write_hdf5(imgs_train, dataset_path + "imgs_train.hdf5")
write_hdf5(groundTruth_train, dataset_path + "groundTruth_train.hdf5")
write_hdf5(border_masks_train,dataset_path + "borderMasks_train.hdf5")
# 封装测试数据集
imgs_test, groundTruth_test, border_masks_test
= get_datasets(original_imgs_test,groundTruth_imgs_test,borderMasks_imgs_test,"test")
print ("saving test datasets ... ...")
write_hdf5(imgs_test,dataset_path + "DRIVE_dataset_imgs_test.hdf5")
write_hdf5(groundTruth_test, dataset_path + "DRIVE_dataset_groundTruth_test.hdf5")
write_hdf5(border_masks_test,dataset_path + "DRIVE_dataset_borderMasks_test.hdf5")
def write_hdf5(arr,outfile):
with h5py.File(outfile,"w") as f:
f.create_dataset("image", data=arr, dtype=arr.dtype)
def load_hdf5(infile):
with h5py.File(infile,"r") as f: # "image"是写入的时候规定的字段 key-value
return f["image"][()] # 调用方法 train_imgs_original = load_hdf5( file_dir )
# 将RGB图像转换为Gray图像
def rgb2gray(rgb):
assert (len(rgb.shape)==4) #[Nimgs, channels, height, width]
assert (rgb.shape[1]==3) #确定是否为RGB图像
bn_imgs = rgb[:,0,:,:]*0.299 + rgb[:,1,:,:]*0.587 + rgb[:,2,:,:]*0.114
bn_imgs = np.reshape(bn_imgs,(rgb.shape[0],1,rgb.shape[2],rgb.shape[3])) # 确保张量形式正确
return bn_imgs
# 对数据集划分,进行分组显示,totimg图像阵列
def group_images(data,per_row): # data:数据 per_row:每行显示的图像个数
assert data.shape[0]%per_row==0 # data=[Nimgs, channels, height, width]
assert (data.shape[1]==1 or data.shape[1]==3)
data = np.transpose(data,(0,2,3,1)) # 用于显示
all_stripe = []
for i in range(int(data.shape[0]/per_row)): # data.shape[0]/per_row 行数
stripe = data[i*per_row] # 相当于matlab中的 data(i*per_row, :, :, :) 一张图像
for k in range(i*per_row+1, i*per_row+per_row):
stripe = np.concatenate((stripe,data[k]),axis=1) # 每per_row张图像拼成一行
all_stripe.append(stripe) # 加入列表
totimg = all_stripe[0]
for i in range(1,len(all_stripe)):
totimg = np.concatenate((totimg,all_stripe[i]),axis=0) # 每行图像进行拼凑 共len(all_stripe)行
return totimg
def visualize(data,filename):
assert (len(data.shape)==3) #height*width*channels
img = None
if data.shape[2]==1: #in case it is black and white
data = np.reshape(data,(data.shape[0],data.shape[1]))
if np.max(data)>1:
img = Image.fromarray(data.astype(np.uint8)) #the image is already 0-255
else:
img = Image.fromarray((data*255).astype(np.uint8)) #the image is between 0-1
img.save(filename + '.png') #保存
return img