[pytorch] Resnet3D预训练网络 + MedMNIST 3D医学数据分类

[pytorch] MedMNIST 3D医学数据分类

  • MedMNIST数据集
  • OrganMNIST3D 多分类任务
    • 加载库
    • 加载数据
    • 使用Resnet3D预训练网络
    • train
    • 结果
  • VesselMNIST3D 二分类任务

MedMNIST数据集

医学数据集的资源往往是比较难找的,3d数据集公开的更少。而MedMNIST v2,是一个大规模的类似 MNIST 的标准化生物医学图像集合,包括 12 个 2D 数据集和 6 个 3D 数据集。所有图像都被预处理成 28 x 28 (2D) 或 28 x 28 x 28 (3D) 并带有相应的分类标签,因此用户不需要背景知识。MedMNIST v2 涵盖生物医学图像中的主要数据模式,旨在对具有各种数据规模(从 100 到 100,000)和不同任务(二元/多类、序数回归和多标签)的轻量级 2D 和 3D 图像执行分类。
我们可以使用它来测试我们的3d网络等等。
数据集介绍:MedMNIST v2: A Large-Scale Lightweight Benchmark for 2D and 3D Biomedical Image Classification
github:MedMNIST
我们这里分析两个3d数据集 OrganMNIST3D 和 VesselMNIST3D,分别实现多分类和二分类。
[pytorch] Resnet3D预训练网络 + MedMNIST 3D医学数据分类_第1张图片
[pytorch] Resnet3D预训练网络 + MedMNIST 3D医学数据分类_第2张图片

OrganMNIST3D 多分类任务

加载库

from tqdm import tqdm
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torch.utils.data as data
import torchvision.transforms as transforms

import medmnist
from medmnist import INFO, Evaluator


import os
import time
import torch.nn as nn
import torch
import torchvision.transforms as transforms
from PIL import Image
from matplotlib import pyplot as plt
import torchvision.models as models
import torchsummary
import time
from torch.optim.lr_scheduler import ExponentialLR

使用tensorboard记录结果

from torch.utils.tensorboard import SummaryWriter

summaryWriter = SummaryWriter("./logs/")

加载数据

batch_size = 256

数据处理

class Transform3D:

    def __init__(self, mul=None):
        self.mul = mul

    def __call__(self, voxel):
   
        if self.mul == '0.5':
            voxel = voxel * 0.5
        elif self.mul == 'random':
            voxel = voxel * np.random.uniform()
        
        return voxel.astype(np.float32)

下载数据

print('==> Preparing data...')
train_transform = Transform3D(mul='random') 
eval_transform = Transform3D(mul='0.5')

data_flag = 'organmnist3d' # Multi-Class (11)
download = True

info = INFO[data_flag]
DataClass = getattr(medmnist, info['python_class'])

# load the data
train_dataset = DataClass(split='train', transform=train_transform, download=download)
val_dataset = DataClass(split='val', transform=eval_transform, download=download)
test_dataset = DataClass(split='test', transform=eval_transform, download=download)

3d数据可视化函数

