3D-Resnet-50 医学图像分类(二分类任务)torch代码(精简版)-图像格式为NIFTI

1. 需要有GPU(推荐8G以上),已设置好CUDA:基于win10深度学习环境配置(conda,python,cuda11.7,torch1.13.0)_dr_yingli的博客-CSDN博客2. 文件格式为常见的nii

img_list格式如下

E:\...\3.nrrd E:\...\3.nrrd 0
E:\...\4.nrrd E:\...\4.nrrd 1

训练代码

import torch
from torch import nn
import os
import numpy as np
from torch.utils.data import Dataset
import nibabel
from scipy import ndimage
from torch import optim
from torch.utils.data import DataLoader
import time
import logging
root_dir = './data'  # type=str, help='Root directory path of data'
img_list = './data/train.txt'  # type=str, help='Path for image list file'
pretrain_path = 'pretrain/resnet_50.pth'  # type=str, help='Path for pretrained model.'
save_folder = "./trails/models/Resnet50"
total_epochs = 20  # type=int, help='Number of total epochs to run'
save_intervals = 10  # type=int, help='Interation for saving model'
learning_rate = 0.001  # set to 0.001 when finetune, type=float, help= 'Initial learning rate (divided by 10 while training by lr scheduler)'
new_layer_names = ['conv_cls'] # type=list, help='New layer except for backbone'
batch_size = 1  # type=int, help='Batch Size'
input_D = 56  # type=int, help='Input size of depth'
input_H = 448  # type=int, help='Input size of height'
input_W = 448  # type=int, help='Input size of width'
torch.manual_seed(1)
class Bottleneck(nn.Module):
    expansion = 4
    def __init__(self, inplanes, planes, stride=1, dilation=1, downsample=None):
        super(Bottleneck, self).__init__()
        self.conv1 = nn.Conv3d(inplanes, planes, kernel_size=1, bias=False)
        self.bn1 = nn.BatchNorm3d(planes)
        self.conv2 = nn.Conv3d(planes, planes, kernel_size=3, stride=stride, dilation=dilation, padding=dilation, bias=False)
        self.bn2 = nn.BatchNorm3d(planes)
        self.conv3 = nn.Conv3d(planes, planes * 4, kernel_size=1, bias=False)
        self.bn3 = nn.BatchNorm3d(planes * 4)
        self.relu = nn.ReLU(inplace=True)
        self.downsample = downsample
        self.stride = stride
        self.dilation = dilation
    def forward(self, x):
        residual = x
        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)
        out = self.conv2(out)
        out = self.bn2(out)
        out = self.relu(out)
        out = self.conv3(out)
        out = self.bn3(out)
        if self.downsample is not None:
            residual = self.downsample(x)
        out += residual
        out = self.relu(out)
        return out
class ResNet(nn.Module):
    def __init__(self, block, layers, input_D, input_H, input_W):
        self.inplanes = 64
        super(ResNet, self).__init__()
        self.conv1 = nn.Conv3d(1, 64, kernel_size=7, stride=(2, 2, 2), padding=(3, 3, 3), bias=False)
        self.bn1 = nn.BatchNorm3d(64) # conv1的输出维度
        self.relu = nn.ReLU(inplace=True)
        self.maxpool = nn.MaxPool3d(kernel_size=(3, 3, 3), stride=2, padding=1) # H/2,W/2。C不变
        self.layer1 = self._make_layer(block, 64, layers[0]) # H,W不变。downsample控制的shortcut,out_channel=64x4=256
        self.layer2 = self._make_layer(block, 128, layers[1], stride=2) # H/2, W/2。downsample控制的shortcut,out_channel=128x4=512
        self.layer3 = self._make_layer(block, 256, layers[2], stride=1, dilation=2) # H/2, W/2。downsample控制的shortcut,out_channel=256x4=1024
        self.layer4 = self._make_layer(block, 512, layers[3], stride=1, dilation=4) # H/2, W/2。downsample控制的shortcut,out_channel=512x4=2048
        self.conv_cls = nn.Sequential(
            nn.AdaptiveMaxPool3d(output_size=(1, 1, 1)),
            nn.Flatten(start_dim=1),
            nn.Dropout(0.1),
            nn.Linear(512 * block.expansion, 1)
        )
        for m in self.modules():
            if isinstance(m, nn.Conv3d):
                m.weight = nn.init.kaiming_normal_(m.weight, mode='fan_out')
            elif isinstance(m, nn.BatchNorm3d):
                m.weight.data.fill_(1)
                m.bias.data.zero_()
    def _make_layer(self, block, planes, blocks, stride=1, dilation=1):
        downsample = None
        if stride != 1 or self.inplanes != planes * block.expansion:
           downsample = nn.Sequential(
            nn.Conv3d(self.inplanes, planes * block.expansion, kernel_size=1, stride=stride, bias=False),
            nn.BatchNorm3d(planes * block.expansion))
        layers = []
        layers.append(block(self.inplanes, planes, stride=stride, dilation=dilation, downsample=downsample))
        self.inplanes = planes * block.expansion # 在下一次调用_make_layer函数的时候,self.in_channel已经x4
        for i in range(1, blocks):
            layers.append(block(self.inplanes, planes, dilation=dilation))
        return nn.Sequential(*layers) # '*'的作用是将list转换为非关键字参数传入
    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)
        x = self.conv_cls(x)
        x = torch.sigmoid_(x)
        return x
