MATLAB处理大量三维图像,Data Augmentation

Author: Zongwei Zhou | 周纵苇
Weibo: @MrGiovanni
Email: [email protected]


Please cite this paper if you found it useful. Thanks!
Wang H, Zhou Z, Li Y, et al. Comparison of machine learning methods for classifying mediastinal lymph node metastasis of non-small cell lung cancer from 18F-FDG PET/CT images[J]. 2017, 7.


前言吐个槽:深度学习对于样本的要求是1) 数量多,2) 质量高。相对来讲的话数量比质量更重要一点,即使训练样本中有个别的出错了,也基本不会对卷积神经网络(Convolutional Neural Network, CNN)造成什么影响。所以,要搞深度学习,第一步就是能弄到足够多的较好的数据集,说实话这个门槛是比较高的,虽说现在是什么“大数据时代”,“数据爆炸了”,由于1) 卷积神经网络是有监督学习,如此海量的数据的标签往往很难标注,2) 有些数据(如医疗数据),涉及到隐私,一般是不会给研究者开放的。这两个原因直接让深度学习的研究门槛上升了。因此很少会有看到一些比较小的学校或者研究机构把深度学习做得很6的,这种数据需要很牛逼的老师去申请到基金,或者合作才有开展的可能,在国内比较6的学校也就是985那些土豪名校了,比较6的公司也是bat这种和一些大量与数据打交道的公司了。国外相对开展的更火热一点,因为腰杆子硬嘛呵呵。像Google的AlphaGo,那你说是那些高质量的棋谱难拿到还是深度学习算法难实现呢,网上这么多AlphaGo的开源代码,没有用啊,Google并不开放数据集。。这才是槽点。

MATLAB处理大量三维图像,Data Augmentation_第1张图片
不可否认这个的确是机遇

这就是我为什么公开自己的代码,1)首先我的代码本身没什么难度,稍微学习一下MATLAB就可以写出来,2)就算我把代码公开也没事儿,没有数据集这事儿就白搭,3)交流分享在我这儿比什么都重要:)自己花了很多时间做了这个,要是别人再做这个,没有可以参考的,还得像我一样从零开始一点点摸索,那我做了也是白搭。

数据集我是不会公开的,因为,很贵,而且不是我的,是老师的:)


逼逼的有点多了,现在开始脚踏实地干Data Augmentation。

论文引用

本文用卷积神经网络做有监督学习,使用的数据集是三维的肺部肿瘤PET/CT医学影像,总共1383幅影像,其中有127幅恶性肿瘤以及1256幅良性肿瘤。每幅医学影像的像素尺寸是101×101×101,图像中含有肿瘤的分割标签,精确到每个像素,是肿瘤的标为true,不是的标为false。详细的PET/CT图像数据库的信息如表3.1所示

MATLAB处理大量三维图像,Data Augmentation_第2张图片
PET/CT数据集

所有的肿瘤都位于每幅图的 中心位置。在对数据集的标记统计后发现,一般肿瘤的尺寸在15×15像素左右的正态分布,考虑到保留肿瘤周边的医学信息,我们使用的每个patch的尺寸为51×51像素,各个patch的六通道都是对应的同一个肿瘤。


参数初始化

clear all; close all; clc;

addr = 'F:\\Datasets\\For_ZhouZongwei_3D\\For_DeepLearning\\';  
    % 数据库存放的地址                
save_addr = 'G:\\Data4DeepLearning\\';      
    % 切完的数据存放的地址
PIXEL = 101;        
    % 图像的原始尺寸
CutRange = 51;      
    % 网络输入图像的尺寸: CutRange×CutRange
ThetaStep = 10;     
    % 角度划分的步长
ThetaStart = 0;
    % 角度划分的开始角度     
angnum = 3;         
    % 划分几个角度
    % 也就是说现在我们划分的角度是:0,10,20
DeltaStep = 2;      
    % 平切的步长
DeltaStart = 0;
    % 平切的开始像素
delnum = 3;         
    % 用几个平切
    % 也就是说现在我们的平切像素相对于中心点为:0,2,4
Channels = 6;       
    % CNN输入通道个数
B_M_Prop = 1;       
    % 良性肿瘤的训练数目是恶性肿瘤的1倍
va_m = 27;              
    % 测试集中的恶性个数
va_b = 27;             
    % 测试集中的良性个数