def draw_oct(volume, type_volume = 'np',canal_first = False):
    if type_volume == 'np':
        if canal_first == False:
            print("taille du volume = %s (%s)"%(volume.shape,type_volume))
            slice_h_n, slice_d_n , slice_w_n = int(volume.shape[0]/2),int(volume.shape[1]/2),int(volume.shape[2]/2) 
            slice_h = volume[slice_h_n,:,:,:]
            slice_d = volume[:,slice_d_n,:,:]
            slice_w = volume[:,:,slice_w_n,:]
            slice_h = Image.fromarray(np.squeeze(slice_h))
            slice_d = Image.fromarray(np.squeeze(slice_d))
            slice_w = Image.fromarray(np.squeeze(slice_w))
            plt.figure(figsize=(21,7))
            plt.subplot(1, 3, 1)
            plt.imshow(slice_h)
            plt.title(slice_h.size)
            plt.axis('off')
            plt.subplot(1, 3, 2)
            plt.imshow(slice_d)
            plt.title(slice_d.size)
            plt.axis('off')
            plt.subplot(1, 3, 3)
            plt.imshow(slice_w)
            plt.title(slice_w.size)
            plt.axis('off')
        if canal_first == True:
            print("taille du volume = %s (%s)"%(volume.shape,type_volume))
            slice_h_n, slice_d_n , slice_w_n = int(volume.shape[1]/2),int(volume.shape[2]/2),int(volume.shape[3]/2) 
            slice_h = volume[:,slice_h_n,:,:]
            slice_d = volume[:,:,slice_d_n,:]
            slice_w = volume[:,:,:,slice_w_n]
            slice_h = Image.fromarray(np.squeeze(slice_h))
            slice_d = Image.fromarray(np.squeeze(slice_d))
            slice_w = Image.fromarray(np.squeeze(slice_w))
            plt.figure(figsize=(21,7))
            plt.subplot(1, 3, 1)
            plt.imshow(slice_h)
            plt.title(slice_h.size)
            plt.axis('off')
            plt.subplot(1, 3, 2)
            plt.imshow(slice_d)
            plt.title(slice_d.size)
            plt.axis('off')
            plt.subplot(1, 3, 3)
            plt.imshow(slice_w)
            plt.title(slice_w.size)
            plt.axis('off')
            
    if type_volume == 'tensor':
        if canal_first == False:
            print("taille du volume = %s (%s)"%(volume.shape,type_volume))
            slice_h_n, slice_d_n , slice_w_n = int(volume.shape[0]/2),int(volume.shape[1]/2),int(volume.shape[2]/2) 
            slice_h = volume[slice_h_n,:,:,:].numpy()
            slice_d = volume[:,slice_d_n,:,:].numpy()
            slice_w = volume[:,:,slice_w_n,:].numpy()
            slice_h = Image.fromarray(np.squeeze(slice_h))
            slice_d = Image.fromarray(np.squeeze(slice_d))
            slice_w = Image.fromarray(np.squeeze(slice_w))
            plt.figure(figsize=(21,7))
            plt.subplot(1, 3, 1)
            plt.imshow(slice_h)
            plt.title(slice_h.size)
            plt.axis('off')
            plt.subplot(1, 3, 2)
            plt.imshow(slice_d)
            plt.title(slice_d.size)
            plt.axis('off')
            plt.subplot(1, 3, 3)
            plt.imshow(slice_w)
            plt.title(slice_w.size)
            plt.axis('off')
        if canal_first == True:
            slice_h_n, slice_d_n , slice_w_n = int(volume.shape[1]/2),int(volume.shape[2]/2),int(volume.shape[3]/2) 
            slice_h = volume[:,slice_h_n,:,:].numpy()
            slice_d = volume[:,:,slice_d_n,:].numpy()
            slice_w = volume[:,:,:,slice_w_n].numpy()
            slice_h = Image.fromarray(np.squeeze(slice_h))
            slice_d = Image.fromarray(np.squeeze(slice_d))
            slice_w = Image.fromarray(np.squeeze(slice_w))
            plt.figure(figsize=(21,7))
            plt.subplot(1, 3, 1)
            plt.imshow(slice_h)
            plt.title(slice_h.size)
            plt.axis('off')
            plt.subplot(1, 3, 2)
            plt.imshow(slice_d)
            plt.title(slice_d.size)
            plt.axis('off')
            plt.subplot(1, 3, 3)
            plt.imshow(slice_w)
            plt.title(slice_w.size)
            plt.axis('off')

x, y = train_dataset[0]
print(x.shape, y)
draw_oct(x*500,type_volume = 'np',canal_first = True)

[pytorch] Resnet3D预训练网络 + MedMNIST 3D医学数据分类_第3张图片
产生dataloader

train_loader = data.DataLoader(dataset=train_dataset,
                            batch_size=batch_size,
                            shuffle=True)
val_loader = data.DataLoader(dataset=val_dataset,
                            batch_size=batch_size,
                            shuffle=False)
test_loader = data.DataLoader(dataset=test_dataset,
                            batch_size=batch_size,
                            shuffle=False)
for x, y in train_loader:
    print(x.shape, y.shape)
    break
