deeplung 代码讲解(自己的理解)(我只是稍微贴下数据增强这部分的代码,其它的借鉴wuzeyuan大神的)lung16数据

ps:一直在等wuzeyuan大神的关于deeplung这部分的理解,但是一直没更,最后只能自己上。先汇总一下大神的链接:

1.deeplung(一):https://www.cnblogs.com/wzyuan/p/9618347.html

2.deeplung(二)上:https://www.cnblogs.com/wzyuan/p/9772128.html

3.deeplung(二)下:没更

4.deeplung(三)上:https://www.cnblogs.com/wzyuan/p/9710678.html

5.deeplung(三)下:https://www.cnblogs.com/wzyuan/p/9718779.html

还有一句话上面只是检测模型没有第二步降假阳性模型(天池复赛第一名的还有第三步呢:融合全局信息再次降假阳性)。

下面开始讲数据增强部分代码:

ps:2019年3月14日添加:从来都不是把整个肺部输入模型的,立体数据太大了,显存直接爆掉。这里是clean.npy是整肺,通过DataBowl3Detector类函数的crop和label_mapping.从clean.npy的体数据中截取96*96*96的立体数据和制作对应的立体标签24*24*24*3*5。你最好还是要逐句调试一下看看输入输出的矩阵变换才好理解含义。再次说明的是推荐看DSB2017grt团队的代码和论文。如果近仅看deeplung的代码和论文你就会有很多困惑以及不理解的参数。我是先反复逐句调试deeplung中不理解的地方,还是不太明白(不堪回首花了近一周)。后来是看了DSB2017grt团队的论文回过头来看才豁然开朗。你可以看下后面几篇该论文我的理解下的翻译。

crop是96*96*96块的提取包括含有结节和不含结节(一部分是根据标签提取结节在中心,一部分是产生一个随机数【0.75,1.25】后缩放后提取,还有就是产生随机数,结节不在中心提取)。他的肺结节标签包含两部分:第一部分就是70%裁剪96*96*96的立方体,但是结节并在在正中心,随机的,结节边缘距离立方体边缘大于12体素,还有30%是在整个肺部中随机裁剪的。(jieshaoxiansen老哥指出)。

ps3月22日16:50添加。 我又看了下,你的说法还是存在一些问题。应该是第一部分就是70%裁剪96*96*96的立方体,但是结节随机的大致处在中心,而非都在正中心。其次还要分为放大和缩小的两种。根据结节大小缩放corp_size(比如96要缩放1.2被,那么corp_size=80),再截取80*80*80(如果结节在边缘不足80,用170补全),接着在缩放回96*96*96.另一种30%随机裁剪到96*96*96,这部分认同。

而crop后面的label_mapping是根据截取的方式产生不同的标签。以上是我的理解。

添加结束

1.数据增强这部分一开始我也感觉非常奇怪,都不知道怎么调用很调试的。(python2和python3的区别?)

不管怎样主要代码基本在data.py这个文件里面的。

main.py里调用的data.py一开始就是

data.DataBowl3Detector

这个函数,但这只是初始化执行下面这段代码,具体我基本做了例子的注释。