hwait = waitbar(0, 'Cutting Now, Pls Wait >>>');
child_h = waitbar(0, 'Rotate and Parallel Cut, Pls Wait >>>');
MATLAB处理大量三维图像,Data Augmentation_第3张图片
你看,有了进度条心里就很有底了,因为运行的速度太慢啦

参数预处理

CtImaFiles = dir(strcat(addr, '*', '_CT3d.raw'));       
    % 所有三维CT图像的文件名
PetImaFiles = dir(strcat(addr, '*', '_PT3d.raw'));       
    % 所有三维PET图像的文件名
ImageNum = length(CtImaFiles);                      
    % CT图像样本的个数,也是全部数据集的个数,因为是一一对应的
tr = 1;               
    % 统计切割的训练集数目,初始化为1
va = 1;             
    % 统计测试集数目,初始化为1
va_m_c = 1;     
    % 统计测试集中的恶性数目,初始化为1
va_b_c = 1;       
    % 统计测试集中的良性数目,初始化为1
tr_m_c = 1;       
    % 统计训练集中的恶性数目,初始化为1
tr_b_c = 1;        
    % 统计训练集中的良性数目,初始化为1

计算正负样本数

disp('>> Label Counting ...')
tic;
label = ones(1, ImageNum);                                     
    % 存放所有的标签
for i=1:ImageNum
    label(:, i) = str2double(CtImaFiles(i, 1).name(strfind(CtImaFiles(i, 1).name, 'Class') + 5));
    % 由于肿瘤的标签信息是在每个文件的名字里面,1表示恶性,0表示良性
    % 请参照下图,你就会了解我的数据库是如何命名的啦
    % 因此只需要用strfind( )函数在文件名组成的字符串中查找Class的位置,然后向右进五格即可        
end
data_m = length(find(label==1));
data_b = length(find(label==0));
clear label;
tr_m = data_m - va_m;       % 训练集中的恶性数目
tr_b = tr_m * B_M_Prop;     % 训练集中的良性数目
disp(['      ', num2str(data_m), ' Malignance and ', num2str(data_b), ' Benign.'])
toc;disp(' ')
clear data_m data_b;
MATLAB处理大量三维图像,Data Augmentation_第4张图片
数据集长这样的,标签信息都在文件名里头

矩阵预处理

这儿我要说明一下为什么要先做矩阵预处理。如果不是一开始就开辟耗内存,那么后续的矩阵就会动态分配内存,这样是很慢的,特别是当矩阵大了的时候,速度呈指数式的变慢。举例子,一开始没有分配耗内存,然后我来个for循环,i每加一,给矩阵赋值Array[i] = ***,好,Array相应的就得开辟一块内存出来,当i=10000时,Array已经很大很大了,i<=10001,速度,爆慢哦。
那么如何给矩阵初始化呢,首先你得知道你要多大的矩阵,比如我需要的是[100000, 6, 51, 51],那就用ones( )或者zero( )给初始化全1或全0的这么大的矩阵,具体的代码如下:

TrNum = angnum * angnum * angnum * delnum * delnum * delnum * (tr_m + tr_b);
    % 训练集的大小,用tr_m+tr_b个样本,然后做数据拓展
    % angnum个角度变换和delnum和平移变换
    % 因为是三维图像,在x, y, z轴方向上都有旋转、平移变换
    % 因此每个样本拓展成了angnum×angnum×angnum×delnum×delnum×delnum个
VaNum = va_m + va_b;
    % 测试集的大小
    % 测试集不参加网络训练,因此不需要数据拓展
disp('>> Memory Distribute ...')
tic;
TrData = ones(CutRange, CutRange, Channels, TrNum);
VaData = ones(CutRange, CutRange, Channels, VaNum);
TrLabel = ones(1, TrNum);
VaLabel = ones(1, VaNum);
    % 分配给训练集和测试集所需要的内存
toc;disp(' ');
clear TrNum VaNum;

正式开始数据拓展

这儿有一个大循环,是从1一直循环到数据集结束,一共有1383(ImageNum)个样本

disp('>> Data Augmentation ...')
r = randperm(ImageNum);
    % 这是为了打乱数据,方便后续的交叉验证
