四、肺癌检测-数据集准备 dsets.py文件

一、目标

数据集准备需要完成以下几个工作:

1. 读取annotations.csv内容;

2. 读取candidates.csv内容;

3. 构造Ct类,用于根据输入的series_uid,获取该uid的CT数据的信息。

4. 构造Dataset类,用于加载数据集。

二、要点说明

1. SimpleITK库

读取和解析CT结果的【mhd】文件需要使用SimpleITK库,可通过【conda install simpleitk】命令安装。

其中主要用到以下几个函数说明如下:

# 读取mhd格式文件,并返回一个mhd对象。
ct_mhd = SimpleITK.ReadImage(path)

# 获取ct_mhd对象的XYZ坐标相对于IRC坐标的原点偏移,类型为1x3数组。
origin_xyz = ct_mhd.GetOrigin()


# 获取ct_mhd对象每个体素在xyz坐标轴的大小,用于转换为IRC坐标时进行尺度缩放。类型为1x3数组
vxSize_xyz = ct_mhd.GetSpacing()


# 获取ct_mhd对象从XYZ转换为IRC坐标时的空间转换矩阵,类型为3x3的eye数组
direction_a = ct_mhd.GetDirection()).reshape(3, 3)

2. functools库

代码中用到了functools库,用于将某些函数的结果缓存到内存中。

@functools.lru_cache(1):代表1次缓存。用于存放在需要缓存的函数定义的代码的开头。意义是:如果该函数之前已经输入过相同的参数,下一次再输入相同参数时,函数直接从缓存调用结果,而不会从新执行函数内部代码。

3. diskcache库

代码中用到了diskcache库,用于将CT数据解析后缓存到磁盘中,使用缓存可以较大的提高训练时数据加载速度。库的使用可参考相关文章:

【编程】Python : diskcache 本地缓存持久化,一行代码_哔哩哔哩_bilibili

 

Python 爬虫进阶篇——diskcache缓存_十先生(公众号:Python知识学堂)的博客-CSDN博客_diskcache python

4. CT文件信息

4.1 csv文件

annotations.csv: 记录实际结节。文件结构: uid, x, y, z, diameter
candidates.csv: 记录候选结节。文件结构: uid, x, y, z, class

注意:两个文件中,相同的uid对应的xyz坐标可能有偏差,要将偏差大于半径的一半(即diameter/4)的数据的diameter强制为0,即认为这个结节异常,不处理。

5. XYZ、IRC坐标轴

5.1 坐标轴方向

CT数据中,有XYZ坐标轴,训练时需要转换为IRC坐标轴,两个坐标轴分别对应着:

xyz:各坐标轴正的方向指向的人体的方向为为:
x:左手,y:后背,z:头顶

irc:各坐标轴正的方向指向的人体的方向为为:
i:头顶,r:后背,c:左手

其中i-index,r-row,  c-column

简记为:xyz-左后上,irc-上后左


5.2 坐标轴转换

5.2.1 irc转xyz

step1:将irc矩阵翻转为cri

step2:用体素大小缩放cri坐标

step3:缩放后的cri坐标与空间矩阵叉乘得到xyz坐标

step4:xyz坐标加上原点偏移量。

def irc2xyz(coord_irc, origin_xyz, vxSize_xyz, direction_a):
    """
    irc坐标转为xyz坐标
    step1:将irc矩阵翻转为cri
    step2:用体素大小缩放cri坐标
    step3:缩放后的cri坐标与空间矩阵叉乘得到xyz坐标
    step4:xyz坐标加上原点偏移量。
    
    :param coord_irc: irc坐标
    :param origin_xyz: irc坐标相对于xyz的坐标偏移
    :param vxSize_xyz: 体素在xyz尺度的大小
    :param direction_a: 空间矩阵
    :return: 
    """
    cri_a = np.array(coord_irc)[::-1]
    origin_a = np.array(origin_xyz)
    vxSize_a = np.array(vxSize_xyz)
    coords_xyz = (direction_a @ (cri_a * vxSize_a)) + origin_a
    # coords_xyz = (direction_a @ (idx * vxSize_a)) + origin_a
    return XyzTuple(*coords_xyz)

