【Keras+配准】用深度学习做配准之训练数据(二维图像对)的准备和读取(2019.11.15更新)

   前段时间在用VoxelMorph框架做二维图像的配准,在数据准备和读取一块花了不少的时间,这里我把自己的核心代码分享一下,以供参考。

  关于VoxelMorph的源代码请参考https://github.com/voxelmorph/voxelmorph

  这个框架是用Keras写的,网络部分的输入是待配准的图像对:浮动图像和固定图像。因此,我需要读取就是对应图像对。

版本1:

  定义一个数据生成器,直接从原图像所在的文件夹中读取数据。固定图像和浮动图像保存在不同的文件夹中,对应的图像以相同的数字命名。

  特点:操作简单,但是效率低,尤其是对于预处理复杂的情况,训练网络时每迭代一次都需要对图像进行一遍预处理操作,导致训练时间大大加长。

def data_gen(datadir,shape,batch_size=32,rotate=False):
    '''
    数据生成器
    '''
    X1 = np.zeros((batch_size, shape[0], shape[1], 1))
    X2 = np.zeros((batch_size, shape[0], shape[1], 1))
    zeros = np.zeros((batch_size,shape[0],shape[1],2))
    image_names = glob.glob(datadir+"/fix/*.jpg")
    length = len(image_names)
    #从数据集中随机挑选batchsize对图像
    while True:
        n = np.random.randint(1,length+1,batch_size)
        for i in range(batch_size):
            #固定图像和浮动图像各自的路径
            image_name1 = datadir+"/fix/%d.jpg"%n[i]
            image_name2 = datadir+"/move/%d.jpg"%n[i]
            #读取并预处理
            img1 = cv2.imread(image_name1,0)
            img1 = cv2.resize(img1,shape, interpolation = cv2.INTER_CUBIC)
            img1 = img1/255.0
            img1 = np.expand_dims(img1,axis=-1)
            X1[i]=img1
         
            img2 = cv2.imread(image_name2,0)
            img2 = cv2.resize(img2,shape, interpolation = cv2.INTER_CUBIC)
            img2 = img2/255.0
            img2 = np.expand_dims(img2,axis=-1)
            X2[i] = img2
        #输入和标签,x1是扭曲后图像的标签,zeros是形变场的标签
        yield ([X2,X1],[X1,zeros])

      这里有个比较严重的问题是把预处理部分写到了数据生成器中。后期我的预处理越来越复杂,导致网络训练速度慢到无法忍受。

版本2:

将已经预处理好的固定图像和浮动图像各自按顺序保存为两个npy文件(两个数组),数据生成器只需要加载该封装好的数据就好,通过数组的索引可以将固定图像和浮动图像对应起来。

def data_gen1(dataFix,dataAffined,shape,batch_size=32):

    X1 = np.zeros((batch_size, shape[0], shape[1], 1))
    X2 = np.zeros((batch_size, shape[0], shape[1], 1))
    zeros = np.zeros((batch_size,shape[0],shape[1],2))
    
    data_fix = np.load(dataFix)
    data_affined = np.load(dataAffined)
    length = data_fix.shape[0]
    #从数据集中随机挑选batchsize对图像
    while True:
        n = np.random.randint(0,length,batch_size)
        for i in range(batch_size):
            index = n[i]
            X1[i] = data_fix[index]/255.0
            X2[i] = data_affined[index]/255.0
        yield ([X2,X1],[X1,zeros])

下面是数据生成器一个简单的测试代码,用于检查生成的数据是否正确:

if __name__ == '__main__':

    x= data_gen1('../data/train/fix.npy','../data/train/affined.npy',shape=(512,512))
    ([X1,X2],[X2,zeros]) =next(x)
    l = len(X1)
    for i in range(l):
        cv2.imwrite('%d.jpg'%i,X1[i]*255)
        cv2.imwrite('%d_.jpg'%i,X2[i]*255)
    print("写入结束!")