def generate_model(input_D, input_H, input_W, pretrain_path):
    model = ResNet(Bottleneck, [3, 4, 6, 3], input_W=input_W, input_H=input_H, input_D=input_D)
    model = model.cuda()
    net_dict = model.state_dict()
    print('loading pretrained model {}'.format(pretrain_path))
    pretrain = torch.load(pretrain_path)
    pretrain_dict = {k.replace("module.", ""): v for k, v in pretrain['state_dict'].items() if k.replace("module.", "") in net_dict.keys()}
    net_dict.update(pretrain_dict) # 字典 dict2 的键/值对更新到 dict 里。
    model.load_state_dict(net_dict) # model.load_state_dict()函数把加载的权重复制到模型的权重中去
    new_parameters = []
    for pname, p in model.named_parameters():
        for layer_name in new_layer_names:
            if pname.find(layer_name) >= 0:
                new_parameters.append(p)
                break
    new_parameters_id = list(map(id, new_parameters))
    base_parameters = list(filter(lambda p: id(p) not in new_parameters_id, model.parameters()))
    parameters = {'base_parameters': base_parameters, 'new_parameters': new_parameters}
    return model, parameters
model, parameters = generate_model(input_D, input_H, input_W, pretrain_path)
params = [{'params': parameters['base_parameters'], 'lr': learning_rate },{'params': parameters['new_parameters'], 'lr': learning_rate*100}]
optimizer = torch.optim.SGD(params, momentum=0.9, weight_decay=1e-3)
scheduler = optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.99)
class Dataset(Dataset):
    def __init__(self, root_dir, img_list, input_D, input_H, input_W):
        with open(img_list, 'r') as f:
            self.img_list = [line.strip() for line in f]
        print("Processing {} datas".format(len(self.img_list)))
        self.root_dir = root_dir
        self.input_D = input_D
        self.input_H = input_H
        self.input_W = input_W
    def __nii2tensorarray__(self, data):
        [z, y, x] = data.shape
        new_data = np.reshape(data, [1, z, y, x])
        new_data = new_data.astype("float32")
        return new_data
    def __len__(self):
        return len(self.img_list)
    def __getitem__(self, idx):
        # read image and labels
        ith_info = self.img_list[idx].split(" ")
        img_name = os.path.join(self.root_dir, ith_info[0])
        label_name = os.path.join(self.root_dir, ith_info[1])
        class_array = np.zeros(1)
        class_array[0] = ith_info[2]
        class_array = torch.tensor(class_array, dtype=torch.float32)  ######
        assert os.path.isfile(img_name)
        assert os.path.isfile(label_name)
        img = nibabel.load(img_name)  # We have transposed the data from WHD format to DHW
        assert img is not None
        mask = nibabel.load(label_name)
        assert mask is not None
        # data processing
        img_array, mask_array = self.__training_data_process__(img, mask)
        # 2 tensor array
        img_array = self.__nii2tensorarray__(img_array)
        mask_array = self.__nii2tensorarray__(mask_array)
        assert img_array.shape == mask_array.shape, "img shape:{} is not equal to mask shape:{}".format(img_array.shape, mask_array.shape)
        return img_array, mask_array, class_array  #####
    def __drop_invalid_range__(self, volume, label=None):
        """
        Cut off the invalid area
        """
        zero_value = volume[0, 0, 0]
        non_zeros_idx = np.where(volume != zero_value)
        [max_z, max_h, max_w] = np.max(np.array(non_zeros_idx), axis=1)
        [min_z, min_h, min_w] = np.min(np.array(non_zeros_idx), axis=1)
        if label is not None:
            return volume[min_z:max_z, min_h:max_h, min_w:max_w], label[min_z:max_z, min_h:max_h, min_w:max_w]
        else:
            return volume[min_z:max_z, min_h:max_h, min_w:max_w]
    def __random_center_crop__(self, data, label):
        from random import random
        """
        Random crop
        """
        target_indexs = np.where(label > 0)
        [img_d, img_h, img_w] = data.shape
        [max_D, max_H, max_W] = np.max(np.array(target_indexs), axis=1)
        [min_D, min_H, min_W] = np.min(np.array(target_indexs), axis=1)
        [target_depth, target_height, target_width] = np.array([max_D, max_H, max_W]) - np.array([min_D, min_H, min_W])
        Z_min = round((min_D - target_depth * 1.0 / 2) * random())
        Y_min = round((min_H - target_height * 1.0 / 2) * random())
        X_min = round((min_W - target_width * 1.0 / 2) * random())
        Z_max = round(img_d - ((img_d - (max_D + target_depth * 1.0 / 2)) * random()))
        Y_max = round(img_h - ((img_h - (max_H + target_height * 1.0 / 2)) * random()))
        X_max = round(img_w - ((img_w - (max_W + target_width * 1.0 / 2)) * random()))
        Z_min = np.max([0, Z_min])
        Y_min = np.max([0, Y_min])
        X_min = np.max([0, X_min])
        Z_max = np.min([img_d, Z_max])
        Y_max = np.min([img_h, Y_max])
        X_max = np.min([img_w, X_max])
        Z_min = round(Z_min)
        Y_min = round(Y_min)
        X_min = round(X_min)
        Z_max = round(Z_max)
        Y_max = round(Y_max)
        X_max = round(X_max)
        return data[Z_min: Z_max, Y_min: Y_max, X_min: X_max], label[Z_min: Z_max, Y_min: Y_max, X_min: X_max]
    def __itensity_normalize_one_volume__(self, volume):
        """
        normalize the itensity of an nd volume based on the mean and std of nonzeor region
        inputs:
            volume: the input nd volume
        outputs:
            out: the normalized nd volume
        """
        pixels = volume[volume > 0]
        mean = pixels.mean()
        std = pixels.std()
        out = (volume - mean) / std
        out_random = np.random.normal(0, 1, size=volume.shape)
        out[volume == 0] = out_random[volume == 0]
        return out
    def __resize_data__(self, data):
        """
        Resize the data to the input size
        """
        [depth, height, width] = data.shape
        scale = [self.input_D * 1.0 / depth, self.input_H * 1.0 / height, self.input_W * 1.0 / width]
        data = ndimage.zoom(data, scale, order=0)
        return data
    def __crop_data__(self, data, label):
        """
        Random crop with different methods:
        """
        # random center crop
        data, label = self.__random_center_crop__(data, label)
        return data, label
    def __training_data_process__(self, data, label):
        # crop data according net input size
        data = data.get_fdata()
        label = label.get_fdata()
        # drop out the invalid range
        data, label = self.__drop_invalid_range__(data, label)
        # crop data
        data, label = self.__crop_data__(data, label)
        # resize data
        data = self.__resize_data__(data)
        label = self.__resize_data__(label)
        # normalization datas
        data = self.__itensity_normalize_one_volume__(data)
        return data, label