5.2.2 xyz转irc 

def xyz2irc(coord_xyz, origin_xyz, vxSize_xyz, direction_a):
    origin_a = np.array(origin_xyz)
    vxSize_a = np.array(vxSize_xyz)
    coord_a = np.array(coord_xyz)
    cri_a = ((coord_a - origin_a) @ np.linalg.inv(direction_a)) / vxSize_a
    cri_a = np.round(cri_a)
    return IrcTuple(int(cri_a[2]), int(cri_a[1]), int(cri_a[0]))

5.3 CT数据单位 

CT文件中数据单位为HU(HounsField Units,亨氏单位)。其中人体各组织的HU值水平为:

空气:-1000HU,约0g/cm3

水:0HU,约1g/cm3

骨骼:1000HU,约2~3g/cm3。

因此超出-1000HU到1000HU外的数据并不是我们需要关心的数据,可强制转换为限值。

5.4 体素、结节概念

体素:可理解为CT扫描后得到的三维切片矩阵中所对应的一个点(像素),即切片后最小的人体组织,接三维的立体像素。

结节:可能为恶性也可能是良性,CT扫描后可根据体素的尺寸,结节中心坐标,结节直径截取出结节所对应的坐标值已经HU值。

6. 数据可视化

下图第一行是对CT文件中,三维CT矩阵用不同维度索引下的结果;

下图第二行是对某个结节中,三维结节矩阵用不同维度索引下的结果。

更多可视化内容可参照原书代码的ipynb文件。

四、肺癌检测-数据集准备 dsets.py文件_第1张图片

三、函数说明

1. getCandidateInfoList

candidateInfo_list = getCandidateInfoList(requireOnDisk_bool=True)
返回candidates.csv文件对应的list,其中每个元素为名称为candidateInfoTuple的元组,元组有如下节点:
class, diameter, id, xyz

2. Ct类

属性如下:

CT.hu_a:以HU为单位的三维array,存储的是CT的所有体素数据。

CT.origin_xyz:xyz坐标和irc坐标的原点偏移量

CT.vzSize_xyz:体素在xyz坐标轴的尺度大小

CT.direction_a:体素的空间矩阵

CT.getRawCandidate函数

ct_chunk, center_irc = getRawCandidate(center_xyz, width_irc)

center_xyz:结节在xyz坐标系的坐标值。

width_irc:体素在irc坐标系的尺度大小。

ct_chunk:结节在irc坐标轴的HU值的三维矩阵。

center_irc:结节中心在irc坐标系的坐标值。

3. LunaDataset类

ds = LunaDataset(val_stride=0, isValSet_bool=False, series_uid=None)

val_stride:作为验证集时,从数据集中抽取样本作为验证集样本的步长。即每隔val_stride抽取一个样本作为验证集样本。

isValSet_bool:是否作为验证集。

series_uid:获取某个uid对应的所有样本。

四、代码

1. 原书代码

书中代码【dsets.py】如下:

import copy
import csv
import functools
import glob
import os

from collections import namedtuple

import SimpleITK as sitk
import numpy as np

import torch
import torch.cuda
from torch.utils.data import Dataset

from util.disk import getCache
from util.util import XyzTuple, xyz2irc
from util.logconf import logging

log = logging.getLogger(__name__)
# log.setLevel(logging.WARN)
# log.setLevel(logging.INFO)
log.setLevel(logging.DEBUG)

raw_cache = getCache('part2ch10_raw')

CandidateInfoTuple = namedtuple(
    'CandidateInfoTuple',
    'isNodule_bool, diameter_mm, series_uid, center_xyz',
)

