前段时间在用VoxelMorph框架做二维图像的配准,在数据准备和读取一块花了不少的时间,这里我把自己的核心代码分享一下,以供参考。
关于VoxelMorph的源代码请参考https://github.com/voxelmorph/voxelmorph
这个框架是用Keras写的,网络部分的输入是待配准的图像对:浮动图像和固定图像。因此,我需要读取就是对应图像对。
版本1:
定义一个数据生成器,直接从原图像所在的文件夹中读取数据。固定图像和浮动图像保存在不同的文件夹中,对应的图像以相同的数字命名。
特点:操作简单,但是效率低,尤其是对于预处理复杂的情况,训练网络时每迭代一次都需要对图像进行一遍预处理操作,导致训练时间大大加长。
def data_gen(datadir,shape,batch_size=32,rotate=False):
'''
数据生成器
'''
X1 = np.zeros((batch_size, shape[0], shape[1], 1))
X2 = np.zeros((batch_size, shape[0], shape[1], 1))
zeros = np.zeros((batch_size,shape[0],shape[1],2))
image_names = glob.glob(datadir+"/fix/*.jpg")
length = len(image_names)
#从数据集中随机挑选batchsize对图像
while True:
n = np.random.randint(1,length+1,batch_size)
for i in range(batch_size):
#固定图像和浮动图像各自的路径
image_name1 = datadir+"/fix/%d.jpg"%n[i]
image_name2 = datadir+"/move/%d.jpg"%n[i]
#读取并预处理
img1 = cv2.imread(image_name1,0)
img1 = cv2.resize(img1,shape, interpolation = cv2.INTER_CUBIC)
img1 = img1/255.0
img1 = np.expand_dims(img1,axis=-1)
X1[i]=img1
img2 = cv2.imread(image_name2,0)
img2 = cv2.resize(img2,shape, interpolation = cv2.INTER_CUBIC)
img2 = img2/255.0
img2 = np.expand_dims(img2,axis=-1)
X2[i] = img2
#输入和标签,x1是扭曲后图像的标签,zeros是形变场的标签
yield ([X2,X1],[X1,zeros])
这里有个比较严重的问题是把预处理部分写到了数据生成器中。后期我的预处理越来越复杂,导致网络训练速度慢到无法忍受。
版本2:
将已经预处理好的固定图像和浮动图像各自按顺序保存为两个npy文件(两个数组),数据生成器只需要加载该封装好的数据就好,通过数组的索引可以将固定图像和浮动图像对应起来。
def data_gen1(dataFix,dataAffined,shape,batch_size=32):
X1 = np.zeros((batch_size, shape[0], shape[1], 1))
X2 = np.zeros((batch_size, shape[0], shape[1], 1))
zeros = np.zeros((batch_size,shape[0],shape[1],2))
data_fix = np.load(dataFix)
data_affined = np.load(dataAffined)
length = data_fix.shape[0]
#从数据集中随机挑选batchsize对图像
while True:
n = np.random.randint(0,length,batch_size)
for i in range(batch_size):
index = n[i]
X1[i] = data_fix[index]/255.0
X2[i] = data_affined[index]/255.0
yield ([X2,X1],[X1,zeros])
下面是数据生成器一个简单的测试代码,用于检查生成的数据是否正确:
if __name__ == '__main__':
x= data_gen1('../data/train/fix.npy','../data/train/affined.npy',shape=(512,512))
([X1,X2],[X2,zeros]) =next(x)
l = len(X1)
for i in range(l):
cv2.imwrite('%d.jpg'%i,X1[i]*255)
cv2.imwrite('%d_.jpg'%i,X2[i]*255)
print("写入结束!")
最后放一下在训练文件中数据生成器调用的代码:
history_ft = model.fit_generator( data_gen1(dataFix,dataAffined,shape=(512,512)),
epochs=nb_epochs,
callbacks=[checkpoint],
steps_per_epoch=steps_per_epoch,
#validation_data=data_gen(data_dir_validation,input_size,batch_size),
#validation_steps=200,
verbose=1)
版本3:
一个二维图像配准通用的数据读取代码,可在线数据增强,可用于弱监督配准架构。思路是把固定图像、浮动图像以及对应标签封装成一个npz文件,有多少图像对就有多少npz文件,所有文件通通以数字命名便于查找。
图像封装成npz文件的代码(省略预处理部分,有问题可留言)
class mydata():
'''
数据处理的类,可以做预处理,封装为npz文件等
'''
def __init__(self,size_resize=512,size_crop=360):
'''
:param
size_resize: resize的尺寸,默认512
:param
size_crop: 裁剪尺寸,默认360
'''
self.size_resize = size_resize
self.size_crop = size_crop
def preprocess(self,path,isImage):
'''
对送入网络的图像进行预处理,包含图像增强,resize和crop操作
:param path: 图像路径
:param isImage:是图像还是标签
:return: np数组,shape=(size_crop,size_crop)
'''
return img
def savenpz(self,dir,dir_save,N,ext="jpg"):
'''
获得当前路径下预处理后的数组
:param dir: 文件夹目录,"../data/train/data_ori",有"image_fix""image_move""label_fix""label_move"四个文件夹,
下面的文件按数字1,2,3...命名
:param dir_save:npz文件保存的目录
:param ext:图像的后缀
:return: None
'''
list_dir = ["image_fix","image_move","label_fix","label_move"]
for i in tqdm(range(1,N+1)):
path_img_f = f"{dir}/{list_dir[0]}/{i}.{ext}"
path_img_m = f"{dir}/{list_dir[1]}/{i}.{ext}"
path_label_f = f"{dir}/{list_dir[2]}/{i}.{ext}"
path_label_m = f"{dir}/{list_dir[3]}/{i}.{ext}"
ima_f = self.preprocess(path_img_f,True)
img_m = self.preprocess(path_img_m,True)
label_f = self.preprocess(path_label_f,False)
label_m = self.preprocess(path_label_m,False)
np.savez(f"{dir_save}/{i}.npz", image_fix=ima_f, image_move=img_m, label_fix=label_f,
label_move=label_m))
函数调用代码:
data = mydata(size_resize=512,size_crop=512)
data.savenpz("../data/trainP/data_ori","../data/trainP/npz_nocrop",41)
数据生成器代码:
调用的albumentations这个数据增强库做的数据增强,这个库也可以嵌入到pytorch中,强烈安利。
class DataGenerator(keras.utils.Sequence):
'Generates data for Keras'
def __init__(self,mode='fit',
base_path='../data/train/npz',
batch_size=32, dim=(360,360),
augment=False,random_state=2019, shuffle=True):
self.dim = dim
self.batch_size = batch_size
self.mode = mode
self.base_path = base_path
self.augment = augment
self.shuffle = shuffle
self.random_state = random_state
self.paths = glob.glob(base_path+"/*.npz")
self.length_dataset = len(self.paths)
self.on_epoch_end()
np.random.seed(self.random_state)
def __len__(self):
'Denotes the number of batches per epoch'
return int(np.floor(self.length_dataset / self.batch_size))
def __getitem__(self, index):
'Generate one batch of data'
# Generate indexes of the batch
indexes = self.indexes[index * self.batch_size:(index + 1) * self.batch_size] # 一批数据的索引,np数组
if self.mode == 'fit':
X_fix, X_move, Y_fix, Y_move = self.__generate_XY(indexes)
if self.augment:
X_fix, X_move, Y_fix, Y_move = self.__augment_batch(X_fix, X_move, Y_fix, Y_move)
return [X_fix, X_move,Y_move], Y_fix #根据不同的目标函数要修改!!!
elif self.mode == 'predict':
raise NotImplementedError
else:
raise AttributeError('The mode parameter should be set to "fit" or "predict".')
def on_epoch_end(self):
'Updates indexes after each epoch'
self.indexes = np.arange(self.length_dataset)
if self.shuffle == True:
np.random.seed(self.random_state)
np.random.shuffle(self.indexes)
def __generate_XY(self, indexes):
'Generates data containing batch_size samples'
# Initialization
X_fix = np.empty((self.batch_size, *self.dim, 1))
X_move = np.empty((self.batch_size, *self.dim, 1))
Y_fix = np.empty((self.batch_size, *self.dim, 1))
Y_move = np.empty((self.batch_size, *self.dim, 1))
# Generate data
for i, ID in enumerate(indexes):
npz_path = self.paths[ID]
data = np.load(npz_path)
image_fix = data["image_fix"] / 255.0
image_move = data["image_move"] / 255.0
mask_fix = data["label_fix"] / 255
mask_move = data["label_move"] / 255
# Store samples
X_fix[i, :, :, 0] = image_fix
X_move[i, :, :, 0j] = image_move
Y_fix[i, :, :, 0] = mask_fix
Y_move[i, :, :, 0] = mask_move
return X_fix, X_move, Y_fix, Y_move
def __augment_batch(self, X_fix, X_move, Y_fix, Y_move):
for i in range(self.batch_size):
X_fix[i,], X_move[i,], Y_fix[i,], Y_move[i,]= self.__random_transform(X_fix[i,], X_move[i,], Y_fix[i,],
Y_move[i,])
return X_fix, X_move, Y_fix, Y_move
def __random_transform(self, x_fix, x_move, y_fix, y_move):
'''
翻转:固定图像和浮动图像做一样的
弹性形变:固定图像和浮动图像分开
仿射变化:固定图像和浮动图像分开
:return:
'''
aug = albu.Flip(p=0.5)
augmented = aug(image=x_fix, masks=[x_fix, x_move, y_fix, y_move])
x_fix_aug, x_move_aug, y_fix_aug, y_move_aug = augmented["masks"]
composition = albu.Compose(
[albu.ElasticTransform(p=0.8, alpha_affine=5, border_mode=cv2.BORDER_CONSTANT, value=0),
albu.ShiftScaleRotate(rotate_limit=3, scale_limit=0, shift_limit=0.03, border_mode=cv2.BORDER_CONSTANT,
value=0)
])
composed1 = composition(image=x_fix_aug, mask=y_fix_aug)
composed2 = composition(image=x_move_aug, mask=y_move_aug)
x_fix_aug = composed1['image']
y_fix_aug = composed1['mask']
x_move_aug = composed2['image']
y_move_aug = composed2['mask']
return x_fix_aug, x_move_aug, y_fix_aug, y_move_aug
测试代码
generator = DataGenerator(base_path="../data/testP/npz_nocrop", batch_size=8,dim=(512,512), augment=True)
[X1, X2, Y2], Y1 = generator.__getitem__(0)
print(X1.shape, X2.shape, Y2.shape, Y1.shape)
# 查看任意一张固定图片的标签是否正确
x1 = X1[5,:,:,0]
x2 = X2[5,:,:,0]
y1 = Y1[5,:,:,0]
y2 = Y2[5,:,:,0]
fig, axes = plt.subplots(2, 2, sharex='col', figsize=(20, 12))
axes[0,0].imshow(x1,cmap='gray')
axes[0,0].set_title("image_fix")
axes[0,1].imshow(y1,cmap='gray')
axes[0,1].set_title("mask_fix")
axes[1,0].imshow(x2,cmap='gray')
axes[1,0].set_title("image_move")
axes[1,1].imshow(y2,cmap='gray')
axes[1,1].set_title("mask_move")
plt.show()