for i=1:ImageNum
    
    index = r(i);
    
    %% 考虑到正负样本比,看看是不是不需要再读入了
    % 由于后续的切割很费时间,所以弄一个这个监测
    labeltmp = str2double(CtImaFiles(index, 1).name(strfind(CtImaFiles(index, 1).name, 'Class') + 5));     
        % 读入一幅图像的标签
    if tr_m_c > tr_m && labeltmp == 1
        % 如果训练集的恶性样本已经够了,就跳过,continue
        waitbar(i/ImageNum, hwait, 'Reading Now, Pls Wait >>>');
        continue;
    end
    if tr_b_c > tr_b && labeltmp == 0
        % 如果训练集的良性样本已经够了,就跳过,continue
        waitbar(i/ImageNum, hwait, 'Reading Now, Pls Wait >>>');
        continue;
    end
    
    %% 读入一张三维CT图像
    fid = fopen(strcat(addr, CtImaFiles(index).name));
    CTimage = fread(fid, PIXEL^3, 'float');
    CTimage = reshape(CTimage,PIXEL,PIXEL,PIXEL);
    fclose(fid);
    
    %% 读入一张三维PET图像
    fid = fopen(strcat(addr, PetImaFiles(index).name));
    PETimage = fread(fid, PIXEL^3, 'float');
    PETimage = reshape(PETimage,PIXEL,PIXEL,PIXEL);
    fclose(fid);  
    
    % 至此,CTimage中存的是一幅三维CT图像
    % PETimage中存的是一幅三维PET图像
    % 这两个是像素一一对应的
    % label中存的是标签

    %% 由于测试集不需要data argumentation,所以先处理    
    if labeltmp == 1 && va_m_c <= va_m
        % 读入的是恶性且测试集中的恶性还不够
        [temp1, temp2, temp3] = RotParGio(CTimage, 0, 0, 0, 0, 0, 0, CutRange);
        VaData(:, :, 1, va) = temp1; VaData(:, :, 2, va) = temp2; VaData(:, :, 3, va) = temp3; 
        [temp1, temp2, temp3] = RotParGio(PETimage, 0, 0, 0, 0, 0, 0, CutRange);
        VaData(:, :, 4, va) = temp1; VaData(:, :, 5, va) = temp2; VaData(:, :, 6, va) = temp3; 
        VaLabel(:, va) = labeltmp;
        va = va + 1;
        va_m_c = va_m_c + 1;

    elseif labeltmp == 0 && va_b_c <= va_b
        % 读入的是良性且测试集中的良性还不够
        [temp1, temp2, temp3] = RotParGio(CTimage, 0, 0, 0, 0, 0, 0, CutRange);
        VaData(:, :, 1, va) = temp1; VaData(:, :, 2, va) = temp2; VaData(:, :, 3, va) = temp3; 
        [temp1, temp2, temp3] = RotParGio(PETimage, 0, 0, 0, 0, 0, 0, CutRange);
        VaData(:, :, 4, va) = temp1; VaData(:, :, 5, va) = temp2; VaData(:, :, 6, va) = temp3; 
        VaLabel(:, va) = labeltmp;
        va = va + 1;
        va_b_c = va_b_c + 1;
    
    %% 测试集这个标签的已经满了,那么就要data augmentation
    else    
        %% 沿着x,y,z轴旋转一个角度,然后裁剪
        % 成功的算法~64s 将循环变成矩阵!
        ind_all = 1:angnum * angnum * angnum * delnum * delnum * delnum;
        [theta_z, theta_y, theta_x, delta_z, delta_y, delta_x] = ind2sub([angnum, angnum, angnum, delnum, delnum, delnum], ind_all);
    
        theta_x = ThetaStep * theta_x - (ThetaStep * theta_x(1) - ThetaStart);
        theta_y = ThetaStep * theta_y - (ThetaStep * theta_y(1) - ThetaStart);
        theta_z = ThetaStep * theta_z - (ThetaStep * theta_z(1) - ThetaStart);
        theta_x = theta_x';
        theta_y = theta_y';
        theta_z = theta_z';
    
        delta_x = DeltaStep * delta_x - (DeltaStep * delta_x(1) - DeltaStart);
        delta_y = DeltaStep * delta_y - (DeltaStep * delta_y(1) - DeltaStart);
        delta_z = DeltaStep * delta_z - (DeltaStep * delta_z(1) - DeltaStart);
        delta_x = delta_x';
        delta_y = delta_y';
        delta_z = delta_z';
    
      
        waitbar(0, child_h, 'Rotate and Parallel Cut, Pls Wait >>>');
        for t = 1 : angnum * angnum * angnum * delnum * delnum * delnum
        %% 调用Rotation3dGio()函数,输入3d图,沿x,y,z轴转动的角度和patch的尺寸,输出三张2.5D图temp1,temp2,temp3
            % disp([num2str(theta_x(t)), ' ', num2str(theta_y(t)), ' ', num2str(theta_z(t)), ' ', num2str(delta_x(t)), ' ', num2str(delta_y(t)), ' ', num2str(delta_z(t))]);
            [temp1, temp2, temp3] = RotParGio(CTimage, theta_x(t), theta_y(t), theta_z(t), delta_x(t), delta_y(t), delta_z(t), CutRange);
            TrData(:, :, 1, tr) = temp1; TrData(:, :, 2, tr) = temp2; TrData(:, :, 3, tr) = temp3; 
            [temp1, temp2, temp3] = RotParGio(PETimage, theta_x(t), theta_y(t), theta_z(t), delta_x(t), delta_y(t), delta_z(t), CutRange);
            TrData(:, :, 4, tr) = temp1; TrData(:, :, 5, tr) = temp2; TrData(:, :, 6, tr) = temp3; 
            TrLabel(1, tr) = labeltmp;
            tr = tr + 1;
            
            waitbar(t/(angnum * angnum * angnum * delnum * delnum * delnum), child_h, ['No. ', num2str(i), ': Rotate and Parallel Cut, Pls Wait >>>']);
        end
        if labeltmp == 1
            tr_m_c = tr_m_c + 1;
        else
            tr_b_c = tr_b_c + 1;
        end
        
    end
    
    %% 进度条,注意:这儿不要clear一些中间矩阵,测试发现clear以后再重新创建速度回变慢!!!
    waitbar(i/ImageNum, hwait, 'Reading Now, Pls Wait >>>');