@functools.lru_cache(1)
def getCandidateInfoList(requireOnDisk_bool=True):
    # We construct a set with all series_uids that are present on disk.
    # This will let us use the data, even if we haven't downloaded all of
    # the subsets yet.
    mhd_list = glob.glob('data-unversioned/part2/luna/subset*/*.mhd')
    presentOnDisk_set = {os.path.split(p)[-1][:-4] for p in mhd_list}

    diameter_dict = {}
    with open('data/part2/luna/annotations.csv', "r") as f:
        for row in list(csv.reader(f))[1:]:
            series_uid = row[0]
            annotationCenter_xyz = tuple([float(x) for x in row[1:4]])
            annotationDiameter_mm = float(row[4])

            diameter_dict.setdefault(series_uid, []).append(
                (annotationCenter_xyz, annotationDiameter_mm)
            )

    candidateInfo_list = []
    with open('data/part2/luna/candidates.csv', "r") as f:
        for row in list(csv.reader(f))[1:]:
            series_uid = row[0]

            if series_uid not in presentOnDisk_set and requireOnDisk_bool:
                continue

            isNodule_bool = bool(int(row[4]))
            candidateCenter_xyz = tuple([float(x) for x in row[1:4]])

            candidateDiameter_mm = 0.0
            for annotation_tup in diameter_dict.get(series_uid, []):
                annotationCenter_xyz, annotationDiameter_mm = annotation_tup
                for i in range(3):
                    delta_mm = abs(candidateCenter_xyz[i] - annotationCenter_xyz[i])
                    if delta_mm > annotationDiameter_mm / 4:
                        break
                else:
                    candidateDiameter_mm = annotationDiameter_mm
                    break

            candidateInfo_list.append(CandidateInfoTuple(
                isNodule_bool,
                candidateDiameter_mm,
                series_uid,
                candidateCenter_xyz,
            ))

    candidateInfo_list.sort(reverse=True)
    return candidateInfo_list

class Ct:
    def __init__(self, series_uid):
        mhd_path = glob.glob(
            'data-unversioned/part2/luna/subset*/{}.mhd'.format(series_uid)
        )[0]

        ct_mhd = sitk.ReadImage(mhd_path)
        ct_a = np.array(sitk.GetArrayFromImage(ct_mhd), dtype=np.float32)

        # CTs are natively expressed in https://en.wikipedia.org/wiki/Hounsfield_scale
        # HU are scaled oddly, with 0 g/cc (air, approximately) being -1000 and 1 g/cc (water) being 0.
        # The lower bound gets rid of negative density stuff used to indicate out-of-FOV
        # The upper bound nukes any weird hotspots and clamps bone down
        ct_a.clip(-1000, 1000, ct_a)

        self.series_uid = series_uid
        self.hu_a = ct_a

        self.origin_xyz = XyzTuple(*ct_mhd.GetOrigin())
        self.vxSize_xyz = XyzTuple(*ct_mhd.GetSpacing())
        self.direction_a = np.array(ct_mhd.GetDirection()).reshape(3, 3)

    def getRawCandidate(self, center_xyz, width_irc):
        center_irc = xyz2irc(
            center_xyz,
            self.origin_xyz,
            self.vxSize_xyz,
            self.direction_a,
        )

        slice_list = []
        for axis, center_val in enumerate(center_irc):
            start_ndx = int(round(center_val - width_irc[axis]/2))
            end_ndx = int(start_ndx + width_irc[axis])

            assert center_val >= 0 and center_val < self.hu_a.shape[axis], repr([self.series_uid, center_xyz, self.origin_xyz, self.vxSize_xyz, center_irc, axis])

            if start_ndx < 0:
                # log.warning("Crop outside of CT array: {} {}, center:{} shape:{} width:{}".format(
                #     self.series_uid, center_xyz, center_irc, self.hu_a.shape, width_irc))
                start_ndx = 0
                end_ndx = int(width_irc[axis])

            if end_ndx > self.hu_a.shape[axis]:
                # log.warning("Crop outside of CT array: {} {}, center:{} shape:{} width:{}".format(
                #     self.series_uid, center_xyz, center_irc, self.hu_a.shape, width_irc))
                end_ndx = self.hu_a.shape[axis]
                start_ndx = int(self.hu_a.shape[axis] - width_irc[axis])

            slice_list.append(slice(start_ndx, end_ndx))

        ct_chunk = self.hu_a[tuple(slice_list)]

        return ct_chunk, center_irc


