U-Net代码复现--train.py

本文记录自己的学习过程,内容包括:
代码解读:Pytorch-UNet
深度学习编程基础:Pytorch-深度学习(新手友好)
UNet论文解读:医学图像分割:U_Net 论文阅读
数据:https://hackernoon.com/hacking-gta-v-for-carvana-kaggle-challenge-6d0b7fb4c781

完整代码解读详见:U-Net代码复现–更新中

(还在更新中。。。。。。。。。。。。)
(还在更新中。。。。。。。。。。。。)
(还在更新中。。。。。。。。。。。。)
(还在更新中。。。。。。。。。。。。)

train.py

CarvanaDataset: 读取并创建输入数据(具体实现详见:U-Net代码复现–utils data_loading.py)

	# 1. Create dataset
    try:
        dataset = CarvanaDataset(dir_img, dir_mask, img_scale)
    except (AssertionError, RuntimeError, IndexError):
        dataset = BasicDataset(dir_img, dir_mask, img_scale)

random_split()函数说明:

  • 这个函数的作用是划分数据集

参数说明:

  • dataset (Dataset): 划分的数据集
  • lengths (sequence): 被划分数据集的长度
	# 2. Split into train / validation partitions
    # 将数据集分为训练集和验证集
    n_val = int(len(dataset) * val_percent)
    n_train = len(dataset) - n_val
    train_set, val_set = random_split(dataset, [n_train, n_val], generator=torch.Generator().manual_seed(0))

加载和迭代数据集
关于DataLoader参考:Pytorch:torch.utils.data.DataLoader()

	# 3. Create data loaders
    loader_args = dict(batch_size=batch_size, num_workers=os.cpu_count(), pin_memory=True)
    train_loader = DataLoader(train_set, shuffle=True, **loader_args)
    val_loader = DataLoader(val_set, shuffle=False, drop_last=True, **loader_args)

Wandb是Weights & Biases的缩写,是类似TensorBoard, visdom的一款可视化工具;是属于Python的,不是Pytorch的(大家感兴趣可以自己看看,这里就不多解释了)

	# (Initialize logging)
    experiment = wandb.init(project='U-Net', resume='allow', anonymous='must')
    experiment.config.update(
        dict(epochs=epochs, batch_size=batch_size, learning_rate=learning_rate,
             val_percent=val_percent, save_checkpoint=save_checkpoint, img_scale=img_scale, amp=amp)
    )

打印日志

	logging.info(f'''Starting training:
        Epochs:          {epochs}
        Batch size:      {batch_size}
        Learning rate:   {learning_rate}
        Training size:   {n_train}
        Validation size: {n_val}
        Checkpoints:     {save_checkpoint}
        Device:          {device.type}
        Images scaling:  {img_scale}
        Mixed Precision: {amp}
    ''')
  • 关于 optim.RMSprop ,参考机器学习:优化器Optimizer的总结与比较
  • torch.optim.lr_scheduler 模块提供了一些根据epoch训练次数来调整学习率(learning rate)的方法。一般情况下我们会设置随着epoch的增大而逐渐减小学习率从而达到更好的训练效果。
  • torch.optim.lr_scheduler.ReduceLROnPlateau 则提供了基于训练中某些测量值使学习率动态下降的方法。
  • torch.cuda.amp.GradScaler 参考:PyTorch : torch.cuda.amp: 自动混合精度详解
  • nn.CrossEntropyLoss() 损失函数
	# 4. Set up the optimizer, the loss, the learning rate scheduler and the loss scaling for AMP
    optimizer = optim.RMSprop(model.parameters(),
                              lr=learning_rate, weight_decay=weight_decay, momentum=momentum, foreach=True)
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'max', patience=5)  # goal: maximize Dice score
    grad_scaler = torch.cuda.amp.GradScaler(enabled=amp)
    criterion = nn.CrossEntropyLoss() if model.n_classes > 1 else nn.BCEWithLogitsLoss()
    global_step = 0

训练函数,这部分内容比较多,将拆分为多个小部分分析:
========== part 1============
迭代数据集:

  • for batch in train_loader: 对应前文中:train_loader = DataLoader (train_set, shuffle=True, **loader_args)

  • images, true_masks = batch['image'], batch['mask'] 对应U-Net代码复现–utils data_loading.py中的

    'image': torch.as_tensor(img.copy()).float().contiguous(),
    'mask': torch.as_tensor(mask.copy()).long().contiguous().

  • images = images.to(device=device, dtype=torch.float32, memory_format=torch.channels_last) .
    true_masks = true_masks.to(device=device, dtype=torch.long) .
    将Tensor或模型移动到指定的设备上;关于.to()的用法详见:pytorch:to()、device()、cuda()

