1. 需要有GPU(推荐8G以上),已设置好CUDA:基于win10深度学习环境配置(conda,python,cuda11.7,torch1.13.0)_dr_yingli的博客-CSDN博客2. 文件格式为常见的nii
img_list格式如下
E:\...\3.nrrd E:\...\3.nrrd 0
E:\...\4.nrrd E:\...\4.nrrd 1
训练代码
import torch
from torch import nn
import os
import numpy as np
from torch.utils.data import Dataset
import nibabel
from scipy import ndimage
from torch import optim
from torch.utils.data import DataLoader
import time
import logging
root_dir = './data' # type=str, help='Root directory path of data'
img_list = './data/train.txt' # type=str, help='Path for image list file'
pretrain_path = 'pretrain/resnet_50.pth' # type=str, help='Path for pretrained model.'
save_folder = "./trails/models/Resnet50"
total_epochs = 20 # type=int, help='Number of total epochs to run'
save_intervals = 10 # type=int, help='Interation for saving model'
learning_rate = 0.001 # set to 0.001 when finetune, type=float, help= 'Initial learning rate (divided by 10 while training by lr scheduler)'
new_layer_names = ['conv_cls'] # type=list, help='New layer except for backbone'
batch_size = 1 # type=int, help='Batch Size'
input_D = 56 # type=int, help='Input size of depth'
input_H = 448 # type=int, help='Input size of height'
input_W = 448 # type=int, help='Input size of width'
torch.manual_seed(1)
class Bottleneck(nn.Module):
expansion = 4
def __init__(self, inplanes, planes, stride=1, dilation=1, downsample=None):
super(Bottleneck, self).__init__()
self.conv1 = nn.Conv3d(inplanes, planes, kernel_size=1, bias=False)
self.bn1 = nn.BatchNorm3d(planes)
self.conv2 = nn.Conv3d(planes, planes, kernel_size=3, stride=stride, dilation=dilation, padding=dilation, bias=False)
self.bn2 = nn.BatchNorm3d(planes)
self.conv3 = nn.Conv3d(planes, planes * 4, kernel_size=1, bias=False)
self.bn3 = nn.BatchNorm3d(planes * 4)
self.relu = nn.ReLU(inplace=True)
self.downsample = downsample
self.stride = stride
self.dilation = dilation
def forward(self, x):
residual = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
out = self.relu(out)
out = self.conv3(out)
out = self.bn3(out)
if self.downsample is not None:
residual = self.downsample(x)
out += residual
out = self.relu(out)
return out
class ResNet(nn.Module):
def __init__(self, block, layers, input_D, input_H, input_W):
self.inplanes = 64
super(ResNet, self).__init__()
self.conv1 = nn.Conv3d(1, 64, kernel_size=7, stride=(2, 2, 2), padding=(3, 3, 3), bias=False)
self.bn1 = nn.BatchNorm3d(64) # conv1的输出维度
self.relu = nn.ReLU(inplace=True)
self.maxpool = nn.MaxPool3d(kernel_size=(3, 3, 3), stride=2, padding=1) # H/2,W/2。C不变
self.layer1 = self._make_layer(block, 64, layers[0]) # H,W不变。downsample控制的shortcut,out_channel=64x4=256
self.layer2 = self._make_layer(block, 128, layers[1], stride=2) # H/2, W/2。downsample控制的shortcut,out_channel=128x4=512
self.layer3 = self._make_layer(block, 256, layers[2], stride=1, dilation=2) # H/2, W/2。downsample控制的shortcut,out_channel=256x4=1024
self.layer4 = self._make_layer(block, 512, layers[3], stride=1, dilation=4) # H/2, W/2。downsample控制的shortcut,out_channel=512x4=2048
self.conv_cls = nn.Sequential(
nn.AdaptiveMaxPool3d(output_size=(1, 1, 1)),
nn.Flatten(start_dim=1),
nn.Dropout(0.1),
nn.Linear(512 * block.expansion, 1)
)
for m in self.modules():
if isinstance(m, nn.Conv3d):
m.weight = nn.init.kaiming_normal_(m.weight, mode='fan_out')
elif isinstance(m, nn.BatchNorm3d):
m.weight.data.fill_(1)
m.bias.data.zero_()
def _make_layer(self, block, planes, blocks, stride=1, dilation=1):
downsample = None
if stride != 1 or self.inplanes != planes * block.expansion:
downsample = nn.Sequential(
nn.Conv3d(self.inplanes, planes * block.expansion, kernel_size=1, stride=stride, bias=False),
nn.BatchNorm3d(planes * block.expansion))
layers = []
layers.append(block(self.inplanes, planes, stride=stride, dilation=dilation, downsample=downsample))
self.inplanes = planes * block.expansion # 在下一次调用_make_layer函数的时候,self.in_channel已经x4
for i in range(1, blocks):
layers.append(block(self.inplanes, planes, dilation=dilation))
return nn.Sequential(*layers) # '*'的作用是将list转换为非关键字参数传入
def forward(self, x):
x = self.conv1(x)
x = self.bn1(x)
x = self.relu(x)
x = self.maxpool(x)
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.layer4(x)
x = self.conv_cls(x)
x = torch.sigmoid_(x)
return x
def generate_model(input_D, input_H, input_W, pretrain_path):
model = ResNet(Bottleneck, [3, 4, 6, 3], input_W=input_W, input_H=input_H, input_D=input_D)
model = model.cuda()
net_dict = model.state_dict()
print('loading pretrained model {}'.format(pretrain_path))
pretrain = torch.load(pretrain_path)
pretrain_dict = {k.replace("module.", ""): v for k, v in pretrain['state_dict'].items() if k.replace("module.", "") in net_dict.keys()}
net_dict.update(pretrain_dict) # 字典 dict2 的键/值对更新到 dict 里。
model.load_state_dict(net_dict) # model.load_state_dict()函数把加载的权重复制到模型的权重中去
new_parameters = []
for pname, p in model.named_parameters():
for layer_name in new_layer_names:
if pname.find(layer_name) >= 0:
new_parameters.append(p)
break
new_parameters_id = list(map(id, new_parameters))
base_parameters = list(filter(lambda p: id(p) not in new_parameters_id, model.parameters()))
parameters = {'base_parameters': base_parameters, 'new_parameters': new_parameters}
return model, parameters
model, parameters = generate_model(input_D, input_H, input_W, pretrain_path)
params = [{'params': parameters['base_parameters'], 'lr': learning_rate },{'params': parameters['new_parameters'], 'lr': learning_rate*100}]
optimizer = torch.optim.SGD(params, momentum=0.9, weight_decay=1e-3)
scheduler = optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.99)
class Dataset(Dataset):
def __init__(self, root_dir, img_list, input_D, input_H, input_W):
with open(img_list, 'r') as f:
self.img_list = [line.strip() for line in f]
print("Processing {} datas".format(len(self.img_list)))
self.root_dir = root_dir
self.input_D = input_D
self.input_H = input_H
self.input_W = input_W
def __nii2tensorarray__(self, data):
[z, y, x] = data.shape
new_data = np.reshape(data, [1, z, y, x])
new_data = new_data.astype("float32")
return new_data
def __len__(self):
return len(self.img_list)
def __getitem__(self, idx):
# read image and labels
ith_info = self.img_list[idx].split(" ")
img_name = os.path.join(self.root_dir, ith_info[0])
label_name = os.path.join(self.root_dir, ith_info[1])
class_array = np.zeros(1)
class_array[0] = ith_info[2]
class_array = torch.tensor(class_array, dtype=torch.float32) ######
assert os.path.isfile(img_name)
assert os.path.isfile(label_name)
img = nibabel.load(img_name) # We have transposed the data from WHD format to DHW
assert img is not None
mask = nibabel.load(label_name)
assert mask is not None
# data processing
img_array, mask_array = self.__training_data_process__(img, mask)
# 2 tensor array
img_array = self.__nii2tensorarray__(img_array)
mask_array = self.__nii2tensorarray__(mask_array)
assert img_array.shape == mask_array.shape, "img shape:{} is not equal to mask shape:{}".format(img_array.shape, mask_array.shape)
return img_array, mask_array, class_array #####
def __drop_invalid_range__(self, volume, label=None):
"""
Cut off the invalid area
"""
zero_value = volume[0, 0, 0]
non_zeros_idx = np.where(volume != zero_value)
[max_z, max_h, max_w] = np.max(np.array(non_zeros_idx), axis=1)
[min_z, min_h, min_w] = np.min(np.array(non_zeros_idx), axis=1)
if label is not None:
return volume[min_z:max_z, min_h:max_h, min_w:max_w], label[min_z:max_z, min_h:max_h, min_w:max_w]
else:
return volume[min_z:max_z, min_h:max_h, min_w:max_w]
def __random_center_crop__(self, data, label):
from random import random
"""
Random crop
"""
target_indexs = np.where(label > 0)
[img_d, img_h, img_w] = data.shape
[max_D, max_H, max_W] = np.max(np.array(target_indexs), axis=1)
[min_D, min_H, min_W] = np.min(np.array(target_indexs), axis=1)
[target_depth, target_height, target_width] = np.array([max_D, max_H, max_W]) - np.array([min_D, min_H, min_W])
Z_min = round((min_D - target_depth * 1.0 / 2) * random())
Y_min = round((min_H - target_height * 1.0 / 2) * random())
X_min = round((min_W - target_width * 1.0 / 2) * random())
Z_max = round(img_d - ((img_d - (max_D + target_depth * 1.0 / 2)) * random()))
Y_max = round(img_h - ((img_h - (max_H + target_height * 1.0 / 2)) * random()))
X_max = round(img_w - ((img_w - (max_W + target_width * 1.0 / 2)) * random()))
Z_min = np.max([0, Z_min])
Y_min = np.max([0, Y_min])
X_min = np.max([0, X_min])
Z_max = np.min([img_d, Z_max])
Y_max = np.min([img_h, Y_max])
X_max = np.min([img_w, X_max])
Z_min = round(Z_min)
Y_min = round(Y_min)
X_min = round(X_min)
Z_max = round(Z_max)
Y_max = round(Y_max)
X_max = round(X_max)
return data[Z_min: Z_max, Y_min: Y_max, X_min: X_max], label[Z_min: Z_max, Y_min: Y_max, X_min: X_max]
def __itensity_normalize_one_volume__(self, volume):
"""
normalize the itensity of an nd volume based on the mean and std of nonzeor region
inputs:
volume: the input nd volume
outputs:
out: the normalized nd volume
"""
pixels = volume[volume > 0]
mean = pixels.mean()
std = pixels.std()
out = (volume - mean) / std
out_random = np.random.normal(0, 1, size=volume.shape)
out[volume == 0] = out_random[volume == 0]
return out
def __resize_data__(self, data):
"""
Resize the data to the input size
"""
[depth, height, width] = data.shape
scale = [self.input_D * 1.0 / depth, self.input_H * 1.0 / height, self.input_W * 1.0 / width]
data = ndimage.zoom(data, scale, order=0)
return data
def __crop_data__(self, data, label):
"""
Random crop with different methods:
"""
# random center crop
data, label = self.__random_center_crop__(data, label)
return data, label
def __training_data_process__(self, data, label):
# crop data according net input size
data = data.get_fdata()
label = label.get_fdata()
# drop out the invalid range
data, label = self.__drop_invalid_range__(data, label)
# crop data
data, label = self.__crop_data__(data, label)
# resize data
data = self.__resize_data__(data)
label = self.__resize_data__(label)
# normalization datas
data = self.__itensity_normalize_one_volume__(data)
return data, label
training_dataset = Dataset(root_dir=root_dir, img_list=img_list, input_D=input_D, input_H=input_H, input_W=input_W)
data_loader = DataLoader(training_dataset, batch_size=batch_size, shuffle=True, pin_memory=True)
logging.basicConfig(format='%(asctime)s %(levelname)-8s [%(filename)s:%(lineno)d] %(message)s', datefmt='%Y-%m-%d %H:%M:%S', level=logging.DEBUG)
log = logging.getLogger()
def train(data_loader, model, optimizer, scheduler, total_epochs, save_interval, save_folder):
batches_per_epoch = len(data_loader)
log.info('{} epochs in total, {} batches per epoch'.format(total_epochs, batches_per_epoch))
loss_seg = nn.BCELoss()# nn.CrossEntropyLoss(ignore_index=-1) #
loss_seg = loss_seg.cuda()
model.train()
train_time_sp = time.time()
for epoch in range(total_epochs):
log.info('Start epoch {}'.format(epoch))
log.info('lr = {}'.format(scheduler.get_lr()))
for batch_id, batch_data in enumerate(data_loader):
# getting data batch
batch_id_sp = epoch * batches_per_epoch
volumes, label_masks, class_array = batch_data #####
volumes = volumes.cuda()
class_array = class_array.cuda() #####
optimizer.zero_grad()
out_masks = model(volumes)
print(volumes.shape)
# calculating loss
loss_value_seg = loss_seg(out_masks, class_array) #####
loss = loss_value_seg
loss.requires_grad_(True) #####
loss.backward()
optimizer.step()
scheduler.step()
avg_batch_time = (time.time() - train_time_sp) / (1 + batch_id_sp)
log.info('Batch: {}-{} ({}), loss = {:.3f}, loss_seg = {:.3f}, avg_batch_time = {:.3f}' \
.format(epoch, batch_id, batch_id_sp, loss.item(), loss_value_seg.item(), avg_batch_time))
# save model
if batch_id == 0 and batch_id_sp != 0 and batch_id_sp % save_interval == 0:
# if batch_id_sp != 0 and batch_id_sp % save_interval == 0:
model_save_path = '{}_epoch_{}_batch_{}.pth.tar'.format(save_folder, epoch, batch_id)
model_save_dir = os.path.dirname(model_save_path)
if not os.path.exists(model_save_dir):
os.makedirs(model_save_dir)
log.info('Save checkpoints: epoch = {}, batch_id = {}'.format(epoch, batch_id))
torch.save({'ecpoch': epoch,'batch_id': batch_id,'state_dict': model.state_dict(),'optimizer': optimizer.state_dict()},model_save_path)
print('Finished training')
train(data_loader=data_loader, model=model, optimizer=optimizer, scheduler=scheduler, total_epochs=total_epochs, save_interval=save_intervals, save_folder=save_folder)