class DataBowl3Detector(Dataset):
    def __init__(self, data_dir, split_path, config, phase='train', split_comber=None):
        assert(phase == 'train' or phase == 'val' or phase == 'test')
        self.phase = phase
        self.max_stride = config['max_stride']#16
        self.stride = config['stride']#4
        sizelim = config['sizelim']/config['reso']#2.5/1
        sizelim2 = config['sizelim2']/config['reso']#10/1
        sizelim3 = config['sizelim3']/config['reso']#20/1
        self.blacklist = config['blacklist']#['868b024d9fa388b7ddab12ec1c06af38','990fbe3f0a1b53878669967b9afd1441','adc3bbc63d40f8761c59be10f1e504c3']
        self.isScale = config['aug_scale']#True
        self.r_rand = config['r_rand_crop']#0.3
        self.augtype = config['augtype']#{'flip':True,'swap':False,'scale':True,'rotate':False}
        self.pad_value = config['pad_value']#170
        self.split_comber = split_comber
        idcs = split_path # np.load(split_path)#like 'subset0/1.3.6.1.4.1.14519.5.2.1.6279.6001.277445975068759205899107114231'/'subset0/1.3.6.1.4.1.14519.5.2.1.6279.6001.105756658031515062000744821260'/......
        if phase!='test':
            idcs = [f for f in idcs if (f not in self.blacklist)]

        self.filenames = [os.path.join(data_dir, '%s_clean.npy' % idx) for idx in idcs]
        # print self.filenames
        self.kagglenames = [f for f in self.filenames]# if len(f.split('/')[-1].split('_')[0])>20]
        # self.lunanames = [f for f in self.filenames if len(f.split('/')[-1].split('_')[0])<20]
        
        labels = []
        
        print len(idcs)
        for idx in idcs:
            # print data_dir, idx
            l = np.load(data_dir+idx+'_label.npy')
            # print l, os.path.join(data_dir, '%s_label.npy' %idx)
            if np.all(l==0):
                l=np.array([])
            labels.append(l)

        self.sample_bboxes = labels
        if self.phase != 'test':
            self.bboxes = []
            for i, l in enumerate(labels):  #i=799,l=[[26.0866714 156.63848290000004 174.39610872999998 4.277202539],[71.31795929999998 115.68609266000001 109.24151874 23.80291305]]
                # print l
                if len(l) > 0 :
                    for t in l:
                        if t[3]>sizelim:
                            self.bboxes.append([np.concatenate([[i],t])])
                        if t[3]>sizelim2:
                            self.bboxes+=[[np.concatenate([[i],t])]]*2#[[np.concatenate([[i],t])]]*2=: [[array([66, 127.71765339999999, 190.72397560000002, 210.87506653, 25.23320204], dtype=object)], [array([66, 127.71765339999999, 190.72397560000002, 210.87506653, 25.23320204], dtype=object)]]
                        if t[3]>sizelim3:
                            self.bboxes+=[[np.concatenate([[i],t])]]*4
            self.bboxes = np.concatenate(self.bboxes,axis = 0)#列表转为ndarray

        self.crop = Crop(config)
        self.label_mapping = LabelMapping(config, self.phase)
        print(1)

其中比较重要的也只是定义了

self.crop = Crop(config)与self.label_mapping = LabelMapping(config, self.phase)

的初始化代码:

class Crop(object):
    def __init__(self, config):#__init__()方法,在创建一个对象时默认被调用不需要手动调用。__init__(self)中,默认有1个参数名字为self,如果在创建对象时传递了2个实参,那么__init__(self)中出了self作为第一个形参外还需要2个形参,例如__init__(self,x,y),__init__(self)中的self参数,不需要开发者传递,python解释器会自动把当前的对象引用传递进去

        self.crop_size = config['crop_size']#[96, 96, 96]
        self.bound_size = config['bound_size']#12
        self.stride = config['stride']#4
        self.pad_value = config['pad_value']#170

class LabelMapping(object):
    def __init__(self, config, phase):
        self.stride = np.array(config['stride'])#4
        self.num_neg = int(config['num_neg'])#800
        self.th_neg = config['th_neg']#0.02
        self.anchors = np.asarray(config['anchors'])#[5., 10., 20.]
        self.phase = phase
        if phase == 'train':
            self.th_pos = config['th_pos_train']#0.5
        elif phase == 'val':
            self.th_pos = config['th_pos_val']

2.之后就是调用这几个主要类里面的这些def __getitem__、def __len__、def __call__等。反正我觉得很麻烦。

调用发生在了类似for i, (data, target, coord) in enumerate(train_loader):这里所有in enumerate里面。

