img_list格式如下
E:\...\3.nrrd E:\...\3.nrrd 0
E:\...\4.nrrd E:\...\4.nrrd 1
训练代码
import torch
from torch import nn
import os
import numpy as np
from torch.utils.data import Dataset
from scipy import ndimage
from torch import optim
from torch.utils.data import DataLoader
import time
import logging
import nrrd
img_list = 'data/train.txt' # type=str, help='Path for image list file'
pretrain_path = 'pretrain/resnet_50.pth' # type=str, help='Path for pretrained model.'
save_folder = "trails/models/Resnet50"
total_epochs = 20 # type=int, help='Number of total epochs to run'
save_intervals = 10 # type=int, help='Interation for saving model'
learning_rate = 0.001 # set to 0.001 when finetune, type=float, help= 'Initial learning rate (divided by 10 while training by lr scheduler)'
new_layer_names = ['conv_cls'] # type=list, help='New layer except for backbone'
batch_size = 1 # type=int, help='Batch Size'
input_D = 56 # type=int, help='Input size of depth'
input_H = 448 # type=int, help='Input size of height'
input_W = 448 # type=int, help='Input size of width'
torch.manual_seed(1)
class Bottleneck(nn.Module):
expansion = 4
def __init__(self, inplanes, planes, stride=1, dilation=1, downsample=None):
super(Bottleneck, self).__init__()
self.conv1 = nn.Conv3d(inplanes, planes, kernel_size=1, bias=False)
self.bn1 = nn.BatchNorm3d(planes)
self.conv2 = nn.Conv3d(planes, planes, kernel_size=3, stride=stride, dilation=dilation, padding=dilation, bias=False)
self.bn2 = nn.BatchNorm3d(planes)
self.conv3 = nn.Conv3d(planes, planes * 4, kernel_size=1, bias=False)
self.bn3 = nn.BatchNorm3d(planes * 4)
self.relu = nn.ReLU(inplace=True)
self.downsample = downsample
self.stride = stride
self.dilation = dilation
def forward(self, x):
residual = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
out = self.relu(out)
out = self.conv3(out)
out = self.bn3(out)
if self.downsample is not None:
residual = self.downsample(x)
out += residual
out = self.relu(out)
return out
class ResNet(nn.Module):
def __init__(self, block, layers, input_D, input_H, input_W):
self.inplanes = 64
super(ResNet, self).__init__()
self.conv1 = nn.Conv3d(1, 64, kernel_size=7, stride=(2, 2, 2), padding=(3, 3, 3), bias=False)
self.bn1 = nn.BatchNorm3d(64) # conv1的输出维度
self.relu = nn.ReLU(inplace=True)
self.maxpool = nn.MaxPool3d(kernel_size=(3, 3, 3), stride=2, padding=1) # H/2,W/2。C不变
self.layer1 = self._make_layer(block, 64, layers[0]) # H,W不变。downsample控制的shortcut,out_channel=64x4=256
self.layer2 = self._make_layer(block, 128, layers[1], stride=2) # H/2, W/2。downsample控制的shortcut,out_channel=128x4=512
self.layer3 = self._make_layer(block, 256, layers[2], stride=1, dilation=2) # H/2, W/2。downsample控制的shortcut,out_channel=256x4=1024
self.layer4 = self._make_layer(block, 512, layers[3], stride=1, dilation=4) # H/2, W/2。downsample控制的shortcut,out_channel=512x4=2048
self.conv_cls = nn.Sequential(
nn.AdaptiveMaxPool3d(output_size=(1, 1, 1)),
nn.Flatten(start_dim=1),
nn.Dropout(0.1),
nn.Linear(512 * block.expansion, 1)
)
for m in self.modules():
if isinstance(m, nn.Conv3d):
m.weight = nn.init.kaiming_normal_(m.weight, mode='fan_out')
elif isinstance(m, nn.BatchNorm3d):
m.weight.data.fill_(1)
m.bias.data.zero_()
def _make_layer(self, block, planes, blocks, stride=1, dilation=1):
downsample = None
if stride != 1 or self.inplanes != planes * block.expansion:
downsample = nn.Sequential(
nn.Conv3d(self.inplanes, planes * block.expansion, kernel_size=1, stride=stride, bias=False),
nn.BatchNorm3d(planes * block.expansion))
layers = []
layers.append(block(self.inplanes, planes, stride=stride, dilation=dilation, downsample=downsample))
self.inplanes = planes * block.expansion # 在下一次调用_make_layer函数的时候,self.in_channel已经x4
for i in range(1, blocks):
layers.append(block(self.inplanes, planes, dilation=dilation))
return nn.Sequential(*layers) # '*'的作用是将list转换为非关键字参数传入
def forward(self, x):
x = self.conv1(x)
x = self.bn1(x)
x = self.relu(x)
x = self.maxpool(x)
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.layer4(x)
x = self.conv_cls(x)
x = torch.sigmoid_(x)
return x
def generate_model(input_D, input_H, input_W, pretrain_path):
model = ResNet(Bottleneck, [3, 4, 6, 3], input_W=input_W, input_H=input_H, input_D=input_D)
model = model.cuda()
net_dict = model.state_dict()
print('loading pretrained model {}'.format(pretrain_path))
pretrain = torch.load(pretrain_path)
pretrain_dict = {k.replace("module.", ""): v for k, v in pretrain['state_dict'].items() if k.replace("module.", "") in net_dict.keys()}
net_dict.update(pretrain_dict) # 字典 dict2 的键/值对更新到 dict 里。
model.load_state_dict(net_dict) # model.load_state_dict()函数把加载的权重复制到模型的权重中去
new_parameters = []
for pname, p in model.named_parameters():
for layer_name in new_layer_names:
if pname.find(layer_name) >= 0:
new_parameters.append(p)
break
new_parameters_id = list(map(id, new_parameters))
base_parameters = list(filter(lambda p: id(p) not in new_parameters_id, model.parameters()))
parameters = {'base_parameters': base_parameters, 'new_parameters': new_parameters}
return model, parameters
model, parameters = generate_model(input_D, input_H, input_W, pretrain_path)
params = [{'params': parameters['base_parameters'], 'lr': learning_rate },{'params': parameters['new_parameters'], 'lr': learning_rate*100}]
optimizer = torch.optim.SGD(params, momentum=0.9, weight_decay=1e-3)
scheduler = optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.99)
class Dataset(Dataset):
def __init__(self, img_list, input_D, input_H, input_W):
with open(img_list, 'r') as f:
self.img_list = [line.strip() for line in f]
print("Processing {} datas".format(len(self.img_list)))
self.input_D = input_D
self.input_H = input_H
self.input_W = input_W
def __nii2tensorarray__(self, data):
[z, y, x] = data.shape
new_data = np.reshape(data, [1, z, y, x])
new_data = new_data.astype("float32")
return new_data
def __len__(self):
return len(self.img_list)
def __getitem__(self, idx):
# read image and labels
ith_info = self.img_list[idx].split(" ")
img_name = ith_info[0]
label_name = ith_info[1]
class_array = np.zeros(1)
class_array[0] = ith_info[2]
class_array = torch.tensor(class_array, dtype=torch.float32) ######
assert os.path.isfile(img_name)
assert os.path.isfile(label_name)
img, _ = nrrd.read(img_name, index_order='C') # transposed the data from WHD format to DHW
assert img is not None
mask, _ = nrrd.read(label_name, index_order='C')
assert mask is not None
# data processing
img_array, mask_array = self.__training_data_process__(img, mask)
# 2 tensor array
img_array = self.__nii2tensorarray__(img_array)
mask_array = self.__nii2tensorarray__(mask_array)
assert img_array.shape == mask_array.shape, "img shape:{} is not equal to mask shape:{}".format(img_array.shape, mask_array.shape)
return img_array, mask_array, class_array #####
def __drop_invalid_range__(self, volume, label=None):
"""
Cut off the invalid area
"""
zero_value = volume[0, 0, 0]
non_zeros_idx = np.where(volume != zero_value)
[max_z, max_h, max_w] = np.max(np.array(non_zeros_idx), axis=1)
[min_z, min_h, min_w] = np.min(np.array(non_zeros_idx), axis=1)
if label is not None:
return volume[min_z:max_z, min_h:max_h, min_w:max_w], label[min_z:max_z, min_h:max_h, min_w:max_w]
else:
return volume[min_z:max_z, min_h:max_h, min_w:max_w]
def __random_center_crop__(self, data, label):
from random import random
"""
Random crop
"""
target_indexs = np.where(label > 0)
[img_d, img_h, img_w] = data.shape
[max_D, max_H, max_W] = np.max(np.array(target_indexs), axis=1)
[min_D, min_H, min_W] = np.min(np.array(target_indexs), axis=1)
[target_depth, target_height, target_width] = np.array([max_D, max_H, max_W]) - np.array([min_D, min_H, min_W])
Z_min = round((min_D - target_depth * 1.0 / 2) * random())
Y_min = round((min_H - target_height * 1.0 / 2) * random())
X_min = round((min_W - target_width * 1.0 / 2) * random())
Z_max = round(img_d - ((img_d - (max_D + target_depth * 1.0 / 2)) * random()))
Y_max = round(img_h - ((img_h - (max_H + target_height * 1.0 / 2)) * random()))
X_max = round(img_w - ((img_w - (max_W + target_width * 1.0 / 2)) * random()))
Z_min = np.max([0, Z_min])
Y_min = np.max([0, Y_min])
X_min = np.max([0, X_min])
Z_max = np.min([img_d, Z_max])
Y_max = np.min([img_h, Y_max])
X_max = np.min([img_w, X_max])
Z_min = round(Z_min)
Y_min = round(Y_min)
X_min = round(X_min)
Z_max = round(Z_max)
Y_max = round(Y_max)
X_max = round(X_max)
return data[Z_min: Z_max, Y_min: Y_max, X_min: X_max], label[Z_min: Z_max, Y_min: Y_max, X_min: X_max]
def __itensity_normalize_one_volume__(self, volume):
"""
normalize the itensity of an nd volume based on the mean and std of nonzeor region
inputs:
volume: the input nd volume
outputs:
out: the normalized nd volume
"""
pixels = volume[volume > 0]
mean = pixels.mean()
std = pixels.std()
out = (volume - mean) / std
out_random = np.random.normal(0, 1, size=volume.shape)
out[volume == 0] = out_random[volume == 0]
return out
def __resize_data__(self, data):
"""
Resize the data to the input size
"""
[depth, height, width] = data.shape
scale = [self.input_D * 1.0 / depth, self.input_H * 1.0 / height, self.input_W * 1.0 / width]
data = ndimage.zoom(data, scale, order=0)
return data
def __crop_data__(self, data, label):
"""
Random crop with different methods:
"""
# random center crop
data, label = self.__random_center_crop__(data, label)
return data, label
def __training_data_process__(self, data, label):
# drop out the invalid range
data, label = self.__drop_invalid_range__(data, label)
# crop data
data, label = self.__crop_data__(data, label)
# resize data
data = self.__resize_data__(data)
label = self.__resize_data__(label)
# normalization datas
data = self.__itensity_normalize_one_volume__(data)
return data, label
training_dataset = Dataset(img_list, input_D, input_H, input_W)
data_loader = DataLoader(training_dataset, batch_size, shuffle=True, pin_memory=True)
logging.basicConfig(format='%(asctime)s %(levelname)-8s [%(filename)s:%(lineno)d] %(message)s', datefmt='%Y-%m-%d %H:%M:%S', level=logging.DEBUG)
log = logging.getLogger()
def train(data_loader, model, optimizer, scheduler, total_epochs, save_interval, save_folder):
batches_per_epoch = len(data_loader)
log.info('{} epochs in total, {} batches per epoch'.format(total_epochs, batches_per_epoch))
loss_seg = nn.BCELoss() # nn.CrossEntropyLoss(ignore_index=-1) #
loss_seg = loss_seg.cuda()
model.train()
train_time_sp = time.time()
for epoch in range(total_epochs):
log.info('Start epoch {}'.format(epoch))
log.info('lr = {}'.format(scheduler.get_lr()))
probs = []
class_arrays = []
for batch_id, batch_data in enumerate(data_loader):
# getting data batch
batch_id_sp = epoch * batches_per_epoch
volumes, label_masks, class_array = batch_data #####
volumes = volumes.cuda()
class_array = class_array.cuda() #####
optimizer.zero_grad()
prob = model(volumes)
# calculating loss
loss_value_seg = loss_seg(prob, class_array) #####
loss = loss_value_seg
loss.requires_grad_(True) #####
loss.backward()
optimizer.step()
scheduler.step()
avg_batch_time = (time.time() - train_time_sp) / (1 + batch_id_sp)
log.info('Batch: {}-{} ({}), loss = {:.3f}, loss_seg = {:.3f}, avg_batch_time = {:.3f}' \
.format(epoch, batch_id, batch_id_sp, loss.item(), loss_value_seg.item(), avg_batch_time))
# save model
if batch_id == 0 and batch_id_sp != 0 and batch_id_sp % save_interval == 0:
# if batch_id_sp != 0 and batch_id_sp % save_interval == 0:
model_save_path = '{}_epoch_{}_batch_{}.pth.tar'.format(save_folder, epoch, batch_id)
model_save_dir = os.path.dirname(model_save_path)
if not os.path.exists(model_save_dir):
os.makedirs(model_save_dir)
log.info('Save checkpoints: epoch = {}, batch_id = {}'.format(epoch, batch_id))
torch.save({'ecpoch': epoch,'batch_id': batch_id,'state_dict': model.state_dict(),'optimizer': optimizer.state_dict()},model_save_path)
probs.append(prob.item())
class_arrays.append(class_array.item())
print('Finished training')
return class_arrays, probs
train(data_loader, model, optimizer, scheduler, total_epochs, save_interval=save_intervals, save_folder=save_folder)
验证代码
import torch
from torch import nn
import os
import numpy as np
from torch.utils.data import Dataset
import nrrd
from scipy import ndimage
from torch.utils.data import DataLoader
from sklearn.metrics import roc_curve, auc
img_list = 'data/val.txt' # type=str, help='Path for image list file'
resume_path = 'trails/models/best.tar'
batch_size = 1 # type=int, help='Batch Size'
input_D = 56 # type=int, help='Input size of depth'
input_H = 448 # type=int, help='Input size of height'
input_W = 448 # type=int, help='Input size of width'
torch.manual_seed(1)
class Bottleneck(nn.Module):
expansion = 4
def __init__(self, inplanes, planes, stride=1, dilation=1, downsample=None):
super(Bottleneck, self).__init__()
self.conv1 = nn.Conv3d(inplanes, planes, kernel_size=1, bias=False)
self.bn1 = nn.BatchNorm3d(planes)
self.conv2 = nn.Conv3d(planes, planes, kernel_size=3, stride=stride, dilation=dilation, padding=dilation, bias=False)
self.bn2 = nn.BatchNorm3d(planes)
self.conv3 = nn.Conv3d(planes, planes * 4, kernel_size=1, bias=False)
self.bn3 = nn.BatchNorm3d(planes * 4)
self.relu = nn.ReLU(inplace=True)
self.downsample = downsample
self.stride = stride
self.dilation = dilation
def forward(self, x):
residual = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
out = self.relu(out)
out = self.conv3(out)
out = self.bn3(out)
if self.downsample is not None:
residual = self.downsample(x)
out += residual
out = self.relu(out)
return out
class ResNet(nn.Module):
def __init__(self, block, layers, input_D, input_H, input_W):
self.inplanes = 64
super(ResNet, self).__init__()
self.conv1 = nn.Conv3d(1, 64, kernel_size=7, stride=(2, 2, 2), padding=(3, 3, 3), bias=False)
self.bn1 = nn.BatchNorm3d(64) # conv1的输出维度
self.relu = nn.ReLU(inplace=True)
self.maxpool = nn.MaxPool3d(kernel_size=(3, 3, 3), stride=2, padding=1) # H/2,W/2。C不变
self.layer1 = self._make_layer(block, 64, layers[0]) # H,W不变。downsample控制的shortcut,out_channel=64x4=256
self.layer2 = self._make_layer(block, 128, layers[1], stride=2) # H/2, W/2。downsample控制的shortcut,out_channel=128x4=512
self.layer3 = self._make_layer(block, 256, layers[2], stride=1, dilation=2) # H/2, W/2。downsample控制的shortcut,out_channel=256x4=1024
self.layer4 = self._make_layer(block, 512, layers[3], stride=1, dilation=4) # H/2, W/2。downsample控制的shortcut,out_channel=512x4=2048
self.conv_cls = nn.Sequential(
nn.AdaptiveMaxPool3d(output_size=(1, 1, 1)),
nn.Flatten(start_dim=1),
nn.Dropout(0.1),
nn.Linear(512 * block.expansion, 1)
)
for m in self.modules():
if isinstance(m, nn.Conv3d):
m.weight = nn.init.kaiming_normal_(m.weight, mode='fan_out')
elif isinstance(m, nn.BatchNorm3d):
m.weight.data.fill_(1)
m.bias.data.zero_()
def _make_layer(self, block, planes, blocks, stride=1, dilation=1):
downsample = None
if stride != 1 or self.inplanes != planes * block.expansion:
downsample = nn.Sequential(
nn.Conv3d(self.inplanes, planes * block.expansion, kernel_size=1, stride=stride, bias=False),
nn.BatchNorm3d(planes * block.expansion))
layers = []
layers.append(block(self.inplanes, planes, stride=stride, dilation=dilation, downsample=downsample))
self.inplanes = planes * block.expansion # 在下一次调用_make_layer函数的时候,self.in_channel已经x4
for i in range(1, blocks):
layers.append(block(self.inplanes, planes, dilation=dilation))
return nn.Sequential(*layers) # '*'的作用是将list转换为非关键字参数传入
def forward(self, x):
x = self.conv1(x)
x = self.bn1(x)
x = self.relu(x)
x = self.maxpool(x)
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.layer4(x)
x = self.conv_cls(x)
x = torch.sigmoid_(x)
return x
class Dataset(Dataset):
def __init__(self, img_list, input_D, input_H, input_W):
with open(img_list, 'r') as f:
self.img_list = [line.strip() for line in f]
print("Processing {} datas".format(len(self.img_list)))
self.input_D = input_D
self.input_H = input_H
self.input_W = input_W
def __nii2tensorarray__(self, data):
[z, y, x] = data.shape
new_data = np.reshape(data, [1, z, y, x])
new_data = new_data.astype("float32")
return new_data
def __len__(self):
return len(self.img_list)
def __getitem__(self, idx):
# read image and labels
ith_info = self.img_list[idx].split(" ")
img_name = ith_info[0]
label_name = ith_info[1]
class_array = np.zeros(1)
class_array[0] = ith_info[2]
class_array = torch.tensor(class_array, dtype=torch.float32) ######
assert os.path.isfile(img_name)
assert os.path.isfile(label_name)
img, _ = nrrd.read(img_name, index_order='C') # transposed the data from WHD format to DHW
assert img is not None
mask, _ = nrrd.read(label_name, index_order='C')
assert mask is not None
# data processing
img_array, mask_array = self.__training_data_process__(img, mask)
# 2 tensor array
img_array = self.__nii2tensorarray__(img_array)
mask_array = self.__nii2tensorarray__(mask_array)
assert img_array.shape == mask_array.shape, "img shape:{} is not equal to mask shape:{}".format(img_array.shape, mask_array.shape)
return img_array, mask_array, class_array #####
def __drop_invalid_range__(self, volume, label=None):
"""
Cut off the invalid area
"""
zero_value = volume[0, 0, 0]
non_zeros_idx = np.where(volume != zero_value)
[max_z, max_h, max_w] = np.max(np.array(non_zeros_idx), axis=1)
[min_z, min_h, min_w] = np.min(np.array(non_zeros_idx), axis=1)
if label is not None:
return volume[min_z:max_z, min_h:max_h, min_w:max_w], label[min_z:max_z, min_h:max_h, min_w:max_w]
else:
return volume[min_z:max_z, min_h:max_h, min_w:max_w]
def __random_center_crop__(self, data, label):
from random import random
"""
Random crop
"""
target_indexs = np.where(label > 0)
[img_d, img_h, img_w] = data.shape
[max_D, max_H, max_W] = np.max(np.array(target_indexs), axis=1)
[min_D, min_H, min_W] = np.min(np.array(target_indexs), axis=1)
[target_depth, target_height, target_width] = np.array([max_D, max_H, max_W]) - np.array([min_D, min_H, min_W])
Z_min = round((min_D - target_depth * 1.0 / 2) * random())
Y_min = round((min_H - target_height * 1.0 / 2) * random())
X_min = round((min_W - target_width * 1.0 / 2) * random())
Z_max = round(img_d - ((img_d - (max_D + target_depth * 1.0 / 2)) * random()))
Y_max = round(img_h - ((img_h - (max_H + target_height * 1.0 / 2)) * random()))
X_max = round(img_w - ((img_w - (max_W + target_width * 1.0 / 2)) * random()))
Z_min = np.max([0, Z_min])
Y_min = np.max([0, Y_min])
X_min = np.max([0, X_min])
Z_max = np.min([img_d, Z_max])
Y_max = np.min([img_h, Y_max])
X_max = np.min([img_w, X_max])
Z_min = round(Z_min)
Y_min = round(Y_min)
X_min = round(X_min)
Z_max = round(Z_max)
Y_max = round(Y_max)
X_max = round(X_max)
return data[Z_min: Z_max, Y_min: Y_max, X_min: X_max], label[Z_min: Z_max, Y_min: Y_max, X_min: X_max]
def __itensity_normalize_one_volume__(self, volume):
"""
normalize the itensity of an nd volume based on the mean and std of nonzeor region
inputs:
volume: the input nd volume
outputs:
out: the normalized nd volume
"""
pixels = volume[volume > 0]
mean = pixels.mean()
std = pixels.std()
out = (volume - mean) / std
out_random = np.random.normal(0, 1, size=volume.shape)
out[volume == 0] = out_random[volume == 0]
return out
def __resize_data__(self, data):
"""
Resize the data to the input size
"""
[depth, height, width] = data.shape
scale = [self.input_D * 1.0 / depth, self.input_H * 1.0 / height, self.input_W * 1.0 / width]
data = ndimage.zoom(data, scale, order=0)
return data
def __crop_data__(self, data, label):
"""
Random crop with different methods:
"""
# random center crop
data, label = self.__random_center_crop__(data, label)
return data, label
def __training_data_process__(self, data, label):
# drop out the invalid range
data, label = self.__drop_invalid_range__(data, label)
# crop data
data, label = self.__crop_data__(data, label)
# resize data
data = self.__resize_data__(data)
label = self.__resize_data__(label)
# normalization datas
data = self.__itensity_normalize_one_volume__(data)
return data, label
def test(data_loader, model):
probs = []
class_arrays = []
model.eval() # for testing
for batch_id, batch_data in enumerate(data_loader):
# forward
volume, label_mask, class_array = batch_data #####
volume = volume.cuda()
with torch.no_grad():
prob = model(volume)
probs.append(prob.cpu().item())
class_arrays.append(class_array.cpu().item())
return class_arrays, probs
net = ResNet(Bottleneck, [3, 4, 6, 3], input_W, input_H, input_D)
net = net.cuda()
checkpoint = torch.load(resume_path)
net.load_state_dict(checkpoint['state_dict'])
testing_data = Dataset(img_list, input_D, input_H, input_W)
data_loader = DataLoader(testing_data, batch_size, shuffle=False, pin_memory=False)
class_arrays, probs= test(data_loader, net)
fpr, tpr, thresholds = roc_curve(class_arrays, probs)
roc_auc = auc(fpr, tpr)
print(roc_auc)