数据集准备需要完成以下几个工作:
1. 读取annotations.csv内容;
2. 读取candidates.csv内容;
3. 构造Ct类,用于根据输入的series_uid,获取该uid的CT数据的信息。
4. 构造Dataset类,用于加载数据集。
读取和解析CT结果的【mhd】文件需要使用SimpleITK库,可通过【conda install simpleitk】命令安装。
其中主要用到以下几个函数说明如下:
# 读取mhd格式文件,并返回一个mhd对象。
ct_mhd = SimpleITK.ReadImage(path)
# 获取ct_mhd对象的XYZ坐标相对于IRC坐标的原点偏移,类型为1x3数组。
origin_xyz = ct_mhd.GetOrigin()
# 获取ct_mhd对象每个体素在xyz坐标轴的大小,用于转换为IRC坐标时进行尺度缩放。类型为1x3数组
vxSize_xyz = ct_mhd.GetSpacing()
# 获取ct_mhd对象从XYZ转换为IRC坐标时的空间转换矩阵,类型为3x3的eye数组
direction_a = ct_mhd.GetDirection()).reshape(3, 3)
代码中用到了functools库,用于将某些函数的结果缓存到内存中。
@functools.lru_cache(1):代表1次缓存。用于存放在需要缓存的函数定义的代码的开头。意义是:如果该函数之前已经输入过相同的参数,下一次再输入相同参数时,函数直接从缓存调用结果,而不会从新执行函数内部代码。
代码中用到了diskcache库,用于将CT数据解析后缓存到磁盘中,使用缓存可以较大的提高训练时数据加载速度。库的使用可参考相关文章:
【编程】Python : diskcache 本地缓存持久化,一行代码_哔哩哔哩_bilibili
Python 爬虫进阶篇——diskcache缓存_十先生(公众号:Python知识学堂)的博客-CSDN博客_diskcache python
annotations.csv: 记录实际结节。文件结构: uid, x, y, z, diameter candidates.csv: 记录候选结节。文件结构: uid, x, y, z, class
注意:两个文件中,相同的uid对应的xyz坐标可能有偏差,要将偏差大于半径的一半(即diameter/4)的数据的diameter强制为0,即认为这个结节异常,不处理。
CT数据中,有XYZ坐标轴,训练时需要转换为IRC坐标轴,两个坐标轴分别对应着:
xyz:各坐标轴正的方向指向的人体的方向为为:
x:左手,y:后背,z:头顶
irc:各坐标轴正的方向指向的人体的方向为为:
i:头顶,r:后背,c:左手
其中i-index,r-row, c-column
简记为:xyz-左后上,irc-上后左
5.2.1 irc转xyz
step1:将irc矩阵翻转为cri
step2:用体素大小缩放cri坐标
step3:缩放后的cri坐标与空间矩阵叉乘得到xyz坐标
step4:xyz坐标加上原点偏移量。
def irc2xyz(coord_irc, origin_xyz, vxSize_xyz, direction_a):
"""
irc坐标转为xyz坐标
step1:将irc矩阵翻转为cri
step2:用体素大小缩放cri坐标
step3:缩放后的cri坐标与空间矩阵叉乘得到xyz坐标
step4:xyz坐标加上原点偏移量。
:param coord_irc: irc坐标
:param origin_xyz: irc坐标相对于xyz的坐标偏移
:param vxSize_xyz: 体素在xyz尺度的大小
:param direction_a: 空间矩阵
:return:
"""
cri_a = np.array(coord_irc)[::-1]
origin_a = np.array(origin_xyz)
vxSize_a = np.array(vxSize_xyz)
coords_xyz = (direction_a @ (cri_a * vxSize_a)) + origin_a
# coords_xyz = (direction_a @ (idx * vxSize_a)) + origin_a
return XyzTuple(*coords_xyz)
5.2.2 xyz转irc
def xyz2irc(coord_xyz, origin_xyz, vxSize_xyz, direction_a):
origin_a = np.array(origin_xyz)
vxSize_a = np.array(vxSize_xyz)
coord_a = np.array(coord_xyz)
cri_a = ((coord_a - origin_a) @ np.linalg.inv(direction_a)) / vxSize_a
cri_a = np.round(cri_a)
return IrcTuple(int(cri_a[2]), int(cri_a[1]), int(cri_a[0]))
CT文件中数据单位为HU(HounsField Units,亨氏单位)。其中人体各组织的HU值水平为:
空气:-1000HU,约0g/cm3
水:0HU,约1g/cm3
骨骼:1000HU,约2~3g/cm3。
因此超出-1000HU到1000HU外的数据并不是我们需要关心的数据,可强制转换为限值。
体素:可理解为CT扫描后得到的三维切片矩阵中所对应的一个点(像素),即切片后最小的人体组织,接三维的立体像素。
结节:可能为恶性也可能是良性,CT扫描后可根据体素的尺寸,结节中心坐标,结节直径截取出结节所对应的坐标值已经HU值。
下图第一行是对CT文件中,三维CT矩阵用不同维度索引下的结果;
下图第二行是对某个结节中,三维结节矩阵用不同维度索引下的结果。
更多可视化内容可参照原书代码的ipynb文件。
candidateInfo_list = getCandidateInfoList(requireOnDisk_bool=True)
返回candidates.csv文件对应的list,其中每个元素为名称为candidateInfoTuple的元组,元组有如下节点:
class, diameter, id, xyz
属性如下:
CT.hu_a:以HU为单位的三维array,存储的是CT的所有体素数据。
CT.origin_xyz:xyz坐标和irc坐标的原点偏移量
CT.vzSize_xyz:体素在xyz坐标轴的尺度大小
CT.direction_a:体素的空间矩阵
CT.getRawCandidate函数:
ct_chunk, center_irc = getRawCandidate(center_xyz, width_irc)
center_xyz:结节在xyz坐标系的坐标值。
width_irc:体素在irc坐标系的尺度大小。
ct_chunk:结节在irc坐标轴的HU值的三维矩阵。
center_irc:结节中心在irc坐标系的坐标值。
ds = LunaDataset(val_stride=0, isValSet_bool=False, series_uid=None)
val_stride:作为验证集时,从数据集中抽取样本作为验证集样本的步长。即每隔val_stride抽取一个样本作为验证集样本。
isValSet_bool:是否作为验证集。
series_uid:获取某个uid对应的所有样本。
书中代码【dsets.py】如下:
import copy
import csv
import functools
import glob
import os
from collections import namedtuple
import SimpleITK as sitk
import numpy as np
import torch
import torch.cuda
from torch.utils.data import Dataset
from util.disk import getCache
from util.util import XyzTuple, xyz2irc
from util.logconf import logging
log = logging.getLogger(__name__)
# log.setLevel(logging.WARN)
# log.setLevel(logging.INFO)
log.setLevel(logging.DEBUG)
raw_cache = getCache('part2ch10_raw')
CandidateInfoTuple = namedtuple(
'CandidateInfoTuple',
'isNodule_bool, diameter_mm, series_uid, center_xyz',
)
@functools.lru_cache(1)
def getCandidateInfoList(requireOnDisk_bool=True):
# We construct a set with all series_uids that are present on disk.
# This will let us use the data, even if we haven't downloaded all of
# the subsets yet.
mhd_list = glob.glob('data-unversioned/part2/luna/subset*/*.mhd')
presentOnDisk_set = {os.path.split(p)[-1][:-4] for p in mhd_list}
diameter_dict = {}
with open('data/part2/luna/annotations.csv', "r") as f:
for row in list(csv.reader(f))[1:]:
series_uid = row[0]
annotationCenter_xyz = tuple([float(x) for x in row[1:4]])
annotationDiameter_mm = float(row[4])
diameter_dict.setdefault(series_uid, []).append(
(annotationCenter_xyz, annotationDiameter_mm)
)
candidateInfo_list = []
with open('data/part2/luna/candidates.csv', "r") as f:
for row in list(csv.reader(f))[1:]:
series_uid = row[0]
if series_uid not in presentOnDisk_set and requireOnDisk_bool:
continue
isNodule_bool = bool(int(row[4]))
candidateCenter_xyz = tuple([float(x) for x in row[1:4]])
candidateDiameter_mm = 0.0
for annotation_tup in diameter_dict.get(series_uid, []):
annotationCenter_xyz, annotationDiameter_mm = annotation_tup
for i in range(3):
delta_mm = abs(candidateCenter_xyz[i] - annotationCenter_xyz[i])
if delta_mm > annotationDiameter_mm / 4:
break
else:
candidateDiameter_mm = annotationDiameter_mm
break
candidateInfo_list.append(CandidateInfoTuple(
isNodule_bool,
candidateDiameter_mm,
series_uid,
candidateCenter_xyz,
))
candidateInfo_list.sort(reverse=True)
return candidateInfo_list
class Ct:
def __init__(self, series_uid):
mhd_path = glob.glob(
'data-unversioned/part2/luna/subset*/{}.mhd'.format(series_uid)
)[0]
ct_mhd = sitk.ReadImage(mhd_path)
ct_a = np.array(sitk.GetArrayFromImage(ct_mhd), dtype=np.float32)
# CTs are natively expressed in https://en.wikipedia.org/wiki/Hounsfield_scale
# HU are scaled oddly, with 0 g/cc (air, approximately) being -1000 and 1 g/cc (water) being 0.
# The lower bound gets rid of negative density stuff used to indicate out-of-FOV
# The upper bound nukes any weird hotspots and clamps bone down
ct_a.clip(-1000, 1000, ct_a)
self.series_uid = series_uid
self.hu_a = ct_a
self.origin_xyz = XyzTuple(*ct_mhd.GetOrigin())
self.vxSize_xyz = XyzTuple(*ct_mhd.GetSpacing())
self.direction_a = np.array(ct_mhd.GetDirection()).reshape(3, 3)
def getRawCandidate(self, center_xyz, width_irc):
center_irc = xyz2irc(
center_xyz,
self.origin_xyz,
self.vxSize_xyz,
self.direction_a,
)
slice_list = []
for axis, center_val in enumerate(center_irc):
start_ndx = int(round(center_val - width_irc[axis]/2))
end_ndx = int(start_ndx + width_irc[axis])
assert center_val >= 0 and center_val < self.hu_a.shape[axis], repr([self.series_uid, center_xyz, self.origin_xyz, self.vxSize_xyz, center_irc, axis])
if start_ndx < 0:
# log.warning("Crop outside of CT array: {} {}, center:{} shape:{} width:{}".format(
# self.series_uid, center_xyz, center_irc, self.hu_a.shape, width_irc))
start_ndx = 0
end_ndx = int(width_irc[axis])
if end_ndx > self.hu_a.shape[axis]:
# log.warning("Crop outside of CT array: {} {}, center:{} shape:{} width:{}".format(
# self.series_uid, center_xyz, center_irc, self.hu_a.shape, width_irc))
end_ndx = self.hu_a.shape[axis]
start_ndx = int(self.hu_a.shape[axis] - width_irc[axis])
slice_list.append(slice(start_ndx, end_ndx))
ct_chunk = self.hu_a[tuple(slice_list)]
return ct_chunk, center_irc
@functools.lru_cache(1, typed=True)
def getCt(series_uid):
return Ct(series_uid)
@raw_cache.memoize(typed=True)
def getCtRawCandidate(series_uid, center_xyz, width_irc):
ct = getCt(series_uid)
ct_chunk, center_irc = ct.getRawCandidate(center_xyz, width_irc)
return ct_chunk, center_irc
class LunaDataset(Dataset):
def __init__(self,
val_stride=0,
isValSet_bool=None,
series_uid=None,
):
self.candidateInfo_list = copy.copy(getCandidateInfoList())
if series_uid:
self.candidateInfo_list = [
x for x in self.candidateInfo_list if x.series_uid == series_uid
]
if isValSet_bool:
assert val_stride > 0, val_stride
self.candidateInfo_list = self.candidateInfo_list[::val_stride]
assert self.candidateInfo_list
elif val_stride > 0:
del self.candidateInfo_list[::val_stride]
assert self.candidateInfo_list
log.info("{!r}: {} {} samples".format(
self,
len(self.candidateInfo_list),
"validation" if isValSet_bool else "training",
))
def __len__(self):
return len(self.candidateInfo_list)
def __getitem__(self, ndx):
candidateInfo_tup = self.candidateInfo_list[ndx]
width_irc = (32, 48, 48)
candidate_a, center_irc = getCtRawCandidate(
candidateInfo_tup.series_uid,
candidateInfo_tup.center_xyz,
width_irc,
)
candidate_t = torch.from_numpy(candidate_a)
candidate_t = candidate_t.to(torch.float32)
candidate_t = candidate_t.unsqueeze(0)
pos_t = torch.tensor([
not candidateInfo_tup.isNodule_bool,
candidateInfo_tup.isNodule_bool
],
dtype=torch.long,
)
return (
candidate_t,
pos_t,
candidateInfo_tup.series_uid,
torch.tensor(center_irc),
)
import functools
import glob
import os.path
import csv
import SimpleITK as sitk
import numpy as np
import copy
import torch
import torch.cuda
from torch.utils.data import Dataset
from collections import namedtuple
from util.disk import getCache
from util.util import XyzTuple, xyz2irc
from util.logconf import logging
log = logging.getLogger(__name__)
log.setLevel(logging.DEBUG)
# annotations.csv: 记录实际结节。文件结构: uid, x, y, z, diameter
# candidates.csv: 记录候选结节。文件结构: uid, x, y, z, class
raw_cache = getCache('part2ch10_raw')
# 构建用于存储候选结节的元组, 结构: class, diameter, id, xyz
candidateInfoTuple = namedtuple('candidateInfoTuple',
'isNodule_bool, diameter_mm, series_uid, center_xyz')
@functools.lru_cache(1) # 缓存一次调用结果
def getCandidateInfoList(requireOnDisk_bool=True):
"""
加载annotations.csv和candidates.csv,分别存到diameter_list和candidateInfo_list
:param requireOnDisk_bool. 如果文件不存在,是否跳过
:return candidateInfo_list. 由candidateInfoTuple构成的list
"""
mhd_list = glob.glob('data-unversioned/part2/luna/subset*/*.mhd')
presentOnDisk_set = {os.path.split(p)[-1][:-4] for p in mhd_list} # 提取所有文件名,即uid
diameter_dict= {}
with open('data/part2/luna/annotations.csv', 'r') as f:
for row in list(csv.reader(f))[1:]:
series_uid = row[0]
annotationCenter_xyz = tuple([float(x) for x in row[1:4]])
annotationDiameter_mm = float(row[4])
diameter_dict.setdefault(series_uid, []).append(
(annotationCenter_xyz, annotationDiameter_mm)
)
candidateInfo_list = []
with open('data/part2/luna/candidates.csv', 'r') as f:
for row in list(csv.reader(f))[1:]:
series_uid = row[0]
# 如果annotations.csv中找不到这个id,则跳过
if series_uid not in presentOnDisk_set and requireOnDisk_bool:
continue
candidateDiameter_xyz = tuple([float(x) for x in row[1:4]])
isNodule_bool = bool(int(row[4]))
# 如果candidate中的xyz坐标和annotation中的xyz坐标偏差大于半径的一半,
# 则认为它们不是同一个节点,将直接用零代替,即认为这不是结节
candidateDiameter_mm = 0.0
for annotation_tup in diameter_dict.get(series_uid, []):
annotation_xyz, annotationDiameter_mm = annotation_tup
for i in range(3):
delta_mm = abs(candidateDiameter_xyz[i] - annotation_xyz[i])
if delta_mm > annotationDiameter_mm/4:
break
else:
candidateDiameter_mm = annotationDiameter_mm
break
candidateInfo_list.append(candidateInfoTuple(
isNodule_bool,
candidateDiameter_mm,
series_uid,
candidateDiameter_xyz,
))
candidateInfo_list.sort(reverse=True)
return candidateInfo_list
class Ct:
def __init__(self, series_uid):
mhd_path = glob.glob(r'data-unversioned/part2/luna/subset*/{}.mhd'.format(series_uid))[0]
# 用SampleSTK包可直接读取CT扫描数据
ct_mhd = sitk.ReadImage(mhd_path)
# HU: 亨氏单位,Hounsfield Unit.
# 空气为-1000 HU,约等于0 g/cm3. 水为0 HU,约等于1 g/cm3, 骨骼至少时1000HU,约等于2~3g/cm3
ct_a = np.array(sitk.GetArrayFromImage(ct_mhd), dtype=np.float32) # 读取到的数据单位为HU
# 将数据限定再-1000~1000 HU
ct_a.clip(-1000, 1000, ct_a)
self.series_uid = series_uid
self.hu_a = ct_a
self.origin_xyz = XyzTuple(*ct_mhd.GetOrigin()) # xyz坐标和irc坐标的原点偏移量
self.vxSize_xyz = XyzTuple(*ct_mhd.GetSpacing()) # 体素在xyz坐标轴的大小
self.direction_a = np.array(ct_mhd.GetDirection()).reshape(3, 3) # 体素方向矩阵,等于eye(3)
def getRawCandidate(self, center_xyz, width_irc):
"""
根据xyz坐标算出病人坐标irc。然后根据每个结节的irc和体素宽度,算出结节包含的体素块数据
:param center_xyz: 结节的xyz坐标
:param width_irc: 体素宽度
:return ct_chunk: 结节包含的体素块的HU值,array
:return center_irc: 结节的病人坐标信息
"""
center_irc = xyz2irc(
center_xyz,
self.origin_xyz,
self.vxSize_xyz,
self.direction_a
)
slice_list = []
for axis, center_val in enumerate(center_irc):
start_ndx = int(round(center_val - width_irc[axis]/2))
end_ndx = int(start_ndx + width_irc[axis])
assert center_val >= 0 and center_val < self.hu_a.shape[axis], repr([self.series_uid, center_xyz, self.origin_xyz, self.vxSize_xyz, center_irc, axis])
if start_ndx < 0:
# log.warning("Crop outside of CT array: {} {}, center:{} shape:{} width:{}".format(
# self.series_uid, center_xyz, center_irc, self.hu_a.shape, width_irc))
start_ndx = 0
end_ndx = int(width_irc[axis])
if end_ndx > self.hu_a.shape[axis]:
# log.warning("Crop outside of CT array: {} {}, center:{} shape:{} width:{}".format(
# self.series_uid, center_xyz, center_irc, self.hu_a.shape, width_irc))
end_ndx = self.hu_a.shape[axis]
start_ndx = int(self.hu_a.shape[axis] - width_irc[axis])
slice_list.append(slice(start_ndx, end_ndx))
ct_chunk = self.hu_a[tuple(slice_list)]
return ct_chunk, center_irc
@functools.lru_cache(1, typed=True) # 保留一次缓存结果
def getCt(series_uid):
return Ct(series_uid)
@raw_cache.memoize(typed=True) # 数据缓存到同路径的cache文件夹下
def getCtRawCandidate(series_uid, center_xyz, width_irc):
ct = getCt(series_uid)
ct_chunk, center_irc = ct.getRawCandidate(center_xyz, width_irc)
return ct_chunk, center_irc
class LunaDataset(Dataset):
def __init__(self, val_stride=0, isValSet_bool=False, series_uid=None):
"""
val_stride:作为验证集时,从数据集中抽取样本作为验证集样本的步长。即每隔val_stride抽取一个样本作为验证集样本。
isValSet_bool:是否作为验证集。
series_uid:获取某个uid对应的所有样本。
"""
self.candidateInfo_list = copy.copy(getCandidateInfoList())
if series_uid:
self.candidateInfo_list = [x for x in self.candidateInfo_list if x.series_uid==series_uid]
if isValSet_bool:
assert val_stride > 0, val_stride
self.candidateInfo_list = self.candidateInfo_list[::val_stride]
assert self.candidateInfo_list
elif val_stride > 0:
del self.candidateInfo_list[::val_stride]
assert self.candidateInfo_list
log.info("(!r): {} {} samples".format(
self,
len(self.candidateInfo_list),
"validation" if isValSet_bool else "training",
))
def __len__(self):
return len(self.candidateInfo_list)
def __getitem__(self, ndx):
"""
返回指定索引对应的结节信息
:param ndx: 某个ct数据中的第ndx个结节索引
:return: candidate_t. 结节所包含的所有体素的三位数组。t代表数组时个tensor
:return: post_t. 结节是否为肿瘤。0代表不是,1代表肿瘤。
:return: series_uid. ndx所对应的结节uid
:return: center_irc. 结节的重心坐标。类型为tensor
"""
candidateInfo_tup = self.candidateInfo_list[ndx]
width_irc = (32, 48, 48)
candidate_a, center_irc = getCtRawCandidate(
candidateInfo_tup.series_uid,
candidateInfo_tup.center_xyz,
width_irc,
)
candidate_t = torch.from_numpy(candidate_a)
candidate_t = candidate_t.to(torch.float32)
candidate_t = candidate_t.unsqueeze(0)
post_t = torch.tensor([
not candidateInfo_tup.isNodule_bool,
candidateInfo_tup.isNodule_bool
],
dtype=torch.long,
)
return (
candidate_t,
post_t,
candidateInfo_tup.series_uid,
torch.tensor(center_irc)
)