training_dataset = Dataset(root_dir=root_dir, img_list=img_list, input_D=input_D, input_H=input_H, input_W=input_W)
data_loader = DataLoader(training_dataset, batch_size=batch_size, shuffle=True, pin_memory=True)
logging.basicConfig(format='%(asctime)s %(levelname)-8s [%(filename)s:%(lineno)d] %(message)s', datefmt='%Y-%m-%d %H:%M:%S', level=logging.DEBUG)
log = logging.getLogger()
def train(data_loader, model, optimizer, scheduler, total_epochs, save_interval, save_folder):
    batches_per_epoch = len(data_loader)
    log.info('{} epochs in total, {} batches per epoch'.format(total_epochs, batches_per_epoch))
    loss_seg = nn.BCELoss()# nn.CrossEntropyLoss(ignore_index=-1)   #
    loss_seg = loss_seg.cuda()
    model.train()
    train_time_sp = time.time()
    for epoch in range(total_epochs):
        log.info('Start epoch {}'.format(epoch))
        log.info('lr = {}'.format(scheduler.get_lr()))
        for batch_id, batch_data in enumerate(data_loader):
            # getting data batch
            batch_id_sp = epoch * batches_per_epoch
            volumes, label_masks, class_array = batch_data  #####
            volumes = volumes.cuda()
            class_array = class_array.cuda()  #####
            optimizer.zero_grad()
            out_masks = model(volumes)
            print(volumes.shape)
            # calculating loss
            loss_value_seg = loss_seg(out_masks, class_array)  #####
            loss = loss_value_seg
            loss.requires_grad_(True)  #####
            loss.backward()
            optimizer.step()
            scheduler.step()
            avg_batch_time = (time.time() - train_time_sp) / (1 + batch_id_sp)
            log.info('Batch: {}-{} ({}), loss = {:.3f}, loss_seg = {:.3f}, avg_batch_time = {:.3f}' \
                     .format(epoch, batch_id, batch_id_sp, loss.item(), loss_value_seg.item(), avg_batch_time))
            # save model
            if batch_id == 0 and batch_id_sp != 0 and batch_id_sp % save_interval == 0:
                # if batch_id_sp != 0 and batch_id_sp % save_interval == 0:
                model_save_path = '{}_epoch_{}_batch_{}.pth.tar'.format(save_folder, epoch, batch_id)
                model_save_dir = os.path.dirname(model_save_path)
                if not os.path.exists(model_save_dir):
                    os.makedirs(model_save_dir)
                log.info('Save checkpoints: epoch = {}, batch_id = {}'.format(epoch, batch_id))
                torch.save({'ecpoch': epoch,'batch_id': batch_id,'state_dict': model.state_dict(),'optimizer': optimizer.state_dict()},model_save_path)
    print('Finished training')
train(data_loader=data_loader, model=model, optimizer=optimizer, scheduler=scheduler, total_epochs=total_epochs, save_interval=save_intervals, save_folder=save_folder)

你可能感兴趣的:(生信Python,医学图像处理,分类,深度学习,python)