pix2pix-gan医疗图像分割迁移

pix2pix-gan做医学图像合成的时候,如果把nii数据转成png格式会损失很多信息,以为png格式图像的灰度值有256阶,因此直接使用nii的医学图像做输入会更好一点。

但是Pythorch中的Dataloader是不能直接读取nii图像的,因此加一个CreateNiiDataset的类。

先来了解一下pytorch中读取数据的主要途径——Dataset类。在自己构建数据层时都要基于这个类,类似于C++中的虚基类。

自己构建的数据层包含三个部分

1

2

3

4

5

6

7

8

9

10

11

12

class Dataset(object):

"""An abstract class representing a Dataset.

All other datasets should subclass it. All subclasses should override

``__len__``, that provides the size of the dataset, and ``__getitem__``,

supporting integer indexing in range from 0 to len(self) exclusive.

"""

def __getitem__(self, index):

 raise NotImplementedError

def __len__(self):

 raise NotImplementedError

def __add__(self, other):

 return ConcatDataset([self, other])

根据自己的需要编写CreateNiiDataset子类:

因为我是基于https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix

做pix2pix-gan的实验,数据包含两个部分mr 和 ct,不需要标签,因此上面的 def getitem(self, index):中不需要index这个参数了,类似地,根据需要,加入自己的参数,去掉不需要的参数。

1

2

3

4

5

6

7

8

9

10

11

12

13

14

15

16

17

18

19

20

21

22

23

24

25

26

27

28

29

30

31

32

33

34

35

36

37

38

39

40

41

42

43

44

45

46

47

48

49

50

51

52

53

54

55

56

57

58

59

60

class CreateNiiDataset(Dataset):

 def __init__(self, opt, transform = None, target_transform = None):

  self.path1 = opt.dataroot # parameter passing

  self.A = 'MR'

  self.B = 'CT'

  lines = os.listdir(os.path.join(self.path1, self.A))

  lines.sort()

  imgs = []

  for line in lines:

   imgs.append(line)

  self.imgs = imgs

  self.transform = transform

  self.target_transform = target_transform

 def crop(self, image, crop_size):

  shp = image.shape

  scl = [int((shp[0] - crop_size[0]) / 2), int((shp[1] - crop_size[1]) / 2)]

  image_crop = image[scl[0]:scl[0] + crop_size[0], scl[1]:scl[1] + crop_size[1]]

  return image_crop

 def __getitem__(self, item):

  file = self.imgs[item]

  img1 = sitk.ReadImage(os.path.join(self.path1, self.A, file))

  img2 = sitk.ReadImage(os.path.join(self.path1, self.B, file))

  data1 = sitk.GetArrayFromImage(img1)

  data2 = sitk.GetArrayFromImage(img2)

  if data1.shape[0] != 256:

   data1 = self.crop(data1, [256, 256])

   data2 = self.crop(data2, [256, 256])

  if self.transform is not None:

   data1 = self.transform(data1)

   data2 = self.transform(data2)

  if np.min(data1)<0:

   data1 = (data1 - np.min(data1))/(np.max(data1)-np.min(data1))

  if np.min(data2)<0:

   #data2 = data2 - np.min(data2)

   data2 = (data2 - np.min(data2))/(np.max(data2)-np.min(data2))

  data = {}

  data1 = data1[np.newaxis, np.newaxis, :, :]

  data1_tensor = torch.from_numpy(np.concatenate([data1,data1,data1], 1))

  data1_tensor = data1_tensor.type(torch.FloatTensor)

  data['A'] = data1_tensor # should be a tensor in Float Tensor Type

  data2 = data2[np.newaxis, np.newaxis, :, :]

  data2_tensor = torch.from_numpy(np.concatenate([data2,data2,data2], 1))

  data2_tensor = data2_tensor.type(torch.FloatTensor)

  data['B'] = data2_tensor # should be a tensor in Float Tensor Type

  data['A_paths'] = [os.path.join(self.path1, self.A, file)] # should be a list, with path inside

  data['B_paths'] = [os.path.join(self.path1, self.B, file)]

  return data

 def load_data(self):

  return self

 def __len__(self):

  return len(self.imgs)

注意:最后输出的data是一个字典,里面有四个keys=[‘A',‘B',‘A_paths',‘B_paths'], 一定要注意数据要转成FloatTensor。

其次是data[‘A_paths'] 接收的值是一个list,一定要加[ ] 扩起来,要不然测试存图的时候会有问题,找这个问题找了好久才发现。

然后直接在train.py的主函数里面把数据加载那行改掉就好了

data_loader = CreateNiiDataset(opt)
dataset = data_loader.load_data()

Over!

补充知识:nii格式图像存为npy格式

我就废话不多说了,大家还是直接看代码吧!

1

2

3

4

5

6

7

8

9

10

11

12

13

14

15

16

17

18

19

20

21

22

23

import nibabel as nib

import os

import numpy as np

  

img_path = '/home/lei/train/img/'

seg_path = '/home/lei/train/seg/'

saveimg_path = '/home/lei/train/npy_img/'

saveseg_path = '/home/lei/train/npy_seg/'

  

img_names = os.listdir(img_path)

seg_names = os.listdir(seg_path)

  

for img_name in img_names:

 print(img_name)

 img = nib.load(img_path + img_name).get_data() #载入

 img = np.array(img)

 np.save(saveimg_path + str(img_name).split('.')[0] + '.npy', img) #保存

  

for seg_name in seg_names:

 print(seg_name)

 seg = nib.load(seg_path + seg_name).get_data()

 seg = np.array(seg)

 np.save(saveseg_path + str(seg_name).split('.')[0] + '.npy

你可能感兴趣的:(文章杂谈,python)