vnet(Github)学习总结

Tensorflow可以使用feed_dict的方式输入数据,但是效率比较低。Tensorflow提供了一个内置函数可以利用输入管道的方式输入数据。

tf.data.Dataset()接收numpy和tensor类型的数据


Dataset

Dataset()可以接收多个输入,当数据由特征和标签组成时,使用起来及其方便。

image_paths = ['特征路径']
label_paths = ['标签路径']
dataset = tf.data.Dataset.from_tensor_slices((image_paths, label_paths))

结果:

>>b'('特征路径', '标签路径')'

当输入为string时,使用form_tensor_slices()得到的结果是bytes类型,可能需要decode('utf-8')


除了加载数据方便外,dataset还可以做数据转换。
dataset.map()接收一个函数,Dataset中的每个元素都会被当作这个函数的输入,并将函数返回值作为新的Dataset

dataset = tf.data.Dataset.from_tensor_slices((image_paths, label_paths))
dataset = dataset.map(lambda image_path, label_path:
                              tuple(tf.py_func(input_parser, [image_path, label_path], [tf.float32, tf.float32])))

使用tf.py_func()input_parser()变为一个tensorflow内置函数,第二个参数表示输入数据,第三个参数表示输出数据。


在使用dataset时,先要创建一个迭代器,然后使用get_next()获取数据。

iter = dataset.make_initializable_iterator()
el = iter.get_next()
with tf.Session() as sess:
    sess.run(iter.initializer)
    print(sess.run(el))

如果使用多个print()时,iter可以自动进行迭代。


加载数据时的小技巧

对于V-Net而言,当训练网络时,必须要提供一个和输入大小相等的tensor作为标签,这个可以直接加载特征和标签来完成。当为非训练状态时,可以生成一个和原特征大小相同的label进行占位。

if train:
    label = read_image(label_path.decode("utf-8"))
else:
    label = sitk.Image(image.GetSize(),sitk.sitkUInt32)
    label.SetOrigin(image.GetOrigin())
    label.SetSpacing(image.GetSpacing())

在SimpleITK中,图像作为物理对象占据一个空间有界区域,通过上述方法生成一个和image相同大小的label。

关于SimpleITk可参:http://insightsoftwareconsortium.github.io/SimpleITK-Notebooks/Python_html/03_Image_Details.html


数据增强的方法

按照作者的说法:医学图像通常比较耗费内存,可以对图像进行0-255的标准化,对于较小的输入image可以进行Padding,还可以从3D图形中随机选择一个区域作为网络输入,还可以对图像添加噪声。


Normalization
resacleFilter = sitk.RescaleIntensityImageFilter()
resacleFilter.SetOutputMaximum(255)
resacleFilter.SetOutputMinimum(0)
image = resacleFilter.Execute(image)

RandomCrop

随机从输入图像中采集一个zone,通常可以用来进行数据增强(一般只用于训练阶段)。
先判断zone和image的大小,如果zone的size小于image的size,就将下标置为0~image_size-zone_size。这里要注意的一点就是在对label进行randomCrop时,每次必须保证将包含标签的zone提取出来。

   while not contain_label: 
      # get the start crop coordinate in ijk
      if size_old[0] <= size_new[0]:
        start_i = 0
      else:
        start_i = np.random.randint(0, size_old[0]-size_new[0])

      if size_old[1] <= size_new[1]:
        start_j = 0
      else:
        start_j = np.random.randint(0, size_old[1]-size_new[1])

      if size_old[2] <= size_new[2]:
        start_k = 0
      else:
        start_k = np.random.randint(0, size_old[2]-size_new[2])

      roiFilter.SetIndex([start_i,start_j,start_k])

      label_crop = roiFilter.Execute(label)
      statFilter = sitk.StatisticsImageFilter()
      statFilter.Execute(label_crop)

      # will iterate until a sub volume containing label is extracted
      # pixel_count = seg_crop.GetHeight()*seg_crop.GetWidth()*seg_crop.GetDepth()
      # if statFilter.GetSum()/pixel_count

训练


此处作者使用了PReLU,也就是Parametric Leaky Relu,是何凯明提出的一种改进ReLU。表达式:
y = max(0, x) + a * min(0, x)
其中的a是可学习参数,当a为非零较小数时,相当于LeakyReLU;当a为零时,等价于ReLU。

你可能感兴趣的:(vnet(Github)学习总结)