========== part 2============

  • with torch.autocast(device.type if device.type != 'mps' else 'cpu', enabled=amp):

    with torch.autocast: 语句块内的代码会自动进行混合精度计算,也就是根据输入数据的类型自动选择合适的精度进行计算

  • masks_pred = model(images) .
    单次预测结果

  • if model.n_classes == 1: .
    n_classes:输出图的通道数,也就是最终得到几张特征图

  • loss += dice_loss(......) .

	# 5. Begin training
	# ================================ part 1 =======================================
    for epoch in range(1, epochs + 1):
        model.train()
        epoch_loss = 0
        with tqdm(total=n_train, desc=f'Epoch {epoch}/{epochs}', unit='img') as pbar:
            for batch in train_loader:
                images, true_masks = batch['image'], batch['mask']

                assert images.shape[1] == model.n_channels, \
                    f'Network has been defined with {model.n_channels} input channels, ' \
                    f'but loaded images have {images.shape[1]} channels. Please check that ' \
                    'the images are loaded correctly.'

                images = images.to(device=device, dtype=torch.float32, memory_format=torch.channels_last)
                true_masks = true_masks.to(device=device, dtype=torch.long)
				# ================================ part 2 =======================================
                with torch.autocast(device.type if device.type != 'mps' else 'cpu', enabled=amp):
                    masks_pred = model(images)
                    if model.n_classes == 1:
                        loss = criterion(masks_pred.squeeze(1), true_masks.float())
                        loss += dice_loss(F.sigmoid(masks_pred.squeeze(1)), true_masks.float(), multiclass=False)
                    else:
                        loss = criterion(masks_pred, true_masks)
                        loss += dice_loss(
                            F.softmax(masks_pred, dim=1).float(),
                            F.one_hot(true_masks, model.n_classes).permute(0, 3, 1, 2).float(),
                            multiclass=True
                        )

                optimizer.zero_grad(set_to_none=True)
                grad_scaler.scale(loss).backward()
                torch.nn.utils.clip_grad_norm_(model.parameters(), gradient_clipping)
                grad_scaler.step(optimizer)
                grad_scaler.update()

                pbar.update(images.shape[0])
                global_step += 1
                epoch_loss += loss.item()
                experiment.log({
                    'train loss': loss.item(),
                    'step': global_step,
                    'epoch': epoch
                })
                pbar.set_postfix(**{'loss (batch)': loss.item()})

                # Evaluation round
                division_step = (n_train // (5 * batch_size))
                if division_step > 0:
                    if global_step % division_step == 0:
                        histograms = {}
                        for tag, value in model.named_parameters():
                            tag = tag.replace('/', '.')
                            if not (torch.isinf(value) | torch.isnan(value)).any():
                                histograms['Weights/' + tag] = wandb.Histogram(value.data.cpu())
                            if not (torch.isinf(value.grad) | torch.isnan(value.grad)).any():
                                histograms['Gradients/' + tag] = wandb.Histogram(value.grad.data.cpu())

                        val_score = evaluate(model, val_loader, device, amp)
                        scheduler.step(val_score)

                        logging.info('Validation Dice score: {}'.format(val_score))
                        try:
                            experiment.log({
                                'learning rate': optimizer.param_groups[0]['lr'],
                                'validation Dice': val_score,
                                'images': wandb.Image(images[0].cpu()),
                                'masks': {
                                    'true': wandb.Image(true_masks[0].float().cpu()),
                                    'pred': wandb.Image(masks_pred.argmax(dim=1)[0].float().cpu()),
                                },
                                'step': global_step,
                                'epoch': epoch,
                                **histograms
                            })
                        except:
                            pass

        if save_checkpoint:
            Path(dir_checkpoint).mkdir(parents=True, exist_ok=True)
            state_dict = model.state_dict()
            state_dict['mask_values'] = dataset.mask_values
            torch.save(state_dict, str(dir_checkpoint / 'checkpoint_epoch{}.pth'.format(epoch)))
            logging.info(f'Checkpoint {epoch} saved!')
  • argparse.ArgumentParser :创建 ArgumentParser() 对象
  • parser.add_argument :调用 add_argument() 方法添加参数
  • parser.parse_args() : 使用 parse_args() 解析添加的参数

其中 parser.add_argument

name or flags - 一个命名或者一个选项字符串的列表,例如 foo 或 -f, --foo。
action -当参数在命令行中出现时使用的动作基本类型。
nargs - 命令行参数应当消耗的数目。
const - 被一些 action 和 nargs选择所需求的常数。
default - 当参数未在命令行中出现时使用的值。
choices - 可用的参数的容器。
required -此命令行选项是否可省略 (仅选项可用)。
help - 一个此选项作用的简单描述。
metavar - 在使用方法消息中使用的参数值示例。
dest - 被添加到 parse_args() 所返回对象上的属性名。

def get_args():
    parser = argparse.ArgumentParser(description='Train the UNet on images and target masks')
    parser.add_argument('--epochs', '-e', metavar='E', type=int, default=5, help='Number of epochs')
    parser.add_argument('--batch-size', '-b', dest='batch_size', metavar='B', type=int, default=1, help='Batch size')
    parser.add_argument('--learning-rate', '-l', metavar='LR', type=float, default=1e-5,
                        help='Learning rate', dest='lr')
    parser.add_argument('--load', '-f', type=str, default=False, help='Load model from a .pth file')
    parser.add_argument('--scale', '-s', type=float, default=0.5, help='Downscaling factor of the images')
    parser.add_argument('--validation', '-v', dest='val', type=float, default=10.0,
                        help='Percent of the data that is used as validation (0-100)')
    parser.add_argument('--amp', action='store_true', default=False, help='Use mixed precision')
    parser.add_argument('--bilinear', action='store_true', default=False, help='Use bilinear upsampling')
    parser.add_argument('--classes', '-c', type=int, default=2, help='Number of classes')

    return parser.parse_args()

你可能感兴趣的:(深度学习,医学图像分割,深度学习,人工智能,python,pytorch)