本文记录自己的学习过程,内容包括:
代码解读:Pytorch-UNet
深度学习编程基础:Pytorch-深度学习(新手友好)
UNet论文解读:医学图像分割:U_Net 论文阅读
数据:https://hackernoon.com/hacking-gta-v-for-carvana-kaggle-challenge-6d0b7fb4c781
完整代码解读详见:U-Net代码复现–更新中
(还在更新中。。。。。。。。。。。。)
(还在更新中。。。。。。。。。。。。)
(还在更新中。。。。。。。。。。。。)
(还在更新中。。。。。。。。。。。。)
CarvanaDataset: 读取并创建输入数据(具体实现详见:U-Net代码复现–utils data_loading.py)
# 1. Create dataset
try:
dataset = CarvanaDataset(dir_img, dir_mask, img_scale)
except (AssertionError, RuntimeError, IndexError):
dataset = BasicDataset(dir_img, dir_mask, img_scale)
random_split()函数说明:
参数说明:
# 2. Split into train / validation partitions
# 将数据集分为训练集和验证集
n_val = int(len(dataset) * val_percent)
n_train = len(dataset) - n_val
train_set, val_set = random_split(dataset, [n_train, n_val], generator=torch.Generator().manual_seed(0))
加载和迭代数据集
关于DataLoader参考:Pytorch:torch.utils.data.DataLoader()
# 3. Create data loaders
loader_args = dict(batch_size=batch_size, num_workers=os.cpu_count(), pin_memory=True)
train_loader = DataLoader(train_set, shuffle=True, **loader_args)
val_loader = DataLoader(val_set, shuffle=False, drop_last=True, **loader_args)
Wandb是Weights & Biases的缩写,是类似TensorBoard, visdom的一款可视化工具;是属于Python的,不是Pytorch的(大家感兴趣可以自己看看,这里就不多解释了)
# (Initialize logging)
experiment = wandb.init(project='U-Net', resume='allow', anonymous='must')
experiment.config.update(
dict(epochs=epochs, batch_size=batch_size, learning_rate=learning_rate,
val_percent=val_percent, save_checkpoint=save_checkpoint, img_scale=img_scale, amp=amp)
)
打印日志
logging.info(f'''Starting training:
Epochs: {epochs}
Batch size: {batch_size}
Learning rate: {learning_rate}
Training size: {n_train}
Validation size: {n_val}
Checkpoints: {save_checkpoint}
Device: {device.type}
Images scaling: {img_scale}
Mixed Precision: {amp}
''')
optim.RMSprop
,参考机器学习:优化器Optimizer的总结与比较torch.optim.lr_scheduler
模块提供了一些根据epoch训练次数来调整学习率(learning rate)的方法。一般情况下我们会设置随着epoch的增大而逐渐减小学习率从而达到更好的训练效果。torch.optim.lr_scheduler.ReduceLROnPlateau
则提供了基于训练中某些测量值使学习率动态下降的方法。torch.cuda.amp.GradScaler
参考:PyTorch : torch.cuda.amp: 自动混合精度详解nn.CrossEntropyLoss()
损失函数 # 4. Set up the optimizer, the loss, the learning rate scheduler and the loss scaling for AMP
optimizer = optim.RMSprop(model.parameters(),
lr=learning_rate, weight_decay=weight_decay, momentum=momentum, foreach=True)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'max', patience=5) # goal: maximize Dice score
grad_scaler = torch.cuda.amp.GradScaler(enabled=amp)
criterion = nn.CrossEntropyLoss() if model.n_classes > 1 else nn.BCEWithLogitsLoss()
global_step = 0
训练函数,这部分内容比较多,将拆分为多个小部分分析:
========== part 1============
迭代数据集:
for batch in train_loader:
对应前文中:train_loader = DataLoader (train_set, shuffle=True, **loader_args)
images, true_masks = batch['image'], batch['mask']
对应U-Net代码复现–utils data_loading.py中的
'image': torch.as_tensor(img.copy()).float().contiguous()
,
'mask': torch.as_tensor(mask.copy()).long().contiguous()
.
images = images.to(device=device, dtype=torch.float32, memory_format=torch.channels_last)
.
true_masks = true_masks.to(device=device, dtype=torch.long)
.
将Tensor或模型移动到指定的设备上;关于.to()
的用法详见:pytorch:to()、device()、cuda()
========== part 2============
with torch.autocast(device.type if device.type != 'mps' else 'cpu', enabled=amp)
:
with torch.autocast: 语句块内的代码会自动进行混合精度计算,也就是根据输入数据的类型自动选择合适的精度进行计算
masks_pred = model(images)
.
单次预测结果
if model.n_classes == 1:
.
n_classes:输出图的通道数,也就是最终得到几张特征图
loss += dice_loss(......)
.
# 5. Begin training
# ================================ part 1 =======================================
for epoch in range(1, epochs + 1):
model.train()
epoch_loss = 0
with tqdm(total=n_train, desc=f'Epoch {epoch}/{epochs}', unit='img') as pbar:
for batch in train_loader:
images, true_masks = batch['image'], batch['mask']
assert images.shape[1] == model.n_channels, \
f'Network has been defined with {model.n_channels} input channels, ' \
f'but loaded images have {images.shape[1]} channels. Please check that ' \
'the images are loaded correctly.'
images = images.to(device=device, dtype=torch.float32, memory_format=torch.channels_last)
true_masks = true_masks.to(device=device, dtype=torch.long)
# ================================ part 2 =======================================
with torch.autocast(device.type if device.type != 'mps' else 'cpu', enabled=amp):
masks_pred = model(images)
if model.n_classes == 1:
loss = criterion(masks_pred.squeeze(1), true_masks.float())
loss += dice_loss(F.sigmoid(masks_pred.squeeze(1)), true_masks.float(), multiclass=False)
else:
loss = criterion(masks_pred, true_masks)
loss += dice_loss(
F.softmax(masks_pred, dim=1).float(),
F.one_hot(true_masks, model.n_classes).permute(0, 3, 1, 2).float(),
multiclass=True
)
optimizer.zero_grad(set_to_none=True)
grad_scaler.scale(loss).backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), gradient_clipping)
grad_scaler.step(optimizer)
grad_scaler.update()
pbar.update(images.shape[0])
global_step += 1
epoch_loss += loss.item()
experiment.log({
'train loss': loss.item(),
'step': global_step,
'epoch': epoch
})
pbar.set_postfix(**{'loss (batch)': loss.item()})
# Evaluation round
division_step = (n_train // (5 * batch_size))
if division_step > 0:
if global_step % division_step == 0:
histograms = {}
for tag, value in model.named_parameters():
tag = tag.replace('/', '.')
if not (torch.isinf(value) | torch.isnan(value)).any():
histograms['Weights/' + tag] = wandb.Histogram(value.data.cpu())
if not (torch.isinf(value.grad) | torch.isnan(value.grad)).any():
histograms['Gradients/' + tag] = wandb.Histogram(value.grad.data.cpu())
val_score = evaluate(model, val_loader, device, amp)
scheduler.step(val_score)
logging.info('Validation Dice score: {}'.format(val_score))
try:
experiment.log({
'learning rate': optimizer.param_groups[0]['lr'],
'validation Dice': val_score,
'images': wandb.Image(images[0].cpu()),
'masks': {
'true': wandb.Image(true_masks[0].float().cpu()),
'pred': wandb.Image(masks_pred.argmax(dim=1)[0].float().cpu()),
},
'step': global_step,
'epoch': epoch,
**histograms
})
except:
pass
if save_checkpoint:
Path(dir_checkpoint).mkdir(parents=True, exist_ok=True)
state_dict = model.state_dict()
state_dict['mask_values'] = dataset.mask_values
torch.save(state_dict, str(dir_checkpoint / 'checkpoint_epoch{}.pth'.format(epoch)))
logging.info(f'Checkpoint {epoch} saved!')
argparse.ArgumentParser
:创建 ArgumentParser() 对象parser.add_argument
:调用 add_argument() 方法添加参数parser.parse_args()
: 使用 parse_args() 解析添加的参数其中
parser.add_argument
:
name or flags
- 一个命名或者一个选项字符串的列表,例如 foo 或 -f, --foo。
action
-当参数在命令行中出现时使用的动作基本类型。
nargs
- 命令行参数应当消耗的数目。
const
- 被一些 action 和 nargs选择所需求的常数。
default
- 当参数未在命令行中出现时使用的值。
choices
- 可用的参数的容器。
required
-此命令行选项是否可省略 (仅选项可用)。
help
- 一个此选项作用的简单描述。
metavar
- 在使用方法消息中使用的参数值示例。
dest
- 被添加到 parse_args() 所返回对象上的属性名。
def get_args():
parser = argparse.ArgumentParser(description='Train the UNet on images and target masks')
parser.add_argument('--epochs', '-e', metavar='E', type=int, default=5, help='Number of epochs')
parser.add_argument('--batch-size', '-b', dest='batch_size', metavar='B', type=int, default=1, help='Batch size')
parser.add_argument('--learning-rate', '-l', metavar='LR', type=float, default=1e-5,
help='Learning rate', dest='lr')
parser.add_argument('--load', '-f', type=str, default=False, help='Load model from a .pth file')
parser.add_argument('--scale', '-s', type=float, default=0.5, help='Downscaling factor of the images')
parser.add_argument('--validation', '-v', dest='val', type=float, default=10.0,
help='Percent of the data that is used as validation (0-100)')
parser.add_argument('--amp', action='store_true', default=False, help='Use mixed precision')
parser.add_argument('--bilinear', action='store_true', default=False, help='Use bilinear upsampling')
parser.add_argument('--classes', '-c', type=int, default=2, help='Number of classes')
return parser.parse_args()