Clinical-grade computational pathology using weakly supervised deep learning on whole slide images

论文:链接:https://pan.baidu.com/s/1eUQt4h6lvRYGVfe9QoEQAw 提取码:pgte 

代码:https://github.com/MSKCC-Computational-Pathology/MIL-nature-medicine-2019

数据:https://github.com/ThoroughImages/CAMEL

MIL_train文件

参数说明

  • "slides": list of full paths to WSIs (e.g. my/full/path/slide01.svs). Size of the list equals the number of slides.
  • "grid": list of a list of tuple (x,y) coordinates. Size of the list is equal to the number of slides. Size of each sublist is equal to the number of tiles in each slide. An example grid list containing two slides, one with 3 tiles and one with 4 tiles:
grid = [
        [(x1_1, y1_1),
	 (x1_2, y1_2),
	 (x1_3, y1_3)],
	[(x2_1, y2_1),
	 (x2_2, y2_2),
	 (x2_3, y2_3),
	 (x2_4, y2_4)],
]
  • "targets": list of slide level class (0: benign slide, 1: tumor slide). Size of the list equals the number of slides.
  • "mult": scale factor (float) for achieving resolutions different than the ones saved in the WSI pyramid file. Usually 1.for no scaling.
  • "level": WSI pyramid level (integer) from which to read the tiles. Usually 0 for the highest resolution.

翻译:

"slides":存储了一系列切片的绝对路径,其本质是保存了字符串的一个列表,列表长度为切片个数。

"grid":可以理解是一个3维数组,3个维度分别表示(切片的总数量,单个切片中有多少个小块也就是上面提到的tile,每个小块左上角的坐标值)

"targets":表示每个切片的标签,阳性为1,阴性为0。具体形状为(切片的总数量,)

"mult":缩放因子,好像没用到,直接设置的为1。

"level":切片层数,具体可以参考https://blog.csdn.net/u013066730/article/details/85049240

方法介绍

介绍下做这个的思想,假设1W张切片,都各自有阴阳性的标签。第一张切片假设是阳性,取出他的所有小块(小块的方式为:首先用otsu找到有组织的区域,然后按照224*224尺寸大小取小块),进行一次前向,选取概率最大的前1个小块(1也可以换成2,3,表示最大的前k个小块),用这个小块来代表这个切片,那么小块的标签即为1,第二张切片假设是阴性,取出他的所有小块,进行一次前向,选取概率最大的前1个小块(1也可以换成2,3,表示最大的前k个小块),用这个小块来代表这个切片,那么小块的标签即为0。以此类推,1W张切片正好取出了1W个小块,这些小块每个都有阴阳性标签,使用这1W个小块取训练res34。这样就完成一次迭代。

接着再按照上述方式迭代100次,即可完成模型训练。

import sys
import os
import numpy as np
import argparse
import random
import openslide
import PIL.Image as Image
import torch
import torch.nn as nn
import torch.optim as optim
import torch.backends.cudnn as cudnn
import torch.nn.functional as F
import torch.utils.data as data
import torchvision.transforms as transforms
import torchvision.models as models

parser = argparse.ArgumentParser(description='MIL-nature-medicine-2019 tile classifier training script')
parser.add_argument('--train_lib', type=str, default='', help='path to train MIL library binary')
parser.add_argument('--val_lib', type=str, default='', help='path to validation MIL library binary. If present.')
parser.add_argument('--output', type=str, default='.', help='name of output file')
parser.add_argument('--batch_size', type=int, default=512, help='mini-batch size (default: 512)')
parser.add_argument('--nepochs', type=int, default=100, help='number of epochs')
parser.add_argument('--workers', default=4, type=int, help='number of data loading workers (default: 4)')
parser.add_argument('--test_every', default=10, type=int, help='test on val every (default: 10)')
parser.add_argument('--weights', default=0.5, type=float, help='unbalanced positive class weight (default: 0.5, balanced classes)')
parser.add_argument('--k', default=1, type=int, help='top k tiles are assumed to be of the same class as the slide (default: 1, standard MIL)')