# torch.Size([256, 1, 28, 28, 28]) torch.Size([256, 1])

使用Resnet3D预训练网络

我使用了MedicalNet的预训练resnet模型。
mednet的网络是用于分割任务的,所以其结构是resnet提取特征图像,最后加反卷积层做分割。我们的任务是分类,于是我将最后的反卷积层替换为分类层。
resnet3d预训练模型参数可以从官方的github上下载,然后直接像下面一样加载即可。注意:需要使用mednet项目代码中的models文件夹,将这个文件夹和要加载的预训练参数复制到自己的项目中。

from models import resnet
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
print('device =',device)
print(torch.cuda.get_device_name(0))
def generate_model(model_type='resnet', model_depth=50,
                   input_W=224, input_H=224, input_D=224, resnet_shortcut='B',
                   no_cuda=False, gpu_id=[0],
                   pretrain_path = 'pretrain/resnet_50.pth',
                   nb_class=1):
    assert model_type in [
        'resnet'
    ]

    if model_type == 'resnet':
        assert model_depth in [10, 18, 34, 50, 101, 152, 200]

    if model_depth == 10:
        model = resnet.resnet10(
            sample_input_W=input_W,
            sample_input_H=input_H,
            sample_input_D=input_D,
            shortcut_type=resnet_shortcut,
            no_cuda=no_cuda,
            num_seg_classes=1)
        fc_input = 256
    elif model_depth == 18:
        model = resnet.resnet18(
            sample_input_W=input_W,
            sample_input_H=input_H,
            sample_input_D=input_D,
            shortcut_type=resnet_shortcut,
            no_cuda=no_cuda,
            num_seg_classes=1)
        fc_input = 512
    elif model_depth == 34:
        model = resnet.resnet34(
            sample_input_W=input_W,
            sample_input_H=input_H,
            sample_input_D=input_D,
            shortcut_type=resnet_shortcut,
            no_cuda=no_cuda,
            num_seg_classes=1)
        fc_input = 512
    elif model_depth == 50:
        model = resnet.resnet50(
            sample_input_W=input_W,
            sample_input_H=input_H,
            sample_input_D=input_D,
            shortcut_type=resnet_shortcut,
            no_cuda=no_cuda,
            num_seg_classes=1)
        fc_input = 2048
    elif model_depth == 101:
        model = resnet.resnet101(
            sample_input_W=input_W,
            sample_input_H=input_H,
            sample_input_D=input_D,
            shortcut_type=resnet_shortcut,
            no_cuda=no_cuda,
            num_seg_classes=1)
        fc_input = 2048
    elif model_depth == 152:
        model = resnet.resnet152(
            sample_input_W=input_W,
            sample_input_H=input_H,
            sample_input_D=input_D,
            shortcut_type=resnet_shortcut,
            no_cuda=no_cuda,
            num_seg_classes=1)
        fc_input = 2048
    elif model_depth == 200:
        model = resnet.resnet200(
            sample_input_W=input_W,
            sample_input_H=input_H,
            sample_input_D=input_D,
            shortcut_type=resnet_shortcut,
            no_cuda=no_cuda,
            num_seg_classes=1)
        fc_input = 2048

    model.conv_seg = nn.Sequential(nn.AdaptiveAvgPool3d((1, 1, 1)), nn.Flatten(),
                                   nn.Linear(in_features=fc_input, out_features=nb_class, bias=True))

    if not no_cuda:
        if len(gpu_id) > 1:
            model = model.cuda()
            model = nn.DataParallel(model, device_ids=gpu_id)
            net_dict = model.state_dict()
        else:
            import os
            os.environ["CUDA_VISIBLE_DEVICES"]=str(gpu_id[0])
            model = model.cuda()
            model = nn.DataParallel(model, device_ids=None)
            net_dict = model.state_dict()
    else:
        net_dict = model.state_dict()

    print('loading pretrained model {}'.format(pretrain_path))
    pretrain = torch.load(pretrain_path)
    pretrain_dict = {k: v for k, v in pretrain['state_dict'].items() if k in net_dict.keys()}
    # k 是每一层的名称,v是权重数值
    net_dict.update(pretrain_dict) #字典 dict2 的键/值对更新到 dict 里。
    model.load_state_dict(net_dict) #model.load_state_dict()函数把加载的权重复制到模型的权重中去

    print("-------- pre-train model load successfully --------")

    return model