最后放一下在训练文件中数据生成器调用的代码:

 history_ft = model.fit_generator( data_gen1(dataFix,dataAffined,shape=(512,512)),
                    epochs=nb_epochs,
                    callbacks=[checkpoint],
                    steps_per_epoch=steps_per_epoch,
                    #validation_data=data_gen(data_dir_validation,input_size,batch_size),
                    #validation_steps=200,
                    verbose=1)

版本3:

   一个二维图像配准通用的数据读取代码,可在线数据增强,可用于弱监督配准架构。思路是把固定图像、浮动图像以及对应标签封装成一个npz文件,有多少图像对就有多少npz文件,所有文件通通以数字命名便于查找。

图像封装成npz文件的代码(省略预处理部分,有问题可留言)  

class mydata():
    '''
    数据处理的类,可以做预处理,封装为npz文件等
    '''
    def __init__(self,size_resize=512,size_crop=360):
        '''
        :param
        size_resize: resize的尺寸,默认512
        :param
        size_crop: 裁剪尺寸,默认360
        '''
        self.size_resize = size_resize
        self.size_crop = size_crop

    def preprocess(self,path,isImage):
        '''
        对送入网络的图像进行预处理,包含图像增强,resize和crop操作
        :param path: 图像路径
        :param isImage:是图像还是标签
        :return: np数组,shape=(size_crop,size_crop)
        '''
       
        return img

    def savenpz(self,dir,dir_save,N,ext="jpg"):
        '''
        获得当前路径下预处理后的数组
        :param dir: 文件夹目录,"../data/train/data_ori",有"image_fix""image_move""label_fix""label_move"四个文件夹,
        下面的文件按数字1,2,3...命名
        :param dir_save:npz文件保存的目录
        :param ext:图像的后缀
        :return: None
        '''
        list_dir = ["image_fix","image_move","label_fix","label_move"]
        for i in tqdm(range(1,N+1)):
            path_img_f = f"{dir}/{list_dir[0]}/{i}.{ext}"
            path_img_m = f"{dir}/{list_dir[1]}/{i}.{ext}"
            path_label_f = f"{dir}/{list_dir[2]}/{i}.{ext}"
            path_label_m = f"{dir}/{list_dir[3]}/{i}.{ext}"
            ima_f = self.preprocess(path_img_f,True)
            img_m = self.preprocess(path_img_m,True)
            label_f = self.preprocess(path_label_f,False)
            label_m = self.preprocess(path_label_m,False)
            np.savez(f"{dir_save}/{i}.npz", image_fix=ima_f, image_move=img_m, label_fix=label_f,
                     label_move=label_m))

函数调用代码:

   data = mydata(size_resize=512,size_crop=512)
   data.savenpz("../data/trainP/data_ori","../data/trainP/npz_nocrop",41)

数据生成器代码:

调用的albumentations这个数据增强库做的数据增强,这个库也可以嵌入到pytorch中,强烈安利。