best_acc = 0
def main():
    global args, best_acc
    args = parser.parse_args()

    #cnn
    model = models.resnet34(True)
    model.fc = nn.Linear(model.fc.in_features, 2)
    model.cuda()

    if args.weights==0.5:
        criterion = nn.CrossEntropyLoss().cuda()
    else:
        w = torch.Tensor([1-args.weights,args.weights])
        criterion = nn.CrossEntropyLoss(w).cuda()
    optimizer = optim.Adam(model.parameters(), lr=1e-4, weight_decay=1e-4)

    cudnn.benchmark = True

    #normalization
    normalize = transforms.Normalize(mean=[0.5,0.5,0.5],std=[0.1,0.1,0.1])
    trans = transforms.Compose([transforms.ToTensor(), normalize])

    #load data
    train_dset = MILdataset(args.train_lib, trans)
    train_loader = torch.utils.data.DataLoader(
        train_dset,
        batch_size=args.batch_size, shuffle=False,
        num_workers=args.workers, pin_memory=False)
    if args.val_lib:
        val_dset = MILdataset(args.val_lib, trans)
        val_loader = torch.utils.data.DataLoader(
            val_dset,
            batch_size=args.batch_size, shuffle=False,
            num_workers=args.workers, pin_memory=False)

    #open output file
    fconv = open(os.path.join(args.output,'convergence.csv'), 'w')
    fconv.write('epoch,metric,value\n')
    fconv.close()

    #loop throuh epochs
    for epoch in range(args.nepochs):
        train_dset.setmode(1)
        probs = inference(epoch, train_loader, model)
        topk = group_argtopk(np.array(train_dset.slideIDX), probs, args.k) # 就表示取前k个的概率
        train_dset.maketraindata(topk)
        train_dset.shuffletraindata()
        train_dset.setmode(2)
        loss = train(epoch, train_loader, model, criterion, optimizer)
        print('Training\tEpoch: [{}/{}]\tLoss: {}'.format(epoch+1, args.nepochs, loss))
        fconv = open(os.path.join(args.output, 'convergence.csv'), 'a')
        fconv.write('{},loss,{}\n'.format(epoch+1,loss))
        fconv.close()

        #Validation
        if args.val_lib and (epoch+1) % args.test_every == 0:
            val_dset.setmode(1)
            probs = inference(epoch, val_loader, model)
            maxs = group_max(np.array(val_dset.slideIDX), probs, len(val_dset.targets))
            pred = [1 if x >= 0.5 else 0 for x in maxs]
            err,fpr,fnr = calc_err(pred, val_dset.targets)
            print('Validation\tEpoch: [{}/{}]\tError: {}\tFPR: {}\tFNR: {}'.format(epoch+1, args.nepochs, err, fpr, fnr))
            fconv = open(os.path.join(args.output, 'convergence.csv'), 'a')
            fconv.write('{},error,{}\n'.format(epoch+1, err))
            fconv.write('{},fpr,{}\n'.format(epoch+1, fpr))
            fconv.write('{},fnr,{}\n'.format(epoch+1, fnr))
            fconv.close()
            #Save best model
            err = (fpr+fnr)/2.
            if 1-err >= best_acc:
                best_acc = 1-err
                obj = {
                    'epoch': epoch+1,
                    'state_dict': model.state_dict(),
                    'best_acc': best_acc,
                    'optimizer' : optimizer.state_dict()
                }
                torch.save(obj, os.path.join(args.output,'checkpoint_best.pth'))

def inference(run, loader, model):
    model.eval()
    probs = torch.FloatTensor(len(loader.dataset)) # 这个len(loader.dataset)可不是batchsize的大小,而是所有小块的数量,具体可以见下面MILdataset中的lenth函数
    with torch.no_grad():
        for i, input in enumerate(loader):
            print('Inference\tEpoch: [{}/{}]\tBatch: [{}/{}]'.format(run+1, args.nepochs, i+1, len(loader)))
            input = input.cuda()
            output = F.softmax(model(input), dim=1)
            probs[i*args.batch_size:i*args.batch_size+input.size(0)] = output.detach()[:,1].clone() # output.detach()[:,1]只取阳性预测概率,概率范围从0-1;probs就是所有小块的阳性的概率
    return probs.cpu().numpy()

def train(run, loader, model, criterion, optimizer):
    model.train()
    running_loss = 0.
    for i, (input, target) in enumerate(loader):
        input = input.cuda()
        target = target.cuda()
        output = model(input)
        loss = criterion(output, target)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        running_loss += loss.item()*input.size(0)
    return running_loss/len(loader.dataset)