end

旋转、平移函数编写RotParGio( )

这儿需要一点点的计算机图形学的知识,就是三维图像的旋转、平移公式

MATLAB处理大量三维图像,Data Augmentation_第5张图片
MATLAB处理大量三维图像,Data Augmentation_第6张图片
MATLAB处理大量三维图像,Data Augmentation_第7张图片
MATLAB处理大量三维图像,Data Augmentation_第8张图片

三维变换可以叠加,体现在矩阵运算上是mat'=Px·Py·Pz·Pm·mat,其中mat表示变换前的齐次坐标点,mat表示经过任意三维旋转,平移变换后的齐次坐标点。在编程方面,考虑到程序的效率,并没有直接对三维PET/CT进行旋转平移变换,而是对横截面-冠状面-矢状面所组成的一组切面进行变换。考虑到计算机内存的限制,取旋转角度分别为-10°,0°, 10°,平移像素分别为-2,0,2像素,即每个样本共扩充成1,008,207组PET/CT的六通道2.5D的输入(3×3×3×3×3×3×1383),样本大小共46.4GB。

MATLAB处理大量三维图像,Data Augmentation_第9张图片
RotParGio.m描述
% 输入:3d图像,沿x,y,z轴旋转的角度,patch的尺寸
% 输出:2.5D的三个图
% 编写人:周纵苇

function [slide_x, slide_y, slide_z] = RotParGio(image, theta_x, theta_y, theta_z, delta_x, delta_y, delta_z, size)

%% 三维旋转变换矩阵
rotation_x = [1 0 0 0;
            0 cosd(theta_x) sind(theta_x) 0;
            0 -sind(theta_x) cosd(theta_x) 0;
            0 0 0 1];
rotation_y = [cosd(theta_y) 0 -sind(theta_y) 0
            0 1 0 0;
            sind(theta_y) 0 cosd(theta_y) 0;
            0 0 0 1];
rotation_z = [cosd(theta_z) sind(theta_z) 0 0;
            -sind(theta_z) cosd(theta_z) 0 0;
            0 0 1 0;
            0 0 0 1];
%% 平移变换矩阵
parallel_xyz = [1 0 0 delta_x;
                0 1 0 delta_y;
                0 0 1 delta_z;
                0 0 0 1];

tmp = 101;