class DataGenerator(keras.utils.Sequence):
    'Generates data for Keras'

    def __init__(self,mode='fit',
                 base_path='../data/train/npz',
                 batch_size=32, dim=(360,360),
                 augment=False,random_state=2019, shuffle=True):
        self.dim = dim
        self.batch_size = batch_size
        self.mode = mode
        self.base_path = base_path
        self.augment = augment
        self.shuffle = shuffle
        self.random_state = random_state
        self.paths = glob.glob(base_path+"/*.npz")
        self.length_dataset = len(self.paths)

        self.on_epoch_end()
        np.random.seed(self.random_state)

    def __len__(self):
        'Denotes the number of batches per epoch'
        return int(np.floor(self.length_dataset / self.batch_size))

    def __getitem__(self, index):
        'Generate one batch of data'
        # Generate indexes of the batch
        indexes = self.indexes[index * self.batch_size:(index + 1) * self.batch_size]  # 一批数据的索引,np数组

        if self.mode == 'fit':

            X_fix, X_move, Y_fix, Y_move = self.__generate_XY(indexes)
            if self.augment:
                X_fix, X_move, Y_fix, Y_move = self.__augment_batch(X_fix, X_move, Y_fix, Y_move)
            return [X_fix, X_move,Y_move], Y_fix  #根据不同的目标函数要修改!!!

        elif self.mode == 'predict':
            raise NotImplementedError

        else:
            raise AttributeError('The mode parameter should be set to "fit" or "predict".')

    def on_epoch_end(self):
        'Updates indexes after each epoch'
        self.indexes = np.arange(self.length_dataset)
        if self.shuffle == True:
            np.random.seed(self.random_state)
            np.random.shuffle(self.indexes)

    def __generate_XY(self, indexes):
        'Generates data containing batch_size samples'
        # Initialization
        X_fix = np.empty((self.batch_size, *self.dim, 1))
        X_move = np.empty((self.batch_size, *self.dim, 1))
        Y_fix = np.empty((self.batch_size, *self.dim, 1))
        Y_move = np.empty((self.batch_size, *self.dim, 1))
        # Generate data
        for i, ID in enumerate(indexes):
            npz_path = self.paths[ID]
            data = np.load(npz_path)
            image_fix = data["image_fix"] / 255.0
            image_move = data["image_move"] / 255.0
            mask_fix = data["label_fix"] / 255
            mask_move = data["label_move"] / 255
            # Store samples
            X_fix[i, :, :, 0] = image_fix
            X_move[i, :, :, 0j] = image_move
            Y_fix[i, :, :, 0] = mask_fix
            Y_move[i, :, :, 0] = mask_move
        return X_fix, X_move, Y_fix, Y_move
    def __augment_batch(self, X_fix, X_move, Y_fix, Y_move):
        for i in range(self.batch_size):
            X_fix[i,], X_move[i,], Y_fix[i,], Y_move[i,]= self.__random_transform(X_fix[i,], X_move[i,], Y_fix[i,],
                                                                                   Y_move[i,])

        return X_fix, X_move, Y_fix, Y_move

    def __random_transform(self, x_fix, x_move, y_fix, y_move):
        '''
        翻转:固定图像和浮动图像做一样的
        弹性形变:固定图像和浮动图像分开
        仿射变化:固定图像和浮动图像分开
        :return:
        '''
        aug = albu.Flip(p=0.5)
        augmented = aug(image=x_fix, masks=[x_fix, x_move, y_fix, y_move])
        x_fix_aug, x_move_aug, y_fix_aug, y_move_aug = augmented["masks"]
   
        composition = albu.Compose(
            [albu.ElasticTransform(p=0.8, alpha_affine=5, border_mode=cv2.BORDER_CONSTANT, value=0),
             albu.ShiftScaleRotate(rotate_limit=3, scale_limit=0, shift_limit=0.03, border_mode=cv2.BORDER_CONSTANT,
                                   value=0)
             ])
        composed1 = composition(image=x_fix_aug, mask=y_fix_aug)
        composed2 = composition(image=x_move_aug, mask=y_move_aug)

        x_fix_aug = composed1['image']
        y_fix_aug = composed1['mask']
  
        x_move_aug = composed2['image']
        y_move_aug = composed2['mask']
      

        return x_fix_aug, x_move_aug, y_fix_aug, y_move_aug

测试代码

 generator = DataGenerator(base_path="../data/testP/npz_nocrop", batch_size=8,dim=(512,512), augment=True)
    [X1, X2, Y2], Y1 = generator.__getitem__(0)
    print(X1.shape, X2.shape, Y2.shape, Y1.shape)
    # 查看任意一张固定图片的标签是否正确
    x1 = X1[5,:,:,0]
    x2 = X2[5,:,:,0]
    y1 = Y1[5,:,:,0]
    y2 = Y2[5,:,:,0]

   fig, axes = plt.subplots(2, 2, sharex='col', figsize=(20, 12))
    axes[0,0].imshow(x1,cmap='gray')
    axes[0,0].set_title("image_fix")
    axes[0,1].imshow(y1,cmap='gray')
    axes[0,1].set_title("mask_fix")
    axes[1,0].imshow(x2,cmap='gray')
    axes[1,0].set_title("image_move")
    axes[1,1].imshow(y2,cmap='gray')
    axes[1,1].set_title("mask_move")
    plt.show()

 

你可能感兴趣的:(配准)