@functools.lru_cache(1, typed=True)
def getCt(series_uid):
    return Ct(series_uid)

@raw_cache.memoize(typed=True)
def getCtRawCandidate(series_uid, center_xyz, width_irc):
    ct = getCt(series_uid)
    ct_chunk, center_irc = ct.getRawCandidate(center_xyz, width_irc)
    return ct_chunk, center_irc

class LunaDataset(Dataset):
    def __init__(self,
                 val_stride=0,
                 isValSet_bool=None,
                 series_uid=None,
            ):
        self.candidateInfo_list = copy.copy(getCandidateInfoList())

        if series_uid:
            self.candidateInfo_list = [
                x for x in self.candidateInfo_list if x.series_uid == series_uid
            ]

        if isValSet_bool:
            assert val_stride > 0, val_stride
            self.candidateInfo_list = self.candidateInfo_list[::val_stride]
            assert self.candidateInfo_list
        elif val_stride > 0:
            del self.candidateInfo_list[::val_stride]
            assert self.candidateInfo_list

        log.info("{!r}: {} {} samples".format(
            self,
            len(self.candidateInfo_list),
            "validation" if isValSet_bool else "training",
        ))

    def __len__(self):
        return len(self.candidateInfo_list)

    def __getitem__(self, ndx):
        candidateInfo_tup = self.candidateInfo_list[ndx]
        width_irc = (32, 48, 48)

        candidate_a, center_irc = getCtRawCandidate(
            candidateInfo_tup.series_uid,
            candidateInfo_tup.center_xyz,
            width_irc,
        )

        candidate_t = torch.from_numpy(candidate_a)
        candidate_t = candidate_t.to(torch.float32)
        candidate_t = candidate_t.unsqueeze(0)

        pos_t = torch.tensor([
                not candidateInfo_tup.isNodule_bool,
                candidateInfo_tup.isNodule_bool
            ],
            dtype=torch.long,
        )

        return (
            candidate_t,
            pos_t,
            candidateInfo_tup.series_uid,
            torch.tensor(center_irc),
        )

2. 我注释的代码

import functools
import glob
import os.path
import csv
import SimpleITK as sitk
import numpy as np
import copy

import torch
import torch.cuda
from torch.utils.data import Dataset

from collections import namedtuple

from util.disk import getCache
from util.util import XyzTuple, xyz2irc
from util.logconf import logging

log = logging.getLogger(__name__)
log.setLevel(logging.DEBUG)

# annotations.csv: 记录实际结节。文件结构: uid, x, y, z, diameter
# candidates.csv: 记录候选结节。文件结构: uid, x, y, z, class

raw_cache = getCache('part2ch10_raw')


# 构建用于存储候选结节的元组, 结构: class, diameter, id, xyz
candidateInfoTuple = namedtuple('candidateInfoTuple',
                                'isNodule_bool, diameter_mm, series_uid, center_xyz')