调用顺序如下代码如下,首先是class DataBowl3Detector(Dataset):里的

    def __getitem__(self, idx,split=None):#我们处理数据时,可能会将数据用类的方式取出来,并且我们在使用这些数据时会使用dataset['key'] 这样的方式取出,此时如果在_getitem_()方法里加入一些需要临时进行的操作(比如图像降采样)可以使数据处理变得灵活。
        t = time.time()
        np.random.seed(int(str(t%1)[2:7]))#seed according to time

        isRandomImg  = False
        if self.phase !='test':
            if idx>=len(self.bboxes):#idx哪里来的是关键
                isRandom = True
                idx = idx%len(self.bboxes)
                isRandomImg = np.random.randint(2)#随机产生0和1
            else:
                isRandom = False
        else:
            isRandom = False
        
        if self.phase != 'test':
            if not isRandomImg:
                bbox = self.bboxes[idx]#self.bboxes第一个参数代表的是第几个CT序列(trainlist排列好的)后面代表的是xyz和直径。当直径大于2.5写1行,大于10写1+2行,大于20写1+2+4行
                filename = self.filenames[int(bbox[0])]#第int(bbox[0])序列的路径(包含名称_clearn.npy)
                imgs = np.load(filename)
                bboxes = self.sample_bboxes[int(bbox[0])]#self.sample_bboxes就是800CT对应的结节的xyz和直径,是一个列表list(可能是空,也可能是多行)
                isScale = self.augtype['scale'] and (self.phase=='train')#True and true
                #imgs是_clearn.npy导入的3d矩阵(仅仅包含肺部并且分辨率为1×1*1mm)/bbox代表[556 74.7895776 249.75681577 267.72357450000004 10.57235285],bbox[1:]代表[74.7895776 249.75681577 267.72357450000004 10.57235285]
                # bboxes代表[[74.7895776 249.75681577 267.72357450000004 10.57235285],[165.8177988 172.30634478 79.42561491000001 20.484109399999998]]等于或包含bbox[1:].
                sample, target, bboxes, coord = self.crop(imgs, bbox[1:], bboxes,isScale,isRandom)
                if self.phase=='train' and not isRandom:
                     sample, target, bboxes, coord = augment(sample, target, bboxes, coord,
                        ifflip = self.augtype['flip'], ifrotate=self.augtype['rotate'], ifswap = self.augtype['swap'])
            else:
                randimid = np.random.randint(len(self.kagglenames))
                filename = self.kagglenames[randimid]
                imgs = np.load(filename)
                bboxes = self.sample_bboxes[randimid]
                isScale = self.augtype['scale'] and (self.phase=='train')
                sample, target, bboxes, coord = self.crop(imgs, [], bboxes,isScale=False,isRand=True)
            #print np.random.randint(2)
            #print sample.shape, target.shape, bboxes.shape
            #print "t:",target
            #print "b:",bboxes
            label = self.label_mapping(sample.shape[1:], target, bboxes, filename)#sample.shape[1:]=: (96, 96, 96)
            sample = (sample.astype(np.float32)-128)/128
            #if filename in self.kagglenames and self.phase=='train':
            #    label[label==-1]=0
            return torch.from_numpy(sample), torch.from_numpy(label), coord
        else:
            imgs = np.load(self.filenames[idx])
            bboxes = self.sample_bboxes[idx]
            nz, nh, nw = imgs.shape[1:]
            pz = int(np.ceil(float(nz) / self.stride)) * self.stride
            ph = int(np.ceil(float(nh) / self.stride)) * self.stride
            pw = int(np.ceil(float(nw) / self.stride)) * self.stride
            imgs = np.pad(imgs, [[0,0],[0, pz - nz], [0, ph - nh], [0, pw - nw]], 'constant',constant_values = self.pad_value)
            
            xx,yy,zz = np.meshgrid(np.linspace(-0.5,0.5,imgs.shape[1]/self.stride),
                                   np.linspace(-0.5,0.5,imgs.shape[2]/self.stride),
                                   np.linspace(-0.5,0.5,imgs.shape[3]/self.stride),indexing ='ij')
            coord = np.concatenate([xx[np.newaxis,...], yy[np.newaxis,...],zz[np.newaxis,:]],0).astype('float32')
            imgs, nzhw = self.split_comber.split(imgs)
            coord2, nzhw2 = self.split_comber.split(coord,
                                                   side_len = self.split_comber.side_len/self.stride,
                                                   max_stride = self.split_comber.max_stride/self.stride,
                                                   margin = self.split_comber.margin/self.stride)
            assert np.all(nzhw==nzhw2)
            imgs = (imgs.astype(np.float32)-128)/128
            return torch.from_numpy(imgs), bboxes, torch.from_numpy(coord2), np.array(nzhw)

