如何来实现模型并行?【在batch_size为1的情况下?】

【提问】在batch size为1时,使用数据并行和模型并行哪个更合理?
【解答】`在batch size为1时,使用模型并行比数据并行更合理。

模型并行允许将模型的不同部分分配到不同的设备上,并在每个设备上处理不同的部分。因此,即使batch size为1,仍然可以利用多个设备来处理模型。

相比之下,数据并行将输入数据分割到不同的设备上进行处理,每个设备只处理一部分数据,因此对于batch size为1的情况,数据并行无法发挥作用。

需要注意的是,使用模型并行需要对模型进行适当的分解和重新组织,以确保每个设备处理的部分可以协同工作,而不是产生矛盾或重复的结果。这可能需要对模型进行重构和调整,并且可能会带来额外的复杂性和实现难度。

综上所述,在batch size为1时,使用模型并行比数据并行更合理,但需要进行适当的模型重构和调整。`
首先对待修改模型中的部分代码的解释:

def relocate(self):
        device=torch.device("cuda" if torch.cuda.is_available() else "cpu")
        if torch.cuda.device_count() >1:
            device_ids = list(range(torch.cuda.device_count()))
            self.layer1 = nn.DataParallel(self.layer1, device_ids=device_ids).to('cuda:0')
            self.pos_layer1 = nn.DataParallel(self.pos_layer1, device_ids=device_ids).to('cuda:0')
            self.layer2 = nn.DataParallel(self.layer2, device_ids=device_ids).to('cuda:0')
            self.pos_layer2 = nn.DataParallel(self.pos_layer2, device_ids=device_ids).to('cuda:0')
            self.norm = nn.DataParallel(self.norm, device_ids=device_ids).to('cuda:0')
            self.layer3 = nn.DataParallel(self.layer3, device_ids=device_ids).to('cuda:0')
            self.conv1 = nn.DataParallel(self.conv1, device_ids=device_ids).to('cuda:0')
            self._fc1 = nn.DataParallel(self._fc1, device_ids=device_ids).to('cuda:0')
            self._fc2 = nn.DataParallel(self._fc2, device_ids=device_ids).to('cuda:0')
        else: #DELL--->run 
            self.layer1 = self.layer1.to(device)
            self.pos_layer1 = self.pos_layer1.to(device)
            self.layer2 = self.layer2.to(device)
            self.pos_layer2= self.pos_layer2.to(device)
            self.norm = self.norm.to(device)
            self.layer3 = self.layer3.to(device)
            self._fc1 = self._fc1.to(device)
            self._fc2 = self._fc2.to(device)
            self.conv1 = self.conv1.to(device)

具体的描述信息解释:

This is a method that relocates a PyTorch model to either the CPU or GPU based on the availability of CUDA.

If CUDA is available and the machine has multiple GPUs, the model is parallelized using nn.DataParallel and each module is sent to the first GPU using to('cuda:0'). This allows the model to utilize multiple GPUs for faster training and inference.

If CUDA is available but the machine has only one GPU, each module is sent to that GPU using to(device).

If CUDA is not available, the model is sent to the CPU using to(device).

Overall, this method is useful for ensuring that a PyTorch model runs on the appropriate device, whether it is a single CPU or multiple GPUs.

1.代码:Transformer_LeFF_MSA.py中 代码实现并行,将模型中的不同block加载到不同的gpu显卡上。

#######################test37 for model parallel ################################
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from nystrom_attention import NystromAttention
from performer_pytorch import Performer
from performer_pytorch import SelfAttention
from einops import rearrange, repeat
from einops.layers.torch import Rearrange
import os
os.environ["CUDA_VISIBLE_DEVICES"] = '0,1'

#ref:https://zhuanlan.zhihu.com/p/517730496
class TransLayer(nn.Module):
    def __init__(self, norm_layer=nn.LayerNorm, dim=512):
        super().__init__()
        self.norm = norm_layer(dim)

        #Attention (4):Performer with SelfAttention##-->ref link:https://github.com/lucidrains/performer-pytorch.without dropout
        self.attn = SelfAttention(
            dim = 512,
            heads = 8,
            causal = False
        )  
    def forward(self, x):
        x = x + self.attn(self.norm(x))
        #print("after attn x shape:",x.shape)                   # after attn x shape: torch.Size([1, 6084, 512])
        return x