@functools.lru_cache(1)     # 缓存一次调用结果
def getCandidateInfoList(requireOnDisk_bool=True):
    """
    加载annotations.csv和candidates.csv,分别存到diameter_list和candidateInfo_list
    :param      requireOnDisk_bool. 如果文件不存在,是否跳过
    :return     candidateInfo_list. 由candidateInfoTuple构成的list
    """
    mhd_list = glob.glob('data-unversioned/part2/luna/subset*/*.mhd')
    presentOnDisk_set = {os.path.split(p)[-1][:-4] for p in mhd_list}        # 提取所有文件名,即uid

    diameter_dict= {}
    with open('data/part2/luna/annotations.csv', 'r') as f:
        for row in list(csv.reader(f))[1:]:
            series_uid = row[0]
            annotationCenter_xyz = tuple([float(x) for x in row[1:4]])
            annotationDiameter_mm = float(row[4])

            diameter_dict.setdefault(series_uid, []).append(
                (annotationCenter_xyz, annotationDiameter_mm)
            )

    candidateInfo_list = []
    with open('data/part2/luna/candidates.csv', 'r') as f:
        for row in list(csv.reader(f))[1:]:
            series_uid = row[0]

            # 如果annotations.csv中找不到这个id,则跳过
            if series_uid not in presentOnDisk_set and requireOnDisk_bool:
                continue

            candidateDiameter_xyz = tuple([float(x) for x in row[1:4]])
            isNodule_bool = bool(int(row[4]))

            # 如果candidate中的xyz坐标和annotation中的xyz坐标偏差大于半径的一半,
            # 则认为它们不是同一个节点,将直接用零代替,即认为这不是结节
            candidateDiameter_mm = 0.0
            for annotation_tup in diameter_dict.get(series_uid, []):
                annotation_xyz, annotationDiameter_mm = annotation_tup
                for i in range(3):
                    delta_mm = abs(candidateDiameter_xyz[i] - annotation_xyz[i])
                    if delta_mm > annotationDiameter_mm/4:
                        break
                    else:
                        candidateDiameter_mm = annotationDiameter_mm
                        break

            candidateInfo_list.append(candidateInfoTuple(
                isNodule_bool,
                candidateDiameter_mm,
                series_uid,
                candidateDiameter_xyz,
            ))

    candidateInfo_list.sort(reverse=True)
    return candidateInfo_list


class Ct:
    def __init__(self, series_uid):
        mhd_path = glob.glob(r'data-unversioned/part2/luna/subset*/{}.mhd'.format(series_uid))[0]

        # 用SampleSTK包可直接读取CT扫描数据
        ct_mhd = sitk.ReadImage(mhd_path)

        # HU: 亨氏单位,Hounsfield Unit.
        # 空气为-1000 HU,约等于0 g/cm3. 水为0 HU,约等于1 g/cm3, 骨骼至少时1000HU,约等于2~3g/cm3
        ct_a = np.array(sitk.GetArrayFromImage(ct_mhd), dtype=np.float32)   # 读取到的数据单位为HU
        # 将数据限定再-1000~1000 HU
        ct_a.clip(-1000, 1000, ct_a)
        self.series_uid = series_uid
        self.hu_a = ct_a

        self.origin_xyz = XyzTuple(*ct_mhd.GetOrigin())     # xyz坐标和irc坐标的原点偏移量
        self.vxSize_xyz = XyzTuple(*ct_mhd.GetSpacing())    # 体素在xyz坐标轴的大小
        self.direction_a = np.array(ct_mhd.GetDirection()).reshape(3, 3)    # 体素方向矩阵,等于eye(3)

    def getRawCandidate(self, center_xyz, width_irc):
        """
        根据xyz坐标算出病人坐标irc。然后根据每个结节的irc和体素宽度,算出结节包含的体素块数据
        :param center_xyz: 结节的xyz坐标
        :param width_irc: 体素宽度
        :return ct_chunk: 结节包含的体素块的HU值,array
        :return center_irc: 结节的病人坐标信息
        """
        center_irc = xyz2irc(
            center_xyz,
            self.origin_xyz,
            self.vxSize_xyz,
            self.direction_a
        )

        slice_list = []
        for axis, center_val in enumerate(center_irc):
            start_ndx = int(round(center_val - width_irc[axis]/2))
            end_ndx = int(start_ndx + width_irc[axis])

            assert center_val >= 0 and center_val < self.hu_a.shape[axis], repr([self.series_uid, center_xyz, self.origin_xyz, self.vxSize_xyz, center_irc, axis])

            if start_ndx < 0:
                # log.warning("Crop outside of CT array: {} {}, center:{} shape:{} width:{}".format(
                #     self.series_uid, center_xyz, center_irc, self.hu_a.shape, width_irc))
                start_ndx = 0
                end_ndx = int(width_irc[axis])

            if end_ndx > self.hu_a.shape[axis]:
                # log.warning("Crop outside of CT array: {} {}, center:{} shape:{} width:{}".format(
                #     self.series_uid, center_xyz, center_irc, self.hu_a.shape, width_irc))
                end_ndx = self.hu_a.shape[axis]
                start_ndx = int(self.hu_a.shape[axis] - width_irc[axis])

            slice_list.append(slice(start_ndx, end_ndx))

        ct_chunk = self.hu_a[tuple(slice_list)]

        return ct_chunk, center_irc