上面比较重要的几个点我都做了注释。上很重要的就是这两句:

sample, target, bboxes, coord = self.crop(imgs, [], bboxes,isScale=False,isRand=True)
label = self.label_mapping(sample.shape[1:], target, bboxes, filename)#sample.shape[1:]=: (96, 96, 96)

而1中初始化定义过了self.crop = Crop(config)与self.label_mapping = LabelMapping(config, self.phase)故分别一次调用了class Crop(object):中的:

    def __call__(self, imgs, target, bboxes,isScale=False,isRand=False):#Python中,如果在创建class的时候写了call()方法, 那么该class实例化出实例后, 实例名()就是调用call()方法。
        if isScale:
            radiusLim = [8.,120.]
            scaleLim = [0.75,1.25]
            scaleRange = [np.min([np.max([(radiusLim[0]/target[3]),scaleLim[0]]),1])
                         ,np.max([np.min([(radiusLim[1]/target[3]),scaleLim[1]]),1])]#target代表bbox[1:]代表[74.7895776 249.75681577 267.72357450000004 10.57235285]
            scale = np.random.rand()*(scaleRange[1]-scaleRange[0])+scaleRange[0]#np.random.rand(d0, d1, …, dn)的随机样本位于[0, 1)中.
            crop_size = (np.array(self.crop_size).astype('float')/scale).astype('int')#根据实际结节直径大小调整crop_size大小,target[3]小crop_size变大.target[3]大crop_size变小
            #print(crop_size)
        else:
            crop_size=self.crop_size
        bound_size = self.bound_size
        target = np.copy(target)#目标结节
        bboxes = np.copy(bboxes)#这个ct含有的所以结节
        
        start = []
        for i in range(3):
            if not isRand:
                r = target[3] / 2
                s = np.floor(target[i] - r)+ 1 - bound_size
                e = np.ceil (target[i] + r)+ 1 + bound_size - crop_size[i] 
            else:
                s = np.max([imgs.shape[i+1]-crop_size[i]/2,imgs.shape[i+1]/2+bound_size])
                e = np.min([crop_size[i]/2,              imgs.shape[i+1]/2-bound_size])
                target = np.array([np.nan,np.nan,np.nan,np.nan])
            if s>e:
                start.append(np.random.randint(e,s))#产生一个e到s之间的随机数
            else:
                start.append(int(target[i])-crop_size[i]/2+np.random.randint(-bound_size/2,bound_size/2))#求取结节的3d立方矩阵最靠近原点的点
                
                
        normstart = np.array(start).astype('float32')/np.array(imgs.shape[1:])-0.5#将normstart放到3d块[-0.5:0.5,-0.5:0.5,-0.5:0.5]对应实际[-imgs.shape[1]/2:imgs.shape[1]/2,-imgs.shape[2]/2:imgs.shape[2]/2,-imgs.shape[1]/2:imgs.shape[1]/2,-imgs.shape[3]/2:imgs.shape[3]/2]
        normsize = np.array(crop_size).astype('float32')/np.array(imgs.shape[1:])
        xx,yy,zz = np.meshgrid(np.linspace(normstart[0],normstart[0]+normsize[0],self.crop_size[0]/self.stride),#np.linspace创建等差数列
                           np.linspace(normstart[1],normstart[1]+normsize[1],self.crop_size[1]/self.stride),
                           np.linspace(normstart[2],normstart[2]+normsize[2],self.crop_size[2]/self.stride),indexing ='ij')#可以这么理解,meshgrid函数用两个坐标轴上的点在平面上画网格(3d也是如此).
        coord = np.concatenate([xx[np.newaxis,...], yy[np.newaxis,...],zz[np.newaxis,:]],0).astype('float32')

        pad = []
        pad.append([0,0])
        for i in range(3):
            leftpad = max(0,-start[i])
            rightpad = max(0,start[i]+crop_size[i]-imgs.shape[i+1])
            pad.append([leftpad,rightpad])
        crop = imgs[:,
            max(start[0],0):min(start[0] + crop_size[0],imgs.shape[1]),
            max(start[1],0):min(start[1] + crop_size[1],imgs.shape[2]),
            max(start[2],0):min(start[2] + crop_size[2],imgs.shape[3])]#以结节位置为中心来截取
        crop = np.pad(crop,pad,'constant',constant_values =self.pad_value)#越界进行填充
        for i in range(3):
            target[i] = target[i] - start[i] #目标结节位置已经减去目标结节3d立方矩阵最靠近原点的开始点
        for i in range(len(bboxes)):
            for j in range(3):
                bboxes[i][j] = bboxes[i][j] - start[j] #同一ct的所有结节位置已经减去该目标结节3d立方矩阵最靠近原点的开始点
                
        if isScale:
            with warnings.catch_warnings():
                warnings.simplefilter("ignore")
                crop = zoom(crop,[1,scale,scale,scale],order=1)
            newpad = self.crop_size[0]-crop.shape[1:][0]
            if newpad<0:
                crop = crop[:,:-newpad,:-newpad,:-newpad]#A[:3]就是A[0:3](前闭后开)
            elif newpad>0:
                pad2 = [[0,0],[0,newpad],[0,newpad],[0,newpad]]#补充到1×self.crop_size
                crop = np.pad(crop, pad2, 'constant', constant_values=self.pad_value)
            for i in range(4):
                target[i] = target[i]*scale
            for i in range(len(bboxes)):
                for j in range(4):
                    bboxes[i][j] = bboxes[i][j]*scale
        return crop, target, bboxes, coord#1*96*96*96/可能是[nan,nan,nan,nan]/可能是空/3*24*24*24