model = generate_model(model_type='resnet', model_depth=50,
                   input_W=224, input_H=224, input_D=224, resnet_shortcut='B',
                   no_cuda=False, gpu_id=[0],
                   pretrain_path = './resnet_50_23dataset.pth',
                   nb_class=11)

在这里插入图片描述

train

optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
criterion = nn.CrossEntropyLoss()
scheduler = ExponentialLR(optimizer, gamma=0.99)
num_epochs = 800
total_step = len(train_loader)
time_list = []
for epoch in range(num_epochs):
    start = time.time()
    per_epoch_loss = 0
    num_correct= 0
    val_num_correct = 0 
    model.train()
    with torch.enable_grad():
        for x,label in tqdm(train_loader):

            x = x.to(device)
            label = label.to(device)
            label = torch.squeeze(label)# label的形状是 [256,1] 要将其变成 [256]
            # Forward pass
            logits = model(x)
            loss = criterion(logits, label)

            per_epoch_loss += loss.item()

            # Backward and optimize
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            pred = logits.argmax(dim=1)
            num_correct += torch.eq(pred, label).sum().float().item()
        print("Train Epoch: {}\t Loss: {:.6f}\t Acc: {:.6f}".format(epoch,per_epoch_loss/total_step,num_correct/len(train_loader.dataset)))
        summaryWriter.add_scalars('loss', {"loss":(per_epoch_loss/total_step)}, epoch)
        summaryWriter.add_scalars('acc', {"acc":num_correct/len(train_loader.dataset)}, epoch)
        
    model.eval()
    with torch.no_grad():
        for x,label in tqdm(val_loader):
            x = x.to(device)
            label = label.to(device)
            label = torch.squeeze(label)
            # Forward pass
            logits = model(x)
            pred = logits.argmax(dim=1)
            val_num_correct += torch.eq(pred, label).sum().float().item()
        print("val Epoch: {}\t Acc: {:.6f}".format(epoch,num_correct/len(train_loader.dataset)))
        
        summaryWriter.add_scalars('acc', {"val_acc":val_num_correct/len(val_loader.dataset)}, epoch)
        summaryWriter.add_scalars('time', {"time":(time.time() - start)}, epoch)
    scheduler.step()

结果

最后让我们看一下训练结果。
[pytorch] Resnet3D预训练网络 + MedMNIST 3D医学数据分类_第4张图片

VesselMNIST3D 二分类任务

大体上和多分类任务是一样的,有几段代码需要修改。
数据下载

data_flag = 'vesselmnist3d' # Binary-Class (2)
download = True

info = INFO[data_flag]
DataClass = getattr(medmnist, info['python_class'])

# load the data
train_dataset = DataClass(split='train', transform=train_transform, download=download)
val_dataset = DataClass(split='val', transform=eval_transform, download=download)
test_dataset = DataClass(split='test', transform=eval_transform, download=download)

可视化
[pytorch] Resnet3D预训练网络 + MedMNIST 3D医学数据分类_第5张图片
加载模型

model = generate_model(model_type='resnet', model_depth=50,
                   input_W=28, input_H=28, input_D=28, resnet_shortcut='B',
                   no_cuda=False, gpu_id=[0],
                   pretrain_path = './resnet_50_23dataset.pth',
                   nb_class=1)

训练参数,二分类使用BCEWithLogitsLoss

optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
criterion = torch.nn.BCEWithLogitsLoss(pos_weight=torch.tensor([10.0])).cuda() #分类不均衡
scheduler = ExponentialLR(optimizer, gamma=0.99)
num_epochs = 1200

我们使用acc和auc作为指标