def calc_err(pred,real):
    pred = np.array(pred)
    real = np.array(real)
    neq = np.not_equal(pred, real)
    err = float(neq.sum())/pred.shape[0]
    fpr = float(np.logical_and(pred==1,neq).sum())/(real==0).sum()
    fnr = float(np.logical_and(pred==0,neq).sum())/(real==1).sum()
    return err, fpr, fnr

def group_argtopk(groups, data,k=1):
    order = np.lexsort((data, groups)) # 精妙啊,他这个groups比如[0,0,0,0,1,1,2,2,2],这时候按groups排序,前几个都是0不好排,按照data中的概率值进行排序,然后得到对应的索引值
    groups = groups[order]
    data = data[order]
    index = np.empty(len(groups), 'bool')
    index[-k:] = True # 这表示groups最后一个一定被取到,因为最后一个一定最大啊
    index[:-k] = groups[k:] != groups[:-k] # 错位求得每个slide中的最大值
    # 返回的order长度就是总的slide的数量*k,当中的每一个值就是每个slide前k个的索引值
    return list(order[index])

def group_max(groups, data, nmax):
    out = np.empty(nmax)
    out[:] = np.nan
    order = np.lexsort((data, groups))
    groups = groups[order]
    data = data[order]
    index = np.empty(len(groups), 'bool')
    index[-1] = True
    index[:-1] = groups[1:] != groups[:-1]
    out[groups[index]] = data[index]
    return out

class MILdataset(data.Dataset):
    def __init__(self, libraryfile='', transform=None):
        lib = torch.load(libraryfile)
        slides = []
        for i,name in enumerate(lib['slides']):
            sys.stdout.write('Opening SVS headers: [{}/{}]\r'.format(i+1, len(lib['slides'])))
            sys.stdout.flush()
            slides.append(openslide.OpenSlide(name))
        print('')
        #Flatten grid
        grid = []
        slideIDX = []
        for i,g in enumerate(lib['grid']):
            grid.extend(g) # 这里就是[[x1,y1],[x2,y2]...[xn,yn]],包含所有的小块
            slideIDX.extend([i]*len(g))# 就是对应grid中每一个坐标是属于哪一个slide的

        print('Number of tiles: {}'.format(len(grid))) # 他这个tiles其实就是有多少个小块
        self.slidenames = lib['slides']
        self.slides = slides
        self.targets = lib['targets']
        self.grid = grid
        self.slideIDX = slideIDX
        self.transform = transform
        self.mode = None
        self.mult = lib['mult'] # 接下来的图像尺寸需不需要缩放使用的缩放因子,一般为1,就是不缩放
        self.size = int(np.round(224*lib['mult']))
        self.level = lib['level'] # 因为slide文件本身就是金字塔式的存储方式,这就看取的哪一层的数据,0表示最高层,一般为20倍
    def setmode(self,mode):
        self.mode = mode
    def maketraindata(self, idxs):
        self.t_data = [(self.slideIDX[x],self.grid[x],self.targets[self.slideIDX[x]]) for x in idxs]
    def shuffletraindata(self):
        self.t_data = random.sample(self.t_data, len(self.t_data))
    def __getitem__(self,index):
        if self.mode == 1:
            slideIDX = self.slideIDX[index]
            coord = self.grid[index]
            img = self.slides[slideIDX].read_region(coord,self.level,(self.size,self.size)).convert('RGB')
            if self.mult != 1:
                img = img.resize((224,224),Image.BILINEAR)
            if self.transform is not None:
                img = self.transform(img)
            return img
        elif self.mode == 2:
            slideIDX, coord, target = self.t_data[index]
            img = self.slides[slideIDX].read_region(coord,self.level,(self.size,self.size)).convert('RGB')
            if self.mult != 1:
                img = img.resize((224,224),Image.BILINEAR)
            if self.transform is not None:
                img = self.transform(img)
            return img, target
    def __len__(self):
        if self.mode == 1:
            return len(self.grid)
        elif self.mode == 2:
            return len(self.t_data)

if __name__ == '__main__':
    main()

RNN_train文件