#test 9:LeFF(cnn_feat+proj+proj1+proj2)+Performer+MIL,the LeFF(cnn_feat+proj+proj1+proj2)-->PPEG,##Attention (4):Performer with SelfAttention##
class LeFF(nn.Module):
   def __init__(self, dim=512):
       super(LeFF, self).__init__()
       # group convolution ref link:https://www.jianshu.com/p/a936b7bc54e3
       self.proj = nn.Conv2d(dim, dim, 1, 1, 1//2, groups=dim)  
       self.proj1 = nn.Conv2d(dim, dim, 5, 1, 5//2, groups=dim)
       self.proj2 = nn.Conv2d(dim, dim, 3, 1, 3//2, groups=dim)


   def forward(self, x, H, W):
       B, _, C = x.shape
       #print("x.shape:",x.shape)                    # x.shape torch.Size([1, 6085, 512])

       cls_token, feat_token = x[:, 0], x[:, 1:]   # split x into cls_token-->torch.Size([1, 512]) , feat_token-->torch.Size([1, 6084, 512]) 

       cnn_feat = feat_token.transpose(1, 2).view(B, C, H, W)  # 1:patch_size  2:patch_dim
       #print("after transpose feat_token shape:", feat_token.transpose(1, 2).shape)  #  torch.Size([1, 512, 6084])
       #print("cnn_feat shape:",cnn_feat.shape)      # cnn_feat shape torch.Size([1, 512, 78, 78])

       x = self.proj(cnn_feat)+cnn_feat+self.proj1(cnn_feat)+self.proj2(cnn_feat)#+self.proj3(cnn_feat)                           #cnn_feat add to the (obtained from convolution block processing with kernal k=3,5,7 padding=1,2,3)
       #print("x shape:",x.shape)                    # x shape: torch.Size([1, 512, 78, 78])

       x = x.flatten(2).transpose(1, 2)             
       #print("after flatten x shape:",x.shape)      # after flatten x shape: torch.Size([1, 6084, 512])

       x = torch.cat((cls_token.unsqueeze(1), x), dim=1)  
       #print("after concat x shape:",x.shape)       # after concat x shape: torch.Size([1, 6085, 512])
       return x


#https://www.bilibili.com/video/av380166304
class TransformerMIL(nn.Module):
    def __init__(self, n_classes):
        super(TransformerMIL, self).__init__()
        self._fc1 = nn.Sequential(nn.Linear(1024, 512), nn.ReLU()).to('cuda:0')
        self.cls_token = nn.Parameter(torch.randn(1, 1, 512))
        self.n_classes = n_classes
        self.layer1 = TransLayer(dim=512).to('cuda:0')
        self.pos_layer1 = LeFF(dim=512).to('cuda:0')
        self.layer2 = TransLayer(dim=512).to('cuda:1')
        self.pos_layer2 = LeFF(dim=512).to('cuda:1')
        #add another layer
        #self.layer3 = TransLayer(dim=512)
        #self.norm = nn.LayerNorm(8)
        #self._fc2 = nn.Linear(512, self.n_classes)
        #self.conv1=nn.Conv1d(512,8,1)
        
        self.layer3 = TransLayer(dim=512).to('cuda:1')
        self.norm = nn.LayerNorm(n_classes).to('cuda:1')
        self._fc2 = nn.Linear(512, self.n_classes).to('cuda:1')
        self.conv1=nn.Conv1d(512,n_classes,1).to('cuda:1')
  

    def relocate(self):
        device=torch.device("cuda" if torch.cuda.is_available() else "cpu")
        if torch.cuda.device_count() <1:
            device_ids = list(range(torch.cuda.device_count()))
            self.layer1 = nn.DataParallel(self.layer1, device_ids=device_ids).to('cuda:0')
            self.pos_layer1 = nn.DataParallel(self.pos_layer1, device_ids=device_ids).to('cuda:0')
            self.layer2 = nn.DataParallel(self.layer2, device_ids=device_ids).to('cuda:0')
            self.pos_layer2 = nn.DataParallel(self.pos_layer2, device_ids=device_ids).to('cuda:0')
            self.norm = nn.DataParallel(self.norm, device_ids=device_ids).to('cuda:0')
            self.layer3 = nn.DataParallel(self.layer3, device_ids=device_ids).to('cuda:0')
            self.conv1 = nn.DataParallel(self.conv1, device_ids=device_ids).to('cuda:0')
            self._fc1 = nn.DataParallel(self._fc1, device_ids=device_ids).to('cuda:0')
            self._fc2 = nn.DataParallel(self._fc2, device_ids=device_ids).to('cuda:0')
        else: #DELL--->run 
            self.layer1 = self.layer1.to(device)
            self.pos_layer1 = self.pos_layer1.to(device)
            self.layer2 = self.layer2.to(device)
            self.pos_layer2= self.pos_layer2.to(device)
            self.norm = self.norm.to(device)
            self.layer3 = self.layer3.to(device)
            self._fc1 = self._fc1.to(device)
            self._fc2 = self._fc2.to(device)
            self.conv1 = self.conv1.to(device)


    # def forward(self, **kwargs):
    #     h = kwargs['data'].float() #[B, n, 1024]
    #     print("original h shape:",h.shape)


    def forward(self,h,attention_only=False):
        h = h.float()#.to(device) #[B, n, 1024]
        h = h.expand(1, -1, -1)
        n=h.shape[1]  
        #print("original input token shape:",n) 
        #print("h.shape:",h.shape)
        
        h = self._fc1(h.to('cuda:0')) #[B, n, 512]
        # print("after _fc1 layer shape",h.shape)     # after _fc1 layer shape torch.Size([1, 6000, 512])
        
        #---->pad
        H = h.shape[1]
        # print("H shape",H)                          # H shape 6000
        _H, _W = int(np.ceil(np.sqrt(H))), int(np.ceil(np.sqrt(H)))
        # print("_H,_W",_H,_W)                        # _H,_W 78 78
        add_length = _H * _W - H
        # print("add_length",add_length)              # add_length 84


        #add the code to deal with add_length
        #feature token concat the first add_length feature token
        if add_length!=0:
            h = torch.cat([h, h[:,-add_length:,:]],dim = 1) #[B, N, 512]
            # print("----h shape----",h.shape)

        #print("h.shape",h.shape)                    # h.shape torch.Size([1, 6084, 512])

        #---->cls_token
        B = h.shape[0]
        # print("B",B)                                # B 1
        cls_tokens = self.cls_token.expand(B, -1, -1).cuda()
        #print("cls_tokens is on the cuda:",cls_tokens.device)
        # print("cls_tokens shape",cls_tokens.shape)  # cls_tokens shape torch.Size([1, 1, 512])
        h = torch.cat((cls_tokens, h), dim=1)

        # print("h shape",h.shape)                    # h shape torch.Size([1, 6085, 512])

        #---->Translayer x1------->MSA(h)
        h = self.layer1(h.to('cuda:0')) #[B, N, 512]
        # print("after layer1 h shape",h.shape)       # after layer1 h shape torch.Size([1, 6085, 512])

        #---->LeFF(Locally Enhanced Feed-Forward)
        # Linear Projection-->Spatial Restoration-->Depth-wise Convolution-->Flatten
        h = self.pos_layer1(h, _H, _W).to('cuda:1') #[B, N, 512]
        # print("after pos_layer h shape",h.shape)    # after pos_layer h shape torch.Size([1, 6085, 512])

        #---->Translayer x2------->MSA(h)
        h = self.layer2(h)  #[B, N, 512]
        # print("after layer2 shape",h.shape)         # after layer2 shape torch.Size([1, 6085, 512]

        # ---->LeFF(Locally Enhanced Feed-Forward)
        h = self.pos_layer2(h, _H, _W) #[B, N, 512]


        #update the GAP
        h = self.layer3(h) #[B, N, 512]


        h=h[:,1:].expand(1, -1, -1).transpose(1, 2)
        h=self.conv1(h)
        # print("after con1 shape:",h.shape)                     # torch.Size([1, 8, 6084])


        #the first class predicted results
        #class_token_attention_scores=h[:, :1, :n]
        #print("class_token_attention_scores shape:",class_token_attention_scores.shape)
        #class_token_to_remain_tokens=class_token_attention_scores[0][0:1].transpose(1,0)
        #print("class_token_to_remain_tokens shape:",class_token_to_remain_tokens.shape)   #class_token_to_remain_tokens shape: torch.Size([6000, 1])

        #the kidney class
        class_token_attention_scores=h[:, :3, :n]
        #print("class_token_attention_scores shape:",class_token_attention_scores.shape)
        class_token_to_remain_tokens=class_token_attention_scores[0][-1:].transpose(1,0)
        #print("class_token_to_remain_tokens shape:",class_token_to_remain_tokens.shape) 

        if attention_only:
            return class_token_to_remain_tokens


        B, C , _= h.shape
        h=h.view(B, C, _H, _W)   
        # print("after view h shape:",h.shape)                             #after view h shape: torch.Size([1, 8, 78, 78])
        h=torch.nn.functional.adaptive_avg_pool2d(h, (1,1))             
        # print("after adaptive_avg_pool2d h shape:",h.shape)              # after adaptive_avg_pool2d h shape: torch.Size([1, 8, 1, 1])
        h = h.flatten(2).transpose(1, 2)                                 #  torch.Size([1, 1, 8])
        # print("after flatten  shape:",h.shape)
        # print("after flatten transpose h shape:",h.shape)

        # print("class_token2 shape:",h.shape)             #class_token2 shape: torch.Size([1, 512])

        #update the GAP
        h = self.norm(h)[:,0]
        # print("after norm layer shape:",self.norm(h).shape)
        # print("after norm layer shape",h.shape)     # after norm layer shape torch.Size([1, 512])

        #---->predict---->类似于MLP Head
        #logits = self._fc2(h) #[B, n_classes]
        #print("after logits shape",logits.shape)    # after _f2 shape torch.Size([1, 2])

        Y_hat = torch.argmax(h, dim=1)         # original the predicted result
        # print("Y_hat shape",Y_hat.shape)            # Y_hat shape torch.Size([1])
        
        Y_prob = F.softmax(h, dim = 1)         # after softmax:--->predicted result
        # print("Y_prob shape",Y_prob.shape)          # Y_prob shape torch.Size([1, 2])

        results_dict = {'logits': h, 'Y_prob': Y_prob, 'Y_hat': Y_hat,'A':class_token_to_remain_tokens}
        #print(results_dict)                         # {'logits': tensor([[0.4693, 0.3050]], device='cuda:0', grad_fn=), 'Y_prob': tensor([[0.5410, 0.4590]], device='cuda:0', grad_fn=), 'Y_hat': tensor([0], device='cuda:0')}
        return results_dict
        
if __name__ == "__main__":
    data = torch.randn((1, 6000, 1024))
    model = TransformerMIL(n_classes=8)
    print(model.eval())
    results_dict = model(data = data)
    print(results_dict)

2.在程序运行的过程中出现了如下错误,【RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:1 and cuda:0! (when checking argument for argument target in method wrapper_nll_loss_forward)
】,主要因为样本的label并没有放置在gpu 1上。
注:调用损失函数时,只需确保标签(label)与输出(output)在同一设备(on the same device)上。
core_utils_mtl_concat_transformer.py的代码描述如下:

import numpy as np
import torch
import pickle 
from utils.utils import *
import os
from datasets.dataset_mtl_concat import save_splits
from sklearn.metrics import roc_auc_score
#from models.model_toad import TOAD_fc_mtl_concat
from models.TransformerMIL_LeFF_MSA import TransformerMIL
from sklearn.metrics import roc_auc_score, roc_curve
from sklearn.metrics import auc as calc_auc
from sklearn.preprocessing import label_binarize

class Accuracy_Logger(object):
    """Accuracy logger"""
    def __init__(self, n_classes):
        super(Accuracy_Logger, self).__init__()
        self.n_classes = n_classes
        self.initialize()

    def initialize(self):
        self.data = [{"count": 0, "correct": 0} for i in range(self.n_classes)]
    
    def log(self, Y_hat, Y):
        Y_hat = int(Y_hat)
        Y = int(Y)
        self.data[Y]["count"] += 1
        self.data[Y]["correct"] += (Y_hat == Y)

    def log_batch(self, count, correct, c):
        self.data[c]["count"] += count
        self.data[c]["correct"] += correct
    
    def get_summary(self, c):
        count = self.data[c]["count"] 
        correct = self.data[c]["correct"]
        
        if count == 0: 
            acc = None
        else:
            acc = float(correct) / count
        
        return acc, correct, count

class EarlyStopping:
    """Early stops the training if validation loss doesn't improve after a given patience."""
    def __init__(self, patience=20, stop_epoch=50, verbose=False):
        """
        Args:
            patience (int): How long to wait after last time validation loss improved.
                            Default: 20
            stop_epoch (int): Earliest epoch possible for stopping
            verbose (bool): If True, prints a message for each validation loss improvement. 
                            Default: False
        """
        self.patience = patience
        self.stop_epoch = stop_epoch
        self.verbose = verbose
        self.counter = 0
        self.best_score = None
        self.early_stop = False
        self.val_loss_min = np.Inf

    def __call__(self, epoch, val_loss, model, ckpt_name = 'checkpoint.pt'):

        score = -val_loss

        if self.best_score is None:
            self.best_score = score
            self.save_checkpoint(val_loss, model, ckpt_name)
        elif score < self.best_score:
            self.counter += 1
            print(f'EarlyStopping counter: {self.counter} out of {self.patience}')
            if self.counter >= self.patience and epoch > self.stop_epoch:
                self.early_stop = True
        else:
            self.best_score = score
            self.save_checkpoint(val_loss, model, ckpt_name)  #Saves model when validation loss decrease.
            self.counter = 0

    def save_checkpoint(self, val_loss, model, ckpt_name):
        '''Saves model when validation loss decrease.'''
        if self.verbose:
            print(f'Validation loss decreased ({self.val_loss_min:.6f} --> {val_loss:.6f}).  Saving model ...')
        torch.save(model.state_dict(), ckpt_name)
        self.val_loss_min = val_loss

def train(datasets, cur, args):
    """   
        train for a single fold
    """
    print('\nTraining Fold {}!'.format(cur))
    #create the "fold" dir in the result
    writer_dir = os.path.join(args.results_dir, str(cur))
    if not os.path.isdir(writer_dir):
        os.mkdir(writer_dir)

    if args.log_data:
        from tensorboardX import SummaryWriter
        '''
        writer = SummaryWriter(log_dir='logs',flush_secs=60)
        log_dir:tensorboard文件的存放路径
        flush_secs:表示写入tensorboard文件的时间间隔
        '''
        writer = SummaryWriter(writer_dir, flush_secs=15)

    else:
        writer = None

    print('\nInit train/val/test splits...', end=' ')
    train_split, val_split, test_split = datasets    #datasets = (train_dataset, val_dataset, test_dataset)传入的参数

    #save the train/val/test splits to the  splits_0.csv
    save_splits(datasets, ['train', 'val', 'test'], os.path.join(args.results_dir, 'splits_{}.csv'.format(cur)))
    print('Done!')
    print("Training on {} samples".format(len(train_split)))
    print("Validating on {} samples".format(len(val_split)))
    print("Testing on {} samples".format(len(test_split)))

    loss_fn = nn.CrossEntropyLoss()
    
    print('\nInit Model...', end=' ')  #作用:为end传递一个空字符串,这样print函数不会在字符串末尾添加一个换行符,而是添加一个空字符串。
    model_dict = {'n_classes': args.n_classes}

    model= TransformerMIL(**model_dict)#.cuda()
    
    #transfer learning
    # path_checkpoint = "/data/luanhaijing/project/tissue_process_pipeline_origin/result_chosen_5_test1/dummy_mtl_sex_s1/s_0_checkpoint.pt"  # 断点路径
    # checkpoint = torch.load(path_checkpoint)  # 加载断点
    # model.load_state_dict(checkpoint)  # 加载模型可学习参数
    #optimizer.load_state_dict(checkpoint['optimizer'])  # 加载优化器参数
    
    
    #model = TOAD_fc_mtl_concat(**model_dict)
    #model.relocate()  #set the model to GPU
    print('Done!')
    print_network(model)  # print the network structure

    print('\nInit optimizer ...', end=' ')
    optimizer = get_optim(model, args)  #optim:adam  SGD
    print('Done!')
    
    print('\nInit Loaders...', end=' ')
    #return either the validation loader or training loader
    train_loader = get_split_loader(train_split, training=True, testing = args.testing, weighted = args.weighted_sample)
    print("hello")
    val_loader = get_split_loader(val_split)
    test_loader = get_split_loader(test_split)
    print('Done!')

    print('\nSetup EarlyStopping...', end=' ')
    """
    Args:
        patience (int): How long to wait after last time validation loss improved.
                        Default: 20
        stop_epoch (int): Earliest epoch possible for stopping
        verbose (bool): If True, prints a message for each validation loss improvement. 
                        Default: False
    """
    if args.early_stopping:
        early_stopping = EarlyStopping(patience = 20, stop_epoch=50, verbose = True)
    else:
        early_stopping = None
    print('Done!')

    # start_epoch = -1
    # if args.RESUME:
    #     path_checkpoint = "./models/checkpoint/ckpt_best_284.pth"  # 断点路径
    #     checkpoint = torch.load(path_checkpoint)  # 加载断点
    #     model.load_state_dict(checkpoint['net'])  # 加载模型可学习参数
    #     optimizer.load_state_dict(checkpoint['optimizer'])  # 加载优化器参数
    #     start_epoch = checkpoint['epoch']  # 设置开始的epoch
    # for epoch in range(start_epoch + 1, args.max_epochs):
    #     # def train_loop(epoch, model, loader, optimizer, n_classes, writer = None, loss_fn = None):
    #     train_loop(epoch, model, train_loader, optimizer, args.n_classes, writer, loss_fn)
    #     stop = validate(cur, epoch, model, val_loader, args.n_classes,
    #                     early_stopping, writer, loss_fn, args.results_dir)

    #     checkpoint = {
    #         "net": model.state_dict(),
    #         'optimizer': optimizer.state_dict(),
    #         "epoch": epoch
    #     }
    #     if not os.path.isdir("./models/checkpoint"):
    #         os.mkdir("./models/checkpoint")
    #     torch.save(checkpoint, './models/checkpoint/ckpt_best_%s.pth' % (str(epoch)))

    #     if stop:
    #         break

    for epoch in range(args.max_epochs):
        #def train_loop(epoch, model, loader, optimizer, n_classes, writer = None, loss_fn = None):   
        train_loop(epoch, model, train_loader, optimizer, args.n_classes, writer, loss_fn)
        stop = validate(cur, epoch, model, val_loader, args.n_classes, 
            early_stopping, writer, loss_fn, args.results_dir)
        
        if stop: 
            break
    ##获得模型的原始状态以及参数
    if args.early_stopping:
        model.load_state_dict(torch.load(os.path.join(args.results_dir, "s_{}_checkpoint.pt".format(cur))))
    else:
        torch.save(model.state_dict(), os.path.join(args.results_dir, "s_{}_checkpoint.pt".format(cur)))

    _, cls_val_error, cls_val_auc, _= summary(model, val_loader, args.n_classes)
    print('Cls Val error: {:.4f}, Cls ROC AUC: {:.4f}'.format(cls_val_error, cls_val_auc))

    results_dict, cls_test_error, cls_test_auc, acc_loggers= summary(model, test_loader, args.n_classes)
    print('Cls Test error: {:.4f}, Cls ROC AUC: {:.4f}'.format(cls_test_error, cls_test_auc))
    print(acc_loggers)
    
    #for i in range(args.n_classes):
    #    acc, correct, count = acc_loggers[0].get_summary(i)
    #    print('class {}: acc {}, correct {}/{}'.format(i, acc, correct, count))

    #    if writer:
    #        writer.add_scalar('final/test_class_{}_tpr'.format(i), acc, 0)


    if writer:
        writer.add_scalar('final/cls_val_error', cls_val_error, 0)
        writer.add_scalar('final/cls_val_auc', cls_val_auc, 0)
        writer.add_scalar('final/cls_test_error', cls_test_error, 0)
        writer.add_scalar('final/cls_test_auc', cls_test_auc, 0)

    
    writer.close()
    return results_dict, cls_test_auc, cls_val_auc, 1-cls_test_error, 1-cls_val_error 


def train_loop(epoch, model, loader, optimizer, n_classes, writer = None, loss_fn = None):   
    device=torch.device("cuda" if torch.cuda.is_available() else "cpu") 
    model.train()
    cls_logger = Accuracy_Logger(n_classes=n_classes)

    cls_train_error = 0.
    cls_train_loss = 0.

    print('\n')
    for batch_idx, (data, label) in enumerate(loader):
        # data =  data.to(device)
        # #print("data shape:",data.shape)
        # label = label.to(device)
        #print("label shape:",label.shape)
        #print("label",label)


        #pallel
        data =  data.to('cuda:1')
        #print("data shape:",data.shape)
        label = label.to('cuda:1')


    
        results_dict = model(data)#.to(device)
        logits, Y_prob, Y_hat  = results_dict['logits'], results_dict['Y_prob'], results_dict['Y_hat']

        
        cls_logger.log(Y_hat, label)
        
        cls_loss =  loss_fn(logits, label) 
        loss = cls_loss

        cls_loss_value = cls_loss.item()
        cls_train_loss += cls_loss_value

        if (batch_idx + 1) % 5 == 0:
            print('batch {}, cls loss: {:.4f}, '.format(batch_idx, cls_loss_value) + 
                'label: {},  bag_size: {}'.format(label.item(),  data.size(0)))
           
        cls_error = calculate_error(Y_hat, label)
        cls_train_error += cls_error

        # backward pass
        loss.backward()
        # step
        optimizer.step()
        optimizer.zero_grad()

    # calculate loss and error for epoch
    cls_train_loss /= len(loader)
    cls_train_error /= len(loader)

    print('Epoch: {}, cls train_loss: {:.4f}, cls train_error: {:.4f}'.format(epoch, cls_train_loss, cls_train_error))
    for i in range(n_classes):
        acc, correct, count = cls_logger.get_summary(i)
        print('class {}: tpr {}, correct {}/{}'.format(i, acc, correct, count))
        if writer:
            writer.add_scalar('train/class_{}_tpr'.format(i), acc, epoch)

    if writer:
        writer.add_scalar('train/cls_loss', cls_train_loss, epoch)
        writer.add_scalar('train/cls_error', cls_train_error, epoch)
         
def validate(cur, epoch, model, loader, n_classes, early_stopping = None, writer = None, loss_fn = None, results_dir=None):
    device=torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.eval()
    cls_logger = Accuracy_Logger(n_classes=n_classes)
    cls_val_error = 0.
    cls_val_loss = 0.
    
    cls_probs = np.zeros((len(loader), n_classes))
    cls_labels = np.zeros(len(loader))

    with torch.no_grad():
        for batch_idx, (data, label) in enumerate(loader):
            # data =  data.to(device)
            # label = label.to(device)
            data =  data.to('cuda:1')
            #print("data shape:",data.shape)
            label = label.to('cuda:1')

            results_dict = model(data)
            logits, Y_prob, Y_hat  = results_dict['logits'], results_dict['Y_prob'], results_dict['Y_hat']
            del results_dict

            cls_logger.log(Y_hat, label)
            
            cls_loss =  loss_fn(logits, label) 
            loss = cls_loss
            cls_loss_value = cls_loss.item()

            cls_probs[batch_idx] = Y_prob.cpu().numpy()
            cls_labels[batch_idx] = label.item()

            
            cls_val_loss += cls_loss_value
            
            cls_error = calculate_error(Y_hat, label)
            cls_val_error += cls_error
            

    cls_val_error /= len(loader)
    cls_val_loss /= len(loader)


    if n_classes == 2:
        cls_auc = roc_auc_score(cls_labels, cls_probs[:, 1])
        cls_aucs = []
    else:
        cls_aucs = []
        binary_labels = label_binarize(cls_labels, classes=[i for i in range(n_classes)])
        for class_idx in range(n_classes):
            if class_idx in cls_labels:
                fpr, tpr, _ = roc_curve(binary_labels[:, class_idx], cls_probs[:, class_idx])
                cls_aucs.append(calc_auc(fpr, tpr))
            else:
                cls_aucs.append(float('nan'))

        cls_auc = np.nanmean(np.array(cls_aucs))
    
    
    if writer:
        writer.add_scalar('val/cls_loss', cls_val_loss, epoch)
        writer.add_scalar('val/cls_auc', cls_auc, epoch)
        writer.add_scalar('val/cls_error', cls_val_error, epoch)

    print('\nVal Set, cls val_loss: {:.4f}, cls val_error: {:.4f}, cls auc: {:.4f}'.format(cls_val_loss, cls_val_error, cls_auc))
    for i in range(n_classes):
        acc, correct, count = cls_logger.get_summary(i)
        print('class {}: tpr {}, correct {}/{}'.format(i, acc, correct, count))
        if writer:
            writer.add_scalar('val/class_{}_tpr'.format(i), acc, epoch)
    

    if early_stopping:
        assert results_dir
        early_stopping(epoch, cls_val_loss, model, ckpt_name = os.path.join(results_dir, "s_{}_checkpoint.pt".format(cur)))
        
        if early_stopping.early_stop:
            print("Early stopping")
            return True

    return False

def summary(model, loader, n_classes):
    device=torch.device("cuda" if torch.cuda.is_available() else "cpu")
    cls_logger = Accuracy_Logger(n_classes=n_classes)
    site_logger = Accuracy_Logger(n_classes=2)
    model.eval()
    cls_test_error = 0.
    cls_test_loss = 0.

    all_cls_probs = np.zeros((len(loader), n_classes))
    all_cls_labels = np.zeros(len(loader))

    slide_ids = loader.dataset.slide_data['slide_id']
    patient_results = {}

    for batch_idx, (data, label) in enumerate(loader):
        # data =  data.to(device)
        # label = label.to(device)

        data =  data.to('cuda:1')
        label = label.to('cuda:1')

        #site = site.to(device)
        #sex = sex.float().to(device)
        slide_id = slide_ids.iloc[batch_idx]
        with torch.no_grad():
            results_dict = model(data)

        logits, Y_prob, Y_hat  = results_dict['logits'], results_dict['Y_prob'], results_dict['Y_hat']
        del results_dict

        cls_logger.log(Y_hat, label)
        cls_probs = Y_prob.cpu().numpy()
        all_cls_probs[batch_idx] = cls_probs
        all_cls_labels[batch_idx] = label.item()

        
        patient_results.update({slide_id: {'slide_id': np.array(slide_id), 'cls_prob': cls_probs, 'cls_label': label.item()}})
        cls_error = calculate_error(Y_hat, label)
        cls_test_error += cls_error

    cls_test_error /= len(loader)

    if n_classes == 2:
        cls_auc = roc_auc_score(all_cls_labels, all_cls_probs[:, 1])
        
    else:
        cls_auc = roc_auc_score(all_cls_labels, all_cls_probs, multi_class='ovr')
    

    return patient_results, cls_test_error, cls_auc, (cls_logger)

import numpy as np

import torch
import torch.nn as nn
import torch.nn.functional as F
from models.TransformerMIL_LeFF_MSA import TransformerMIL
#from models.model_toad import TOAD_fc_mtl_concat
import pdb
import os
import pandas as pd
from utils.utils import *
from utils.core_utils_mtl_concat import EarlyStopping,  Accuracy_Logger
from utils.file_utils import save_pkl, load_pkl
from sklearn.metrics import roc_auc_score, roc_curve, auc
from sklearn.metrics import precision_score
from sklearn.metrics import accuracy_score
from sklearn.metrics import recall_score
from sklearn.metrics import f1_score
from sklearn.metrics import average_precision_score
import h5py
from models.resnet_custom import resnet50_baseline
import math
from sklearn.preprocessing import label_binarize
import matplotlib
matplotlib.use('Agg')
from matplotlib import pyplot as plt

from scipy import interp
#define the model
def initiate_model(args, ckpt_path=None):
    print('Init Model')    

    #model_dict = {"dropout": args.drop_out, 'n_classes': args.n_classes}
    #model = TOAD_fc_mtl_concat(**model_dict).cuda()   
    model_dict = {'n_classes': args.n_classes}
    #model= TransformerMIL(**model_dict).cuda()  
    
    # add the code for parallel
    model= TransformerMIL(**model_dict)

    # model.relocate()
    print_network(model)

    if ckpt_path is not None:
        ckpt = torch.load(ckpt_path)
        #print("ckpt:",ckpt)
        model.load_state_dict(ckpt, strict=False)

    model.eval()
    return model

def eval(dataset, args, ckpt_path):
    model = initiate_model(args, ckpt_path)  #create the model

    print('Init Loaders')
    loader = get_simple_loader(dataset)   #[img, label, site, sex]   batchsize default:1
    results_dict = summary(model, loader, args)   #    inference_results = {'patient_results': patient_results, 'cls_test_error': cls_test_error, 'cls_auc': cls_auc, 'cls_aucs': cls_aucs, 'loggers': (cls_logger), 'df':df}

    print('cls_test_error: ', results_dict['cls_test_error'])
    print('cls_auc: ', results_dict['cls_auc'])

    return model, results_dict

# Code taken from pytorch/examples for evaluating topk classification on on ImageNet
def accuracy(output, target, topk=(1,)):
    """Computes the accuracy over the k top predictions for the specified values of k"""
    with torch.no_grad():
        maxk = max(topk)
        print("maxk:",maxk)
        batch_size = target.size(0)

        _, pred = output.topk(maxk, 1, True, True)
        pred = pred.t()
        correct = pred.eq(target.view(1, -1).expand_as(pred))
        print("correct shape:",correct.shape)

        res = []
        for k in topk:
            print("correct[:k] shape:",k,correct[:k].shape)
            correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True)
            print("correct[:k].view(-1) shape:",k,correct[:k].reshape(-1).shape)
            res.append(correct_k.mul_(1.0 / batch_size))
            print("res:",res)
        return res

def summary(model, loader, args):
    device=torch.device("cuda" if torch.cuda.is_available() else "cpu")
    cls_logger = Accuracy_Logger(n_classes=args.n_classes)  #acc, correct, count
    #site_logger = Accuracy_Logger(n_classes=2)
    model.eval()
    cls_test_error = 0.
    cls_test_loss = 0.
    site_test_error = 0.
    site_test_loss = 0.

    '''
    all_cls_probs->  [batch size,n_classes]   会输出属于每一类别的概率
    all_cls_labels-> [[batch size]            true label only one
    自定义的数据集类MyDataset对象data的长度和 DataLoader对象data_loader的长度,我们会发现:data_loader的长度是data的长度除以batch_size。
    ref:https://blog.csdn.net/weixin_45901519/article/details/115672355
    '''
    all_cls_probs = np.zeros((len(loader), args.n_classes))  #loader-->[img, label, site, sex]  len(loader)->batch size;
    #print("all_cls_probs:",all_cls_probs.shape)
    all_cls_labels = np.zeros(len(loader))
    #print("all_cls_labels:",all_cls_labels.shape)
    all_cls_hats = np.zeros(len(loader))

    slide_ids = loader.dataset.slide_data['slide_id'] #所有的slide_id
    patient_results = {}

    for batch_idx, (data, label) in enumerate(loader):  #batch size=1   len(data_loader)=len(dataset)
        #data =  data.to(device)
        ##print("data shape:",data.shape)
        #label = label.to(device)

        
        # add the code for parellel
        data =  data.to('cuda:1')
        label = label.to('cuda:1')
        
        
        slide_id = slide_ids.iloc[batch_idx]
        with torch.no_grad(): #require_grad=false-->一是减少内存,二是可以把这个operator从computation graph中detach出来,这样就不会在BP过程中计算到。
            model_results_dict = model(data)  #组织切片提取的特征向量+sex
            #print("model_results_dict:",model_results_dict)

        '''
        logits-->   logits  = self.classifier(M[0].unsqueeze(0))   1×2  属于每一类别的概率
        Y_hat-->    Y_hat = torch.topk(logits, 1, dim = 1)[1]  #dim=0表示按照列求topn,dim=1表示按照行求topn  得到的是预测元素下标
        Y_prob-->   Y_prob = F.softmax(logits, dim = 1)        #按行softmax,行和为1     1×2  属于每一类别的概率
        '''
        logits, Y_prob, Y_hat  = model_results_dict['logits'], model_results_dict['Y_prob'], model_results_dict['Y_hat']
        del model_results_dict
        '''
        def log(self, Y_hat, Y):   
            Y_hat = int(Y_hat)
            Y = int(Y)
            self.data[Y]["count"] += 1
            self.data[Y]["correct"] += (Y_hat == Y)
        '''
        cls_logger.log(Y_hat, label)      #统计总类别数和预测正确数目
        cls_probs = Y_prob.cpu().numpy()  #预测的属于每一类别的概率
        cls_hats=Y_hat.cpu().numpy()
        all_cls_hats[batch_idx]=cls_hats
        all_cls_probs[batch_idx] = cls_probs
        all_cls_labels[batch_idx] = label.item()

        #all_sexes[batch_idx] = sex.item()

        
        patient_results.update({slide_id: {'slide_id': np.array(slide_id), 'cls_prob': cls_probs, 'cls_label': label.item()}})
        '''
        def calculate_error(Y_hat, Y):
	        error = 1. - Y_hat.float().eq(Y.float()).float().mean().item()
	        return error
        '''
        cls_error = calculate_error(Y_hat, label)
        cls_test_error += cls_error

    cls_test_error /= len(loader)
    #print("cls_test_error:",cls_test_error)
    '''
    >>a = np.array([[1, 5, 5, 2],
               [9, 6, 2, 8],
               [3, 7, 9, 1]])
    >>np.argmax(a,axis=0)
    >>array([1, 2, 2, 1], dtype=int64)
    >>np.argmax(a,axis=1)
    >>array([1, 0, 2], dtype=int64)
    '''
    all_cls_preds = np.argmax(all_cls_probs, axis=1)
    topk=()
    if args.n_classes > 2:
        if args.n_classes > 5:
            topk = (1,3,5)
        else:
            topk = (1,3)
        
        #print("all_cls_probs shape:",all_cls_probs.shape)              # all_cls_probs shape: (215, 8)

        #print("all_cls_labels shape:",all_cls_labels.shape)            # all_cls_labels shape: (215,)
        topk_accs = accuracy(torch.from_numpy(all_cls_probs), torch.from_numpy(all_cls_labels), topk=topk)
        #print("topk_accs:",topk_accs)
        for k in range(len(topk)):
            print('top{} acc: {:.3f}'.format(topk[k], topk_accs[k].item()))

    if len(np.unique(all_cls_labels)) == 1:  # 只有一个类别时 无法计算auc
        cls_auc = -1
        cls_aucs = []
    else:
        if args.n_classes == 2:
            '''
            y_true = np.array([0, 0, 1, 1])
            y_scores = np.array([0.1, 0.4, 0.35, 0.8])
            roc_auc_score(y_true, y_scores)
            0.75
            '''
            cls_auc = roc_auc_score(all_cls_labels, all_cls_probs[:, 1])
            cls_aucs = []
        else:
            cls_aucs = []
            cls_recalls=[]
            cls_precisions=[]
            fprs=[]
            tprs=[]
            binary_labels = label_binarize(all_cls_labels, classes=[i for i in range(args.n_classes)])
            binary_hats = label_binarize(all_cls_hats, classes=[i for i in range(args.n_classes)])
            print(all_cls_labels)
            print("all_cls_labels_shape",all_cls_labels.shape)
            print(all_cls_hats)
            print("all_cls_hats_shape",all_cls_hats.shape)

            a=precision_score(all_cls_labels,all_cls_hats,average=None)
            print("precision",a)
            b=recall_score(all_cls_labels,all_cls_hats,average=None)
            print("recall",b)
            c=f1_score(all_cls_labels,all_cls_hats,average=None)
            print("f1_score",c)
            d=average_precision_score(binary_labels,all_cls_probs,average=None)
            print("average_precision_score",d)
            e=roc_auc_score(binary_labels,all_cls_probs,average=None)
            print("roc_auc_score", e)



            a1=precision_score(all_cls_labels,all_cls_hats,average='micro')
            a2=precision_score(all_cls_labels,all_cls_hats,average='macro')
            a3=precision_score(all_cls_labels,all_cls_hats,average='weighted')
            print("precision micro;macro;weighted",a1,a2,a3)

            b1=recall_score(all_cls_labels,all_cls_hats,average='micro')
            b2=recall_score(all_cls_labels,all_cls_hats,average='macro')
            b3=recall_score(all_cls_labels,all_cls_hats,average='weighted')
            print("recall micro;macro;weighted",b1,b2,b3)

            c1 = f1_score(all_cls_labels, all_cls_hats, average='micro')
            c2= f1_score(all_cls_labels, all_cls_hats, average='macro')
            c3 = f1_score(all_cls_labels, all_cls_hats, average='weighted')
            print("f1_score micro;macro;weighted", c1, c2, c3)

            d1 = average_precision_score(binary_labels, all_cls_probs, average='micro')
            d2 = average_precision_score(binary_labels, all_cls_probs, average='macro')
            d3 = average_precision_score(binary_labels, all_cls_probs, average='weighted')
            print("AP micro;macro;weighted", d1, d2, d3)

            e1 = roc_auc_score(binary_labels, all_cls_probs, average='micro')
            e2 = roc_auc_score(binary_labels, all_cls_probs, average='macro')
            e3= roc_auc_score(binary_labels, all_cls_probs, average='weighted')
            print("roc_auc micro;macro;weighted", e1, e2, e3)

            for class_idx in range(args.n_classes):
                if class_idx in all_cls_labels:
                    fpr, tpr, _ = roc_curve(binary_labels[:, class_idx], all_cls_probs[:, class_idx])
                    fprs.append(fpr)
                    tprs.append(tpr)
                    cls_aucs.append(auc(fpr, tpr))       # cacluate the every type auc

                    #precision=precision_score(binary_labels[:, class_idx], binary_hats[:, class_idx],average='micro')
                    #cls_precisions.append(precision)
                    #recall=recall_score(binary_labels[:, class_idx], binary_hats[:, class_idx],average='micro')
                    #cls_recalls.append(recall)
                else:
                    cls_aucs.append(float('nan'))
                    cls_recalls.append(float('nan'))
                    cls_precisions.append(float('nan'))
            # ref:https://www.cnblogs.com/laozhanghahaha/p/12499979.html
            print(cls_recalls)
            print(cls_precisions)
            # plt.figure(figsize=(4, 4))
            plt.plot(fprs[0], tprs[0], lw=1.5, label="Lung AUC=%.3f" % cls_aucs[0])
            plt.plot(fprs[1], tprs[1], lw=1.5, label="Skin AUC=%.3f" % cls_aucs[1])
            plt.plot(fprs[2], tprs[2], lw=1.5, label="Kidney AUC=%.3f" % cls_aucs[2])
            plt.plot(fprs[3], tprs[3], lw=1.5, label="Uterus Endometrium AUC=%.3f" % cls_aucs[3])
            plt.plot(fprs[4], tprs[4], lw=1.5, label="Pancreas AUC=%.3f" % cls_aucs[4])
            plt.plot(fprs[5], tprs[5], lw=1.5, label="Soft Tissue AUC=%.3f" % cls_aucs[5])
            plt.plot(fprs[6], tprs[6], lw=1.5, label="Head Neck AUC=%.3f" % cls_aucs[6])
            plt.plot(fprs[7], tprs[7], lw=1.5, label="Brain AUC=%.3f" % cls_aucs[7])
            #micro
            fpr_micro, tpr_micro, _ = roc_curve(binary_labels.ravel(), all_cls_probs.ravel())
            roc_auc_micro = auc(fpr_micro, tpr_micro)
            plt.plot(fpr_micro, tpr_micro, lw=1.5, label="micro AUC=%.3f" % roc_auc_micro)

            #macro
            all_fpr = np.unique(np.concatenate([fprs[i] for i in range(args.n_classes)]))
            mean_tpr = np.zeros_like(all_fpr)
            for i in range(args.n_classes):
                mean_tpr += interp(all_fpr, fprs[i], tprs[i])
            mean_tpr /= args.n_classes
            roc_auc_macro = auc(all_fpr, mean_tpr)
            plt.plot(all_fpr, mean_tpr, lw=1.5, label="macro AUC=%.3f" % roc_auc_macro)
            # plt.xlabel("FPR", fontsize=15)
            # plt.ylabel("TPR", fontsize=15)
            plt.xlabel("1-Specificity")
            plt.ylabel("Sensitivity")
            # plt.title("ROC")
            plt.legend(loc="lower right")
            # plt.gca().set_aspect(1)
            plt.savefig("eval_results_Transformer_test37/EVAL_dummy_mtl_sex_s1_eval/eight_type.pdf")
            


            if args.micro_average:
                #average=micro情况,就是计算以各类作为Positve时的预测正确TP的和再除以以各类作为Positve时的TP+FP
                #average=macro情况,与average=micro情况相对立,是先分别计算将各类视作Positive情况下的score,再求个平均
                binary_labels = label_binarize(all_cls_labels, classes=[i for i in range(args.n_classes)])
                valid_classes = np.where(np.any(binary_labels, axis=0))[0]   #类别索引
                '''
                from sklearn.preprocessing import label_binarize
                a=label_binarize([1, 6], classes=[1, 2, 4, 6])
                a
                array([[1, 0, 0, 0],
                    [0, 0, 0, 1]])
                valid_classes = np.where(np.any(a, axis=0))[0]
                valid_classes
                array([0, 3], dtype=int64)
                a=a[:,valid_classes]
                a
                array([[1, 0],
                    [0, 1]])
                a.ravel()
                array([1, 0, 0, 1])
                '''
                binary_labels = binary_labels[:, valid_classes]
                valid_cls_probs = all_cls_probs[:, valid_classes]
                fpr, tpr, _ = roc_curve(binary_labels.ravel(), valid_cls_probs.ravel())
                cls_auc = auc(fpr, tpr)
                print("micro_average_auc",cls_auc)
                plt.plot(fpr,tpr, lw=1.5, label="micro AUC=%.3f)" % cls_auc)

                plt.xlabel("FPR", fontsize=15)
                plt.ylabel("TPR", fontsize=15)
                plt.title("ROC")
                plt.legend(loc="lower right")
                plt.savefig("eval_results_Transformer_test37/EVAL_dummy_mtl_sex_s1_eval/all_type.png")


            else:
                cls_auc = np.nanmean(np.array(cls_aucs))
                print("macro_average_auc",cls_auc)

    
    '''
        cls_probs = Y_prob.cpu().numpy()  #预测的属于每一类别的概率
        all_cls_preds = np.argmax(all_cls_probs, axis=1)
        all_cls_labels[batch_idx] = label.item()
        
    slide_id,sex,Y,Y_hat,p_0,p_1
    C3L-02647-24,1.0,1.0,1,0.010934877209365368,0.9890650510787964
    C3L-04733-22,0.0,1.0,1,2.2340067516779527e-05,0.999977707862854

    '''
    results_dict = {'slide_id': slide_ids,  'Y': all_cls_labels, 'Y_hat': all_cls_preds}
    for c in range(args.n_classes):
        results_dict.update({'p_{}'.format(c): all_cls_probs[:,c]})


    df = pd.DataFrame(results_dict)
    '''
    patient_results--->patient_results.update({slide_id: {'slide_id': np.array(slide_id), 'cls_prob': cls_probs, 'cls_label': label.item()}})
    cls_test_error---->cls_error = calculate_error(Y_hat, label)
                       cls_test_error += cls_error
    cls_auc:  ---->    cls_auc = roc_auc_score(all_cls_labels, all_cls_probs[:, 1])
    cls_aucs: ---->    cls_aucs = []
    cla_logger: ---->  acc, correct, count

    '''
    inference_results = {'patient_results': patient_results, 'cls_test_error': cls_test_error,
                     'cls_auc': cls_auc, 'cls_aucs': cls_aucs, 'loggers': (cls_logger), 'df':df}

    for k in range(len(topk)):
        inference_results.update({'top{}_acc'.format(topk[k]): topk_accs[k].item()})
    return inference_results

【模型并行代码参考链接:https://zhuanlan.zhihu.com/p/87596314

3.在进行模型评估时,同样需要将类别标签label放置在gpu1上。【eval_utils_mtl_concat.py】

import numpy as np

import torch
import torch.nn as nn
import torch.nn.functional as F
from models.TransformerMIL_LeFF_MSA import TransformerMIL
#from models.model_toad import TOAD_fc_mtl_concat
import pdb
import os
import pandas as pd
from utils.utils import *
from utils.core_utils_mtl_concat import EarlyStopping,  Accuracy_Logger
from utils.file_utils import save_pkl, load_pkl
from sklearn.metrics import roc_auc_score, roc_curve, auc
from sklearn.metrics import precision_score
from sklearn.metrics import accuracy_score
from sklearn.metrics import recall_score
from sklearn.metrics import f1_score
from sklearn.metrics import average_precision_score
import h5py
from models.resnet_custom import resnet50_baseline
import math
from sklearn.preprocessing import label_binarize
import matplotlib
matplotlib.use('Agg')
from matplotlib import pyplot as plt

from scipy import interp
#define the model
def initiate_model(args, ckpt_path=None):
    print('Init Model')    

    #model_dict = {"dropout": args.drop_out, 'n_classes': args.n_classes}
    #model = TOAD_fc_mtl_concat(**model_dict).cuda()   
    model_dict = {'n_classes': args.n_classes}
    #model= TransformerMIL(**model_dict).cuda()  
    
    # add the code for parallel
    model= TransformerMIL(**model_dict)

    # model.relocate()
    print_network(model)

    if ckpt_path is not None:
        ckpt = torch.load(ckpt_path)
        #print("ckpt:",ckpt)
        model.load_state_dict(ckpt, strict=False)

    model.eval()
    return model

def eval(dataset, args, ckpt_path):
    model = initiate_model(args, ckpt_path)  #create the model

    print('Init Loaders')
    loader = get_simple_loader(dataset)   #[img, label, site, sex]   batchsize default:1
    results_dict = summary(model, loader, args)   #    inference_results = {'patient_results': patient_results, 'cls_test_error': cls_test_error, 'cls_auc': cls_auc, 'cls_aucs': cls_aucs, 'loggers': (cls_logger), 'df':df}

    print('cls_test_error: ', results_dict['cls_test_error'])
    print('cls_auc: ', results_dict['cls_auc'])

    return model, results_dict

# Code taken from pytorch/examples for evaluating topk classification on on ImageNet
def accuracy(output, target, topk=(1,)):
    """Computes the accuracy over the k top predictions for the specified values of k"""
    with torch.no_grad():
        maxk = max(topk)
        print("maxk:",maxk)
        batch_size = target.size(0)

        _, pred = output.topk(maxk, 1, True, True)
        pred = pred.t()
        correct = pred.eq(target.view(1, -1).expand_as(pred))
        print("correct shape:",correct.shape)

        res = []
        for k in topk:
            print("correct[:k] shape:",k,correct[:k].shape)
            correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True)
            print("correct[:k].view(-1) shape:",k,correct[:k].reshape(-1).shape)
            res.append(correct_k.mul_(1.0 / batch_size))
            print("res:",res)
        return res

def summary(model, loader, args):
    device=torch.device("cuda" if torch.cuda.is_available() else "cpu")
    cls_logger = Accuracy_Logger(n_classes=args.n_classes)  #acc, correct, count
    #site_logger = Accuracy_Logger(n_classes=2)
    model.eval()
    cls_test_error = 0.
    cls_test_loss = 0.
    site_test_error = 0.
    site_test_loss = 0.

    '''
    all_cls_probs->  [batch size,n_classes]   会输出属于每一类别的概率
    all_cls_labels-> [[batch size]            true label only one
    自定义的数据集类MyDataset对象data的长度和 DataLoader对象data_loader的长度,我们会发现:data_loader的长度是data的长度除以batch_size。
    ref:https://blog.csdn.net/weixin_45901519/article/details/115672355
    '''
    all_cls_probs = np.zeros((len(loader), args.n_classes))  #loader-->[img, label, site, sex]  len(loader)->batch size;
    #print("all_cls_probs:",all_cls_probs.shape)
    all_cls_labels = np.zeros(len(loader))
    #print("all_cls_labels:",all_cls_labels.shape)
    all_cls_hats = np.zeros(len(loader))

    slide_ids = loader.dataset.slide_data['slide_id'] #所有的slide_id
    patient_results = {}

    for batch_idx, (data, label) in enumerate(loader):  #batch size=1   len(data_loader)=len(dataset)
        #data =  data.to(device)
        ##print("data shape:",data.shape)
        #label = label.to(device)

        
        # add the code for parellel
        data =  data.to('cuda:1')
        label = label.to('cuda:1')
        
        
        slide_id = slide_ids.iloc[batch_idx]
        with torch.no_grad(): #require_grad=false-->一是减少内存,二是可以把这个operator从computation graph中detach出来,这样就不会在BP过程中计算到。
            model_results_dict = model(data)  #组织切片提取的特征向量+sex
            #print("model_results_dict:",model_results_dict)

        '''
        logits-->   logits  = self.classifier(M[0].unsqueeze(0))   1×2  属于每一类别的概率
        Y_hat-->    Y_hat = torch.topk(logits, 1, dim = 1)[1]  #dim=0表示按照列求topn,dim=1表示按照行求topn  得到的是预测元素下标
        Y_prob-->   Y_prob = F.softmax(logits, dim = 1)        #按行softmax,行和为1     1×2  属于每一类别的概率
        '''
        logits, Y_prob, Y_hat  = model_results_dict['logits'], model_results_dict['Y_prob'], model_results_dict['Y_hat']
        del model_results_dict
        '''
        def log(self, Y_hat, Y):   
            Y_hat = int(Y_hat)
            Y = int(Y)
            self.data[Y]["count"] += 1
            self.data[Y]["correct"] += (Y_hat == Y)
        '''
        cls_logger.log(Y_hat, label)      #统计总类别数和预测正确数目
        cls_probs = Y_prob.cpu().numpy()  #预测的属于每一类别的概率
        cls_hats=Y_hat.cpu().numpy()
        all_cls_hats[batch_idx]=cls_hats
        all_cls_probs[batch_idx] = cls_probs
        all_cls_labels[batch_idx] = label.item()

        #all_sexes[batch_idx] = sex.item()

        
        patient_results.update({slide_id: {'slide_id': np.array(slide_id), 'cls_prob': cls_probs, 'cls_label': label.item()}})
        '''
        def calculate_error(Y_hat, Y):
	        error = 1. - Y_hat.float().eq(Y.float()).float().mean().item()
	        return error
        '''
        cls_error = calculate_error(Y_hat, label)
        cls_test_error += cls_error

    cls_test_error /= len(loader)
    #print("cls_test_error:",cls_test_error)
    '''
    >>a = np.array([[1, 5, 5, 2],
               [9, 6, 2, 8],
               [3, 7, 9, 1]])
    >>np.argmax(a,axis=0)
    >>array([1, 2, 2, 1], dtype=int64)
    >>np.argmax(a,axis=1)
    >>array([1, 0, 2], dtype=int64)
    '''
    all_cls_preds = np.argmax(all_cls_probs, axis=1)
    topk=()
    if args.n_classes > 2:
        if args.n_classes > 5:
            topk = (1,3,5)
        else:
            topk = (1,3)
        
        #print("all_cls_probs shape:",all_cls_probs.shape)              # all_cls_probs shape: (215, 8)

        #print("all_cls_labels shape:",all_cls_labels.shape)            # all_cls_labels shape: (215,)
        topk_accs = accuracy(torch.from_numpy(all_cls_probs), torch.from_numpy(all_cls_labels), topk=topk)
        #print("topk_accs:",topk_accs)
        for k in range(len(topk)):
            print('top{} acc: {:.3f}'.format(topk[k], topk_accs[k].item()))

    if len(np.unique(all_cls_labels)) == 1:  # 只有一个类别时 无法计算auc
        cls_auc = -1
        cls_aucs = []
    else:
        if args.n_classes == 2:
            '''
            y_true = np.array([0, 0, 1, 1])
            y_scores = np.array([0.1, 0.4, 0.35, 0.8])
            roc_auc_score(y_true, y_scores)
            0.75
            '''
            cls_auc = roc_auc_score(all_cls_labels, all_cls_probs[:, 1])
            cls_aucs = []
        else:
            cls_aucs = []
            cls_recalls=[]
            cls_precisions=[]
            fprs=[]
            tprs=[]
            binary_labels = label_binarize(all_cls_labels, classes=[i for i in range(args.n_classes)])
            binary_hats = label_binarize(all_cls_hats, classes=[i for i in range(args.n_classes)])
            print(all_cls_labels)
            print("all_cls_labels_shape",all_cls_labels.shape)
            print(all_cls_hats)
            print("all_cls_hats_shape",all_cls_hats.shape)

            a=precision_score(all_cls_labels,all_cls_hats,average=None)
            print("precision",a)
            b=recall_score(all_cls_labels,all_cls_hats,average=None)
            print("recall",b)
            c=f1_score(all_cls_labels,all_cls_hats,average=None)
            print("f1_score",c)
            d=average_precision_score(binary_labels,all_cls_probs,average=None)
            print("average_precision_score",d)
            e=roc_auc_score(binary_labels,all_cls_probs,average=None)
            print("roc_auc_score", e)



            a1=precision_score(all_cls_labels,all_cls_hats,average='micro')
            a2=precision_score(all_cls_labels,all_cls_hats,average='macro')
            a3=precision_score(all_cls_labels,all_cls_hats,average='weighted')
            print("precision micro;macro;weighted",a1,a2,a3)

            b1=recall_score(all_cls_labels,all_cls_hats,average='micro')
            b2=recall_score(all_cls_labels,all_cls_hats,average='macro')
            b3=recall_score(all_cls_labels,all_cls_hats,average='weighted')
            print("recall micro;macro;weighted",b1,b2,b3)

            c1 = f1_score(all_cls_labels, all_cls_hats, average='micro')
            c2= f1_score(all_cls_labels, all_cls_hats, average='macro')
            c3 = f1_score(all_cls_labels, all_cls_hats, average='weighted')
            print("f1_score micro;macro;weighted", c1, c2, c3)

            d1 = average_precision_score(binary_labels, all_cls_probs, average='micro')
            d2 = average_precision_score(binary_labels, all_cls_probs, average='macro')
            d3 = average_precision_score(binary_labels, all_cls_probs, average='weighted')
            print("AP micro;macro;weighted", d1, d2, d3)

            e1 = roc_auc_score(binary_labels, all_cls_probs, average='micro')
            e2 = roc_auc_score(binary_labels, all_cls_probs, average='macro')
            e3= roc_auc_score(binary_labels, all_cls_probs, average='weighted')
            print("roc_auc micro;macro;weighted", e1, e2, e3)

            for class_idx in range(args.n_classes):
                if class_idx in all_cls_labels:
                    fpr, tpr, _ = roc_curve(binary_labels[:, class_idx], all_cls_probs[:, class_idx])
                    fprs.append(fpr)
                    tprs.append(tpr)
                    cls_aucs.append(auc(fpr, tpr))       # cacluate the every type auc

                    #precision=precision_score(binary_labels[:, class_idx], binary_hats[:, class_idx],average='micro')
                    #cls_precisions.append(precision)
                    #recall=recall_score(binary_labels[:, class_idx], binary_hats[:, class_idx],average='micro')
                    #cls_recalls.append(recall)
                else:
                    cls_aucs.append(float('nan'))
                    cls_recalls.append(float('nan'))
                    cls_precisions.append(float('nan'))
            # ref:https://www.cnblogs.com/laozhanghahaha/p/12499979.html
            print(cls_recalls)
            print(cls_precisions)
            # plt.figure(figsize=(4, 4))
            plt.plot(fprs[0], tprs[0], lw=1.5, label="Lung AUC=%.3f" % cls_aucs[0])
            plt.plot(fprs[1], tprs[1], lw=1.5, label="Skin AUC=%.3f" % cls_aucs[1])
            plt.plot(fprs[2], tprs[2], lw=1.5, label="Kidney AUC=%.3f" % cls_aucs[2])
            plt.plot(fprs[3], tprs[3], lw=1.5, label="Uterus Endometrium AUC=%.3f" % cls_aucs[3])
            plt.plot(fprs[4], tprs[4], lw=1.5, label="Pancreas AUC=%.3f" % cls_aucs[4])
            plt.plot(fprs[5], tprs[5], lw=1.5, label="Soft Tissue AUC=%.3f" % cls_aucs[5])
            plt.plot(fprs[6], tprs[6], lw=1.5, label="Head Neck AUC=%.3f" % cls_aucs[6])
            plt.plot(fprs[7], tprs[7], lw=1.5, label="Brain AUC=%.3f" % cls_aucs[7])
            #micro
            fpr_micro, tpr_micro, _ = roc_curve(binary_labels.ravel(), all_cls_probs.ravel())
            roc_auc_micro = auc(fpr_micro, tpr_micro)
            plt.plot(fpr_micro, tpr_micro, lw=1.5, label="micro AUC=%.3f" % roc_auc_micro)

            #macro
            all_fpr = np.unique(np.concatenate([fprs[i] for i in range(args.n_classes)]))
            mean_tpr = np.zeros_like(all_fpr)
            for i in range(args.n_classes):
                mean_tpr += interp(all_fpr, fprs[i], tprs[i])
            mean_tpr /= args.n_classes
            roc_auc_macro = auc(all_fpr, mean_tpr)
            plt.plot(all_fpr, mean_tpr, lw=1.5, label="macro AUC=%.3f" % roc_auc_macro)
            # plt.xlabel("FPR", fontsize=15)
            # plt.ylabel("TPR", fontsize=15)
            plt.xlabel("1-Specificity")
            plt.ylabel("Sensitivity")
            # plt.title("ROC")
            plt.legend(loc="lower right")
            # plt.gca().set_aspect(1)
            plt.savefig("eval_results_Transformer_test37/EVAL_dummy_mtl_sex_s1_eval/eight_type.pdf")
            


            if args.micro_average:
                #average=micro情况,就是计算以各类作为Positve时的预测正确TP的和再除以以各类作为Positve时的TP+FP
                #average=macro情况,与average=micro情况相对立,是先分别计算将各类视作Positive情况下的score,再求个平均
                binary_labels = label_binarize(all_cls_labels, classes=[i for i in range(args.n_classes)])
                valid_classes = np.where(np.any(binary_labels, axis=0))[0]   #类别索引
                '''
                from sklearn.preprocessing import label_binarize
                a=label_binarize([1, 6], classes=[1, 2, 4, 6])
                a
                array([[1, 0, 0, 0],
                    [0, 0, 0, 1]])
                valid_classes = np.where(np.any(a, axis=0))[0]
                valid_classes
                array([0, 3], dtype=int64)
                a=a[:,valid_classes]
                a
                array([[1, 0],
                    [0, 1]])
                a.ravel()
                array([1, 0, 0, 1])
                '''
                binary_labels = binary_labels[:, valid_classes]
                valid_cls_probs = all_cls_probs[:, valid_classes]
                fpr, tpr, _ = roc_curve(binary_labels.ravel(), valid_cls_probs.ravel())
                cls_auc = auc(fpr, tpr)
                print("micro_average_auc",cls_auc)
                plt.plot(fpr,tpr, lw=1.5, label="micro AUC=%.3f)" % cls_auc)

                plt.xlabel("FPR", fontsize=15)
                plt.ylabel("TPR", fontsize=15)
                plt.title("ROC")
                plt.legend(loc="lower right")
                plt.savefig("eval_results_Transformer_test37/EVAL_dummy_mtl_sex_s1_eval/all_type.png")


            else:
                cls_auc = np.nanmean(np.array(cls_aucs))
                print("macro_average_auc",cls_auc)

    
    '''
        cls_probs = Y_prob.cpu().numpy()  #预测的属于每一类别的概率
        all_cls_preds = np.argmax(all_cls_probs, axis=1)
        all_cls_labels[batch_idx] = label.item()
        
    slide_id,sex,Y,Y_hat,p_0,p_1
    C3L-02647-24,1.0,1.0,1,0.010934877209365368,0.9890650510787964
    C3L-04733-22,0.0,1.0,1,2.2340067516779527e-05,0.999977707862854

    '''
    results_dict = {'slide_id': slide_ids,  'Y': all_cls_labels, 'Y_hat': all_cls_preds}
    for c in range(args.n_classes):
        results_dict.update({'p_{}'.format(c): all_cls_probs[:,c]})


    df = pd.DataFrame(results_dict)
    '''
    patient_results--->patient_results.update({slide_id: {'slide_id': np.array(slide_id), 'cls_prob': cls_probs, 'cls_label': label.item()}})
    cls_test_error---->cls_error = calculate_error(Y_hat, label)
                       cls_test_error += cls_error
    cls_auc:  ---->    cls_auc = roc_auc_score(all_cls_labels, all_cls_probs[:, 1])
    cls_aucs: ---->    cls_aucs = []
    cla_logger: ---->  acc, correct, count

    '''
    inference_results = {'patient_results': patient_results, 'cls_test_error': cls_test_error,
                     'cls_auc': cls_auc, 'cls_aucs': cls_aucs, 'loggers': (cls_logger), 'df':df}

    for k in range(len(topk)):
        inference_results.update({'top{}_acc'.format(topk[k]): topk_accs[k].item()})
    return inference_results

进行修改后仍会出现错误:Traceback (most recent call last): File "eval_mtl_concat_all_type_transformer.py", line 175, in model, results_dict = eval(split_dataset, args, ckpt_paths[ckpt_idx]) # s_0_checkpoint.pt File "/data/luanhaijing/project/tissue_process_pipeline_origin/utils/eval_utils_mtl_concat.py", line 54, in eval results_dict = summary(model, loader, args) # inference_results = {'patient_results': patient_results, 'cls_test_error': cls_test_error, 'cls_auc': cls_auc, 'cls_aucs': cls_aucs, 'loggers': (cls_logger), 'df':df} File "/data/luanhaijing/project/tissue_process_pipeline_origin/utils/eval_utils_mtl_concat.py", line 121, in summary model_results_dict = model(data) #组织切片提取的特征向量+sex File "/data/luanhaijing/project/new_toad/lib/python3.7/site-packages/torch/nn/modules/module.py", line 1190, in _call_impl return forward_call(*input, **kwargs) File "/data/luanhaijing/project/tissue_process_pipeline_origin/models/TransformerMIL_LeFF_MSA.py", line 425, in forward h = self.layer2(h) #[B, N, 512] File "/data/luanhaijing/project/new_toad/lib/python3.7/site-packages/torch/nn/modules/module.py", line 1190, in _call_impl return forward_call(*input, **kwargs) File "/data/luanhaijing/project/tissue_process_pipeline_origin/models/TransformerMIL_LeFF_MSA.py", line 247, in forward x = x + self.attn(self.norm(x)) File "/data/luanhaijing/project/new_toad/lib/python3.7/site-packages/torch/nn/modules/module.py", line 1190, in _call_impl return forward_call(*input, **kwargs) File "/data/luanhaijing/project/new_toad/lib/python3.7/site-packages/torch/nn/modules/normalization.py", line 191, in forward input, self.normalized_shape, self.weight, self.bias, self.eps) File "/data/luanhaijing/project/new_toad/lib/python3.7/site-packages/torch/nn/functional.py", line 2515, in layer_norm return torch.layer_norm(input, normalized_shape, weight, bias, eps, torch.backends.cudnn.enabled) RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:1 and cuda:0! (when checking argument for argument weight in method wrapper__native_layer_norm) ,发现应该将模型放置在 model= TransformerMIL(**model_dict)

你可能感兴趣的:(batch,深度学习,python)