from sklearn.metrics import roc_curve
from sklearn.metrics import auc
for epoch in range(num_epochs):
    start = time.time()
    per_epoch_loss = 0
    num_correct= 0
    score_list = [] 
    label_list = []
    
    val_num_correct = 0
    val_score_list = []
    val_label_list = []
    
    model.train()
    with torch.enable_grad():
        for x,label in tqdm(train_loader):
            x = x.float()
            x = x.to(device)
            label = label.to(device)
            label = torch.squeeze(label)
            label_list.extend(label.cpu().numpy())
            #print(label_list)
            
            
            # Forward pass
            logits = model(x)
            logits = logits.reshape([label.cpu().numpy().shape[0]])
            prob_out = nn.Sigmoid()(logits)
            #print(logits.shape)
            
            pro_list = prob_out.detach().cpu().numpy()
            #print(pro_list)
            #print(abc)
            #print(pro_list)
            for i in range(pro_list.shape[0]):
                if (pro_list[i] > 0.5) == label.cpu().numpy()[i]:
                    num_correct += 1
            
            score_list.extend(pro_list)
            
            loss = criterion(logits, label.float())

            per_epoch_loss += loss.item()

            # Backward and optimize
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            #pred = logits.argmax(dim=1)
            #num_correct += torch.eq(pred, label).sum().float().item()

        score_array = np.array(score_list)
        label_array = np.array(label_list)
        fpr_keras_1, tpr_keras_1, thresholds_keras_1 = roc_curve(label_array, score_array)
        auc_keras_1 = auc(fpr_keras_1,tpr_keras_1)        

        print("Train EVpoch: {}\t Loss: {:.6f}\t Acc: {:.6f} AUC: {:.6f} ".format(epoch,per_epoch_loss/len(train_loader),num_correct/len(train_loader.dataset),auc_keras_1))
        summaryWriter.add_scalars('loss', {"loss":(per_epoch_loss/len(train_loader))}, epoch)
        summaryWriter.add_scalars('acc', {"acc":num_correct/len(train_loader.dataset)}, epoch)
        summaryWriter.add_scalars('auc', {"auc":auc_keras_1}, epoch)
        
    model.eval()
    with torch.no_grad():
        for x,label in tqdm(val_loader):
            x = x.float() 
            x = x.to(device)
            label = label.to(device)
            #label_n = label.cpu().numpy()
            
            val_label_list.extend(label.cpu().numpy())
            
            # Forward pass
            logits = model(x)
            logits = logits.reshape([label.cpu().numpy().shape[0]])
            prob_out = nn.Sigmoid()(logits)
            #print(logits.shape)
            
            pro_list = prob_out.detach().cpu().numpy()
            
            #print(pro_list)
            for i in range(pro_list.shape[0]):
                if (pro_list[i] > 0.5) == label.cpu().numpy()[i]:
                    val_num_correct += 1
            
            val_score_list.extend(pro_list)
            

        score_array = np.array(val_score_list)
        label_array = np.array(val_label_list)
        fpr_keras_1, tpr_keras_1, thresholds_keras_1 = roc_curve(label_array, score_array)
        auc_keras_1 = auc(fpr_keras_1,tpr_keras_1)        

        print("val Epoch: {}\t Acc: {:.6f} AUC: {:.6f} ".format(epoch,val_num_correct/len(val_loader.dataset),auc_keras_1))
        summaryWriter.add_scalars('acc', {"val_acc":val_num_correct/len(val_loader.dataset)}, epoch)
        summaryWriter.add_scalars('auc', {"val_auc":auc_keras_1}, epoch)
        summaryWriter.add_scalars('time', {"time":(time.time() - start)}, epoch)
        
    scheduler.step()

    #filepath = "./weights"
    #folder = os.path.exists(filepath)
    #if not folder:
    #    # 判断是否存在文件夹如果不存在则创建为文件夹
    #    os.makedirs(filepath)
    #path = './weights/model' + str(epoch) + '.pth'
    #torch.save(model.state_dict(), path)

结果
[pytorch] Resnet3D预训练网络 + MedMNIST 3D医学数据分类_第6张图片[pytorch] Resnet3D预训练网络 + MedMNIST 3D医学数据分类_第7张图片

你可能感兴趣的:(医学图像,pytorch,深度学习,分类,图像处理)