参数说明

  • "slides": list of full paths to WSIs (e.g. my/full/path/slide01.svs). Size of the list equals the number of slides.
  • "grid": list of a list of tuple (x,y) coordinates. Size of the list is equal to the number of slides. Size of each sublist is equal to the number of maximum number of recurrent steps (we used 10). Each sublist is in decreasing order of tumor probability.
  • "targets": list of slide level class (0: benign slide, 1: tumor slide). Size of the list equals the number of slides.
  • "mult": scale factor (float) for achieving resolutions different than the ones saved in the WSI pyramid file. Usually 1.for no scaling.
  • "level": WSI pyramid level (integer) from which to read the tiles. Usually 0 for the highest resolution.

翻译:

"slides":存储了一系列切片的绝对路径,其本质是保存了字符串的一个列表,列表长度为切片个数。

"grid":可以理解是一个3维数组,3个维度分别表示(切片的总数量,单个切片中前10个小块我这里理解似乎有点不对,每个小块左上角的坐标值)

"targets":表示每个切片的标签,阳性为1,阴性为0。具体形状为(切片的总数量,)

"mult":缩放因子,好像没用到,直接设置的为1。

"level":切片层数,具体可以参考https://blog.csdn.net/u013066730/article/details/85049240

方法介绍

介绍下做这个的思想,假设1W张切片,都各自有阴阳性的标签。使用MIL_train训练好的模型对每张切片进行测试,将每个切片概率最高的前10个小块取出来(这个10就是代码中的s,这里只是为了方便说明),就用这10个小块来代表这个切片,这时代表每个切片的小块组成的数据格式为(10,3,224,224)。

然后选取batchsize个切片的数据,那么此时数据形状为(batchsize,10,3,224,224)。当batchsize=1时,即表示将一个切片中的10个小块进行rnn操作,也就是将这10个的特征进行了选择性融合,融合后进行最终的分类,类别数为2,即阴阳性。

这里重点提一下rnn_train训练时为啥数据形状为(s,batchsize,3,height,width)。

首先pytorch在读取数据的时候,会到torch\utils\data\dataloader.py文件下的_worker_loop函数,接着就是获取batchsize个的索引的数据量是RNN_train文件中的rnndata类中的

def __len__(self):
    
    return len(self.targets)

然后根据索引读取图片,到_worker_loop中的去读取数据,具体代码如下:

try:
    samples = collate_fn([dataset[i] for i in batch_indices])
except Exception:
    # It is important that we don't store exc_info in a variable,
    # see NOTE [ Python Traceback Reference Cycle Problem ]
    data_queue.put((idx, ExceptionWrapper(sys.exc_info())))
else:
    data_queue.put((idx, samples))
    del samples

其中dataset就是你自己定义的数据类rnndata,dataset[i]这时调用的就是rnndata类中的 __getitem__函数用于读取数据,然后 __getitem__函数返回的数据形状为(s,3,height,width),那么[dataset[i] for i in batch_indices]这句代码最终返回的形状为(batchsize, s, 3, height, width),此时将这个list送入到collect_fn函数,其源代码如下:

    elif isinstance(batch[0], container_abcs.Sequence):
        transposed = zip(*batch)
        return [default_collate(samples) for samples in transposed]