和class LabelMapping(object):中的:

    def __call__(self, input_size, target, bboxes, filename):
        stride = self.stride#4
        num_neg = self.num_neg
        th_neg = self.th_neg
        anchors = self.anchors
        th_pos = self.th_pos
        
        output_size = []
        for i in range(3):
            if input_size[i] % stride != 0:
                print filename
            # assert(input_size[i] % stride == 0) 
            output_size.append(input_size[i] / stride)
        
        label = -1 * np.ones(output_size + [len(anchors), 5], np.float32)#: (24, 24, 24, 3, 5)(全是-1的多维矩阵)
        offset = ((stride.astype('float')) - 1) / 2#1.5
        oz = np.arange(offset, offset + stride * (output_size[0] - 1) + 1, stride)#oz = np.arange(1.5,1.5+4*(24 -1)+1,4)#最后的加1只是为了取到93.5这个值.不然一般状态都是前闭后开的状态
        oh = np.arange(offset, offset + stride * (output_size[1] - 1) + 1, stride)
        ow = np.arange(offset, offset + stride * (output_size[2] - 1) + 1, stride)

        for bbox in bboxes:
            for i, anchor in enumerate(anchors):
                iz, ih, iw = select_samples(bbox, anchor, th_neg, oz, oh, ow)#选择iou大于0.02的框
                label[iz, ih, iw, i, 0] = 0#把相关类别标签标签你从-1改为0(最后那维0是类别)最后的3*5相当于(当i=0时,  [[ 0. -1. -1. -1. -1.],[-1. -1. -1. -1. -1.],[-1. -1. -1. -1. -1.]])可以理解为24*24*24个3*5的矩阵罗列起来,不同的只是外面包的中括号层数的不同

        if self.phase == 'train' and self.num_neg > 0:
            neg_z, neg_h, neg_w, neg_a = np.where(label[:, :, :, :, 0] == -1)
            neg_idcs = random.sample(range(len(neg_z)), min(num_neg, len(neg_z)))#slice = random.sample(list, 5)从list中随机获取5个元素,作为一个片断返回
            neg_z, neg_h, neg_w, neg_a = neg_z[neg_idcs], neg_h[neg_idcs], neg_w[neg_idcs], neg_a[neg_idcs]
            label[:, :, :, :, 0] = 0#类别全改为0
            label[neg_z, neg_h, neg_w, neg_a, 0] = -1#产生800个负标签

        if np.isnan(target[0]):
            return label
        iz, ih, iw, ia = [], [], [], []
        for i, anchor in enumerate(anchors):
            iiz, iih, iiw = select_samples(target, anchor, th_pos, oz, oh, ow)#选择iou大于0.5的框
            iz.append(iiz)#可能是[[1, 2, 3], [2, 3, 5, 6]]
            ih.append(iih)
            iw.append(iiw)
            ia.append(i * np.ones((len(iiz),), np.int64))
        iz = np.concatenate(iz, 0)#最后[1 2 3 2 3 5 6]
        ih = np.concatenate(ih, 0)
        iw = np.concatenate(iw, 0)
        ia = np.concatenate(ia, 0)
        flag = True 
        if len(iz) == 0:
            pos = []
            for i in range(3):
                pos.append(max(0, int(np.round((target[i] - offset) / stride))))
            idx = np.argmin(np.abs(np.log(target[3] / anchors)))
            pos.append(idx)
            flag = False
        else:
            idx = random.sample(range(len(iz)), 1)[0]#从中随机的选取一个
            pos = [iz[idx], ih[idx], iw[idx], ia[idx]]
        dz = (target[0] - oz[pos[0]]) / anchors[pos[3]]
        dh = (target[1] - oh[pos[1]]) / anchors[pos[3]]
        dw = (target[2] - ow[pos[2]]) / anchors[pos[3]]
        dd = np.log(target[3] / anchors[pos[3]])
        label[pos[0], pos[1], pos[2], pos[3], :] = [1, dz, dh, dw, dd]#产生一个正标签
        return label        