@functools.lru_cache(1, typed=True)     # 保留一次缓存结果
def getCt(series_uid):
    return Ct(series_uid)


@raw_cache.memoize(typed=True)      # 数据缓存到同路径的cache文件夹下
def getCtRawCandidate(series_uid, center_xyz, width_irc):
    ct = getCt(series_uid)
    ct_chunk, center_irc = ct.getRawCandidate(center_xyz, width_irc)
    return ct_chunk, center_irc


class LunaDataset(Dataset):
    def __init__(self, val_stride=0, isValSet_bool=False, series_uid=None):
        """
        val_stride:作为验证集时,从数据集中抽取样本作为验证集样本的步长。即每隔val_stride抽取一个样本作为验证集样本。
        isValSet_bool:是否作为验证集。
        series_uid:获取某个uid对应的所有样本。   
        """
        self.candidateInfo_list = copy.copy(getCandidateInfoList())

        if series_uid:
            self.candidateInfo_list = [x for x in self.candidateInfo_list if x.series_uid==series_uid]

        if isValSet_bool:
            assert val_stride > 0, val_stride
            self.candidateInfo_list = self.candidateInfo_list[::val_stride]
            assert self.candidateInfo_list
        elif val_stride > 0:
            del self.candidateInfo_list[::val_stride]
            assert self.candidateInfo_list

        log.info("(!r): {} {} samples".format(
            self,
            len(self.candidateInfo_list),
            "validation" if isValSet_bool else "training",
        ))

    def __len__(self):
        return len(self.candidateInfo_list)

    def __getitem__(self, ndx):
        """
        返回指定索引对应的结节信息
        :param ndx: 某个ct数据中的第ndx个结节索引
        :return: candidate_t. 结节所包含的所有体素的三位数组。t代表数组时个tensor
        :return: post_t. 结节是否为肿瘤。0代表不是,1代表肿瘤。
        :return: series_uid. ndx所对应的结节uid
        :return: center_irc. 结节的重心坐标。类型为tensor
        """
        candidateInfo_tup = self.candidateInfo_list[ndx]
        width_irc = (32, 48, 48)

        candidate_a, center_irc = getCtRawCandidate(
            candidateInfo_tup.series_uid,
            candidateInfo_tup.center_xyz,
            width_irc,
        )

        candidate_t = torch.from_numpy(candidate_a)
        candidate_t = candidate_t.to(torch.float32)
        candidate_t = candidate_t.unsqueeze(0)

        post_t = torch.tensor([
            not candidateInfo_tup.isNodule_bool,
            candidateInfo_tup.isNodule_bool
            ],
            dtype=torch.long,
        )

        return (
            candidate_t,
            post_t,
            candidateInfo_tup.series_uid,
            torch.tensor(center_irc)
        )

你可能感兴趣的:(#,肺癌检测,深度学习,python,人工智能)