因为他是一个列表,所以进入到的是这个判断条件,然后通过zip(*)实现了维度的调换(这里zip不理解可以参考https://blog.csdn.net/u013066730/article/details/59006113),使得形状变为(s, batchsize, 3, height, width)。

这时在理解rnn_train中的train_single中的

for s in range(len(inputs)):

就好理解多了。

import os
import sys
import openslide
from PIL import Image
import numpy as np
import random
import argparse
import torch
import torch.nn as nn
import torch.utils.data as data
import torch.optim as optim
import torch.backends.cudnn as cudnn
import torch.nn.functional as F
import torchvision.transforms as transforms
import torchvision.models as models

parser = argparse.ArgumentParser(description='MIL-nature-medicine-2019 RNN aggregator training script')
parser.add_argument('--train_lib', type=str, default='', help='path to train MIL library binary')
parser.add_argument('--val_lib', type=str, default='', help='path to validation MIL library binary. If present.')
parser.add_argument('--output', type=str, default='.', help='name of output file')
parser.add_argument('--batch_size', type=int, default=128, help='mini-batch size (default: 128)')
parser.add_argument('--nepochs', type=int, default=100, help='number of epochs')
parser.add_argument('--workers', default=4, type=int, help='number of data loading workers (default: 4)')
parser.add_argument('--s', default=10, type=int, help='how many top k tiles to consider (default: 10)')
parser.add_argument('--ndims', default=128, type=int, help='length of hidden representation (default: 128)')
parser.add_argument('--model', type=str, help='path to trained model checkpoint')
parser.add_argument('--weights', default=0.5, type=float, help='unbalanced positive class weight (default: 0.5, balanced classes)')
parser.add_argument('--shuffle', action='store_true', help='to shuffle order of sequence')

best_acc = 0
def main():
    global args, best_acc
    args = parser.parse_args()
    
    #load libraries
    normalize = transforms.Normalize(mean=[0.5,0.5,0.5],std=[0.1,0.1,0.1])
    trans = transforms.Compose([
        transforms.ToTensor(),
        normalize
    ])
    train_dset = rnndata(args.train_lib, args.s, args.shuffle, trans)
    train_loader = torch.utils.data.DataLoader(
        train_dset,
        batch_size=args.batch_size, shuffle=True,
        num_workers=args.workers, pin_memory=False)
    val_dset = rnndata(args.val_lib, args.s, False, trans)
    val_loader = torch.utils.data.DataLoader(
        val_dset,
        batch_size=args.batch_size, shuffle=False,
        num_workers=args.workers, pin_memory=False)

    #make model
    embedder = ResNetEncoder(args.model)
    for param in embedder.parameters():
        param.requires_grad = False
    embedder = embedder.cuda()
    embedder.eval()

    rnn = rnn_single(args.ndims)
    rnn = rnn.cuda()
    
    #optimization
    if args.weights==0.5:
        criterion = nn.CrossEntropyLoss().cuda()
    else:
        w = torch.Tensor([1-args.weights,args.weights])
        criterion = nn.CrossEntropyLoss(w).cuda()
    optimizer = optim.SGD(rnn.parameters(), 0.1, momentum=0.9, dampening=0, weight_decay=1e-4, nesterov=True)
    cudnn.benchmark = True

    fconv = open(os.path.join(args.output, 'convergence.csv'), 'w')
    fconv.write('epoch,train.loss,train.fpr,train.fnr,val.loss,val.fpr,val.fnr\n')
    fconv.close()

    #
    for epoch in range(args.nepochs):

        train_loss, train_fpr, train_fnr = train_single(epoch, embedder, rnn, train_loader, criterion, optimizer)
        val_loss, val_fpr, val_fnr = test_single(epoch, embedder, rnn, val_loader, criterion)

        fconv = open(os.path.join(args.output,'convergence.csv'), 'a')
        fconv.write('{},{},{},{},{},{},{}\n'.format(epoch+1, train_loss, train_fpr, train_fnr, val_loss, val_fpr, val_fnr))
        fconv.close()

        val_err = (val_fpr + val_fnr)/2
        if 1-val_err >= best_acc:
            best_acc = 1-val_err
            obj = {
                'epoch': epoch+1,
                'state_dict': rnn.state_dict()
            }
            torch.save(obj, os.path.join(args.output,'rnn_checkpoint_best.pth'))

def train_single(epoch, embedder, rnn, loader, criterion, optimizer):
    rnn.train()
    running_loss = 0.
    running_fps = 0.
    running_fns = 0.

    ## 这里的inputs的shape为(s,batchsize,3,height,width)
    for i,(inputs,target) in enumerate(loader):
        print('Training - Epoch: [{}/{}]\tBatch: [{}/{}]'.format(epoch+1, args.nepochs, i+1, len(loader)))

        batch_size = inputs[0].size(0)
        rnn.zero_grad()

        # 就是让每个slide中的s个状态都各自相加
        state = rnn.init_hidden(batch_size).cuda()
        for s in range(len(inputs)):
            input = inputs[s].cuda()
            _, input = embedder(input)
            output, state = rnn(input, state) #所以循环结束后这里的output的shape是(batchsize, 2), output也就最后一次循环用到了,之前都是只用到state,而state的形状为(batchsize, ndims),这里ndims就是128

        target = target.cuda()
        loss = criterion(output, target)
        loss.backward()
        optimizer.step()
        
        running_loss += loss.item()*target.size(0)
        fps, fns = errors(output.detach(), target.cpu())
        running_fps += fps
        running_fns += fns

    running_loss = running_loss/len(loader.dataset)
    running_fps = running_fps/(np.array(loader.dataset.targets)==0).sum()
    running_fns = running_fns/(np.array(loader.dataset.targets)==1).sum()
    print('Training - Epoch: [{}/{}]\tLoss: {}\tFPR: {}\tFNR: {}'.format(epoch+1, args.nepochs, running_loss, running_fps, running_fns))
    return running_loss, running_fps, running_fns

def test_single(epoch, embedder, rnn, loader, criterion):
    rnn.eval()
    running_loss = 0.
    running_fps = 0.
    running_fns = 0.

    with torch.no_grad():
        for i,(inputs,target) in enumerate(loader):
            print('Validating - Epoch: [{}/{}]\tBatch: [{}/{}]'.format(epoch+1,args.nepochs,i+1,len(loader)))
            
            batch_size = inputs[0].size(0)
            
            state = rnn.init_hidden(batch_size).cuda()
            for s in range(len(inputs)):
                input = inputs[s].cuda()
                _, input = embedder(input)
                output, state = rnn(input, state)
            
            target = target.cuda()
            loss = criterion(output,target)
            
            running_loss += loss.item()*target.size(0)
            fps, fns = errors(output.detach(), target.cpu())
            running_fps += fps
            running_fns += fns
            
    running_loss = running_loss/len(loader.dataset)
    running_fps = running_fps/(np.array(loader.dataset.targets)==0).sum()
    running_fns = running_fns/(np.array(loader.dataset.targets)==1).sum()
    print('Validating - Epoch: [{}/{}]\tLoss: {}\tFPR: {}\tFNR: {}'.format(epoch+1, args.nepochs, running_loss, running_fps, running_fns))
    return running_loss, running_fps, running_fns

def errors(output, target):
    _, pred = output.topk(1, 1, True, True)
    pred = pred.squeeze().cpu().numpy()
    real = target.numpy()
    neq = pred!=real
    fps = float(np.logical_and(pred==1,neq).sum())
    fns = float(np.logical_and(pred==0,neq).sum())
    return fps,fns

class ResNetEncoder(nn.Module):

    def __init__(self, path):
        super(ResNetEncoder, self).__init__()

        temp = models.resnet34()
        temp.fc = nn.Linear(temp.fc.in_features, 2)
        ch = torch.load(path)
        temp.load_state_dict(ch['state_dict'])
        self.features = nn.Sequential(*list(temp.children())[:-1])
        self.fc = temp.fc

    def forward(self,x):
        x = self.features(x)
        x = x.view(x.size(0),-1)
        return self.fc(x), x

class rnn_single(nn.Module):

    def __init__(self, ndims):
        super(rnn_single, self).__init__()
        self.ndims = ndims

        self.fc1 = nn.Linear(512, ndims)
        self.fc2 = nn.Linear(ndims, ndims)

        self.fc3 = nn.Linear(ndims, 2)

        self.activation = nn.ReLU()

    def forward(self, input, state):
        input = self.fc1(input)
        state = self.fc2(state)
        state = self.activation(state+input)
        output = self.fc3(state)
        return output, state

    def init_hidden(self, batch_size):
        return torch.zeros(batch_size, self.ndims)

class rnndata(data.Dataset):

    def __init__(self, path, s, shuffle=False, transform=None):

        lib = torch.load(path)
        self.s = s
        self.transform = transform
        self.slidenames = lib['slides']
        self.targets = lib['targets']
        self.grid = lib['grid']
        self.level = lib['level']
        self.mult = lib['mult']
        self.size = int(224*lib['mult'])
        self.shuffle = shuffle

        slides = []
        for i, name in enumerate(lib['slides']):
            sys.stdout.write('Opening SVS headers: [{}/{}]\r'.format(i+1, len(lib['slides'])))
            sys.stdout.flush()
            slides.append(openslide.OpenSlide(name))
        print('')
        self.slides = slides

    def __getitem__(self,index):

        slide = self.slides[index]
        grid = self.grid[index]
        if self.shuffle:
            grid = random.sample(grid,len(grid))

        out = []
        s = min(self.s, len(grid))
        for i in range(s):
            img = slide.read_region(grid[i], self.level, (self.size, self.size)).convert('RGB')
            if self.mult != 1:
                img = img.resize((224,224), Image.BILINEAR)
            if self.transform is not None:
                img = self.transform(img)
            out.append(img)
        
        return out, self.targets[index]

    def __len__(self):
        
        return len(self.targets)

if __name__ == '__main__':
    main()

 

你可能感兴趣的:(论文笔记,Pytorch,Python)