这其中我在把几个用到的函数贴下:

def augment(sample, target, bboxes, coord, ifflip = True, ifrotate=True, ifswap = True):
    #                     angle1 = np.random.rand()*180
    if ifrotate:
        validrot = False
        counter = 0
        while not validrot:
            newtarget = np.copy(target)
            angle1 = np.random.rand()*180
            size = np.array(sample.shape[2:4]).astype('float')
            rotmat = np.array([[np.cos(angle1/180*np.pi),-np.sin(angle1/180*np.pi)],[np.sin(angle1/180*np.pi),np.cos(angle1/180*np.pi)]])
            newtarget[1:3] = np.dot(rotmat,target[1:3]-size/2)+size/2
            if np.all(newtarget[:3]>target[3]) and np.all(newtarget[:3]< np.array(sample.shape[1:4])-newtarget[3]):
                validrot = True
                target = newtarget
                sample = rotate(sample,angle1,axes=(2,3),reshape=False)
                coord = rotate(coord,angle1,axes=(2,3),reshape=False)
                for box in bboxes:
                    box[1:3] = np.dot(rotmat,box[1:3]-size/2)+size/2
            else:
                counter += 1
                if counter ==3:
                    break
    if ifswap:
        if sample.shape[1]==sample.shape[2] and sample.shape[1]==sample.shape[3]:
            axisorder = np.random.permutation(3)
            sample = np.transpose(sample,np.concatenate([[0],axisorder+1]))
            coord = np.transpose(coord,np.concatenate([[0],axisorder+1]))
            target[:3] = target[:3][axisorder]
            bboxes[:,:3] = bboxes[:,:3][:,axisorder]
            
    if ifflip:
#         flipid = np.array([np.random.randint(2),np.random.randint(2),np.random.randint(2)])*2-1
        flipid = np.array([1,np.random.randint(2),np.random.randint(2)])*2-1
        sample = np.ascontiguousarray(sample[:,::flipid[0],::flipid[1],::flipid[2]])
        coord = np.ascontiguousarray(coord[:,::flipid[0],::flipid[1],::flipid[2]])
        for ax in range(3):
            if flipid[ax]==-1:
                target[ax] = np.array(sample.shape[ax+1])-target[ax]
                bboxes[:,ax]= np.array(sample.shape[ax+1])-bboxes[:,ax]
    return sample, target, bboxes, coord 
def select_samples(bbox, anchor, th, oz, oh, ow):
    z, h, w, d = bbox
    max_overlap = min(d, anchor)
    min_overlap = np.power(max(d, anchor), 3) * th / max_overlap / max_overlap#这句理解不了
    if min_overlap > max_overlap:
        return np.zeros((0,), np.int64), np.zeros((0,), np.int64), np.zeros((0,), np.int64)
    else:
        s = z - 0.5 * np.abs(d - anchor) - (max_overlap - min_overlap)
        e = z + 0.5 * np.abs(d - anchor) + (max_overlap - min_overlap)
        mz = np.logical_and(oz >= s, oz <= e)#逻辑与: (24,)
        iz = np.where(mz)[0]#坐标
        
        s = h - 0.5 * np.abs(d - anchor) - (max_overlap - min_overlap)
        e = h + 0.5 * np.abs(d - anchor) + (max_overlap - min_overlap)
        mh = np.logical_and(oh >= s, oh <= e)
        ih = np.where(mh)[0]
            
        s = w - 0.5 * np.abs(d - anchor) - (max_overlap - min_overlap)
        e = w + 0.5 * np.abs(d - anchor) + (max_overlap - min_overlap)
        mw = np.logical_and(ow >= s, ow <= e)
        iw = np.where(mw)[0]#ow中大于等于s,且小于等于e的数的位置(0开始)

        if len(iz) == 0 or len(ih) == 0 or len(iw) == 0:
            return np.zeros((0,), np.int64), np.zeros((0,), np.int64), np.zeros((0,), np.int64)
        #为了遍历所有中心点xyz的所有可能
        lz, lh, lw = len(iz), len(ih), len(iw)
        iz = iz.reshape((-1, 1, 1))#: (6, 1, 1)
        ih = ih.reshape((1, -1, 1))#: (1, 7, 1)
        iw = iw.reshape((1, 1, -1))#: (1, 1, 7)
        iz = np.tile(iz, (1, lh, lw)).reshape((-1))#Numpy的 tile() 函数,就是将原矩阵横向、纵向地复制。tile 是瓷砖的意思,顾名思义,这个函数就是把数组像瓷砖一样铺展开来。
        ih = np.tile(ih, (lz, 1, lw)).reshape((-1))#: (294,)(6*7*7)
        iw = np.tile(iw, (lz, lh, 1)).reshape((-1))#: (294,)
        # tt1=oz[iz]
        # tt2=oh[ih]
        # tt3=ow[iw]
        centers = np.concatenate([
            oz[iz].reshape((-1, 1)),
            oh[ih].reshape((-1, 1)),
            ow[iw].reshape((-1, 1))], axis = 1)#中心z,y,w 的排列组合(每一行都是一种组合)
        
        r0 = anchor / 2
        s0 = centers - r0
        e0 = centers + r0
        
        r1 = d / 2
        s1 = bbox[:3] - r1
        s1 = s1.reshape((1, -1))
        e1 = bbox[:3] + r1
        e1 = e1.reshape((1, -1))
        
        overlap = np.maximum(0, np.minimum(e0, e1) - np.maximum(s0, s1))#np.maximum:(X, Y, out=None) X 与 Y 逐位比较取其大者;最少接收两个参数
        #
        intersection = overlap[:, 0] * overlap[:, 1] * overlap[:, 2]#两个3d矩阵的交集
        union = anchor * anchor * anchor + d * d * d - intersection#两个3d矩阵的并集

        iou = intersection / union

        mask = iou >= th
        #if th > 0.4:
         #   if np.sum(mask) == 0:
          #      print(['iou not large', iou.max()])
           # else:
            #    print(['iou large', iou[mask]])
        iz = iz[mask]
        ih = ih[mask]
        iw = iw[mask]
        return iz, ih, iw