X = -floor(tmp/2) : floor(tmp/2);
Y = -floor(tmp/2) : floor(tmp/2);
Z = -floor(tmp/2) : floor(tmp/2);
ind_all = 1 : size * size;
[y_all, x_all] = ind2sub(size, ind_all);
x_all = x_all'; y_all = y_all';
origin_x = [zeros(size*size,1), x_all-floor(size/2)-1, y_all-floor(size/2)-1, ones(size*size,1)];
origin_y = [x_all-floor(size/2)-1, zeros(size*size,1), y_all-floor(size/2)-1, ones(size*size,1)];
origin_z = [x_all-floor(size/2)-1, y_all-floor(size/2)-1, zeros(size*size,1), ones(size*size,1)];
origin = [origin_x; origin_y; origin_z];
origin = origin';
P = parallel_xyz * rotation_x * rotation_y * rotation_z * origin;
slide_y = interp3(X, Y, Z, image, P(1,1:size*size), P(2,1:size*size), P(3,1:size*size));
slide_x = interp3(X, Y, Z, image, P(1, size*size+1:2*size*size), P(2, size*size+1:2*size*size), P(3, size*size+1:2*size*size));
slide_z = interp3(X, Y, Z, image, P(1, 2*size*size+1:3*size*size), P(2, 2*size*size+1:3*size*size), P(3, 2*size*size+1:3*size*size));
slide_x = reshape(slide_x, size, size);
slide_y = reshape(slide_y, size, size);
slide_z = reshape(slide_z, size, size);
slide_z = slide_z';

以下的ppt是我在设计RotParGio函数的心路历程,具体展示了如何实现旋转功能的

MATLAB处理大量三维图像,Data Augmentation_第10张图片
MATLAB处理大量三维图像,Data Augmentation_第11张图片

我猜你一定会注意到前面的代码有一部分没有看懂,我又没有足够的注释,没错,这个是我想展开来讲的内容:一个小trick来避免MATLAB中使用大量for循环。


避免MATLAB中大量使用for循环的方法——ind2sub( )

是时候亮出这篇文章的大招了,要是你看了这个标题就明白了的话,,,那就别往下看了,这是我能写出来的最难的trick了:)
MATLAB的全称是Matrix Lab(矩阵实验室),底层是C++写的,强项是处理矩阵的运算,由于矩阵的运算用了大量的循环,速度非常的慢,MATLAB一开始就是为了解决这个问题的。但是呢,真的要让MATLAB来处理for循环语句,那速度是很慢的。所以要加速程序运行的速度,一个简单的想法就是把循环运算转换成矩阵运算。这就是这个小trick的由来。

MATLAB处理大量三维图像,Data Augmentation_第12张图片
由于要对每个像素做旋转平移变换,所以讲道理要用三重for循环,每层是1:51像素遍历,效率极低!

我们来回想一下两层for循环是怎么来的

for i = 1 : 3
  for j = 1 : 4

零中心化,标准化

这个步骤的方法是所有值减掉平均值再除以标准差

MATLAB处理大量三维图像,Data Augmentation_第13张图片
数据预处理,仅此而已~
disp('>> After Data Preprocess')
disp('      Training data range')
for i=1:6    
    tmp = squeeze(TrData(:, :, i, :));
    TrData(:, :, i, :) = TrData(:, :, i, :) - mean(tmp(:));
    VaData(:, :, i, :) = VaData(:, :, i, :) - mean(tmp(:));
    tmp = squeeze(TrData(:, :, i, :));
    TrData(:, :, i, :) = TrData(:, :, i, :) / std(tmp(:));
    VaData(:, :, i, :) = VaData(:, :, i, :) / std(tmp(:));
    tmp = squeeze(TrData(:, :, i, :));
    disp(['      ', num2str(min(tmp(:))), '~', num2str(max(tmp(:))), ' : ', num2str(mean(tmp(:))), ' : ', num2str(std(tmp(:)))])
end
disp('      Validation data range')
for i=1:6
    tmp = squeeze(VaData(:, :, i, :));
    disp(['      ', num2str(min(tmp(:))), '~', num2str(max(tmp(:))), ' : ', num2str(mean(tmp(:))), ' : ', num2str(std(tmp(:)))])
end
clear i tmp;

保存到mat文件,为后续python调用准备

disp('>> Saving TrData to dataset.');
tic;
datasave = [save_addr, num2str(length(TrLabel)), 'TrData.mat'];
save(datasave, 'TrData', 'TrLabel');
toc;disp(' ')
disp('>> Saving VaData to dataset.');
tic;
datasave = [save_addr, num2str(length(VaLabel)), 'VaData.mat'];
save(datasave, 'VaData', 'VaLabel');
toc;disp(' ')
clear datasave addr save_addr;

你可能感兴趣的:(MATLAB处理大量三维图像,Data Augmentation)