def collate(batch):
    if torch.is_tensor(batch[0]):
        return [b.unsqueeze(0) for b in batch]
    elif isinstance(batch[0], np.ndarray):
        return batch
    elif isinstance(batch[0], int):
        return torch.LongTensor(batch)
    elif isinstance(batch[0], collections.Iterable):
        transposed = zip(*batch)
        return [collate(samples) for samples in transposed]

3.前面只是大概把逻辑理清楚了。现在我再把这部分到底干了什么说下(当然是我认为的)。

1.前面预处理中已经把整个肺的在三维中切出来了。但切出来的只是一个长方体。补成正方体(补的像素值为170),在一定缩比例内缩放与翻转(扩充数据),最后在裁剪为96*96*96.(上面这几步顺序可能有问题,但总体是这个意思,最后还是能还原回去的)。

2.其实这里最重要的是制作标签。制作了一个24*24*24*3*5的标签。一起看你可能理解不了把这拆开看就ok了。首先24*24*24看成一个立方体,这个立方体的每个像素代表这个像素返回输入那个96*96*96中对应4*4*4的中心位置,3代表3个直径尺度,5代表4+1,4代表xyzd的精确修正值,1代表用来判断前景还是背景(是否是肺结节)。中间一段就是根据实际结节做的标签。

3.里面还有一点困惑就是coord这个值。据说是含有位置信息,把它融合进特征图里的。

4.这个方法好的地方和不好的地方。

好:这个方法可以使运算时候占用的显存大大减小。速度变快很多(1080Ti和1080双显卡(并行运算15G多,还有3G空着))6分钟左右一个eps。

不好:输入尺度太小导致精度不好。多尺度考虑的过少只考虑config['anchors'] = [5., 10., 20.]三种尺度还是不够的。还有很多参数论文里没有提及设置的原因,难以理解。

5.其实这里还有一个问题困扰了我好久,就是这个模型到底是不是3d版的faster-rcnn。

其实答案可以说它是也可以说他不是(faster-rcnn还是重要啊,不理解就会对后面的各种方法理解造成困难,我说的理解faster-rcnn不仅仅从原理上,最重要的是从实际代码上,画出具体到每一步的框图)。说它不是因为这是一个一步检测模型(如ssd等,模型出来一个24*24*24*15的结果,标签也是24*24*24*15。),说它是也可以理解因为这是一个二分类问题所以没有必要再加一个前景分类器。

6.最后是如果你有多gpu最好去看DSB2017第一名的代码和论文。deeplung代码里很多参数的设置可以得到解释。

希望我的理解对你有帮助,祝大家好运!!

目前正在奋战DSB2017第一名的代码和论文对照天池一些队伍的思路修改。

 

你可能感兴趣的:(深度学习)