pytorch-unet
来源:https://github.com/milesial/Pytorch-UNet
前两天搞了一下图像分割,用了下unet。之前没怎么用过。复现了一下18年的une pytorch 版本,记录学习一下 (//过了一年了来补充完善一下。。)
if __name__ == '__main__':
args = get_args()
logging.basicConfig(level=logging.INFO, format='%(levelname)s: %(message)s')
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
logging.info(f'Using device {device}')
# Change here to adapt to your data
# n_channels=3 for RGB images
# n_classes is the number of probabilities you want to get per pixel
# 如果是RGB图像 n_channels=3,如果是医学图像(大部分是灰度图)n_channels=1。n_classes是
# 你要分割的类别数加1,比如你的前景有两类,n_classes = 3哦
net = UNet(n_channels=3, n_classes=args.classes, bilinear=args.bilinear)
logging.info(f'Network:\n'
f'\t{net.n_channels} input channels\n'
f'\t{net.n_classes} output channels (classes)\n'
f'\t{"Bilinear" if net.bilinear else "Transposed conv"} upscaling')
if args.load:
net.load_state_dict(torch.load(args.load, map_location=device))
logging.info(f'Model loaded from {args.load}')
net.to(device=device)
try:
train_net(net=net,
epochs=args.epochs,
batch_size=args.batch_size,
learning_rate=args.lr,
device=device,
img_scale=args.scale,
val_percent=args.val / 100,
amp=args.amp)
except KeyboardInterrupt:
torch.save(net.state_dict(), 'INTERRUPTED.pth')
logging.info('Saved interrupt')
raise
有一个定义的函数get_args():
def get_args():
parser = argparse.ArgumentParser(description='Train the UNet on images and target masks')
parser.add_argument('--epochs', '-e', metavar='E', type=int, default=5, help='Number of epochs')
parser.add_argument('--batch-size', '-b', dest='batch_size', metavar='B', type=int, default=1, help='Batch size')
parser.add_argument('--learning-rate', '-l', metavar='LR', type=float, default=1e-5,
help='Learning rate', dest='lr')
parser.add_argument('--load', '-f', type=str, default=False, help='Load model from a .pth file')
parser.add_argument('--scale', '-s', type=float, default=0.5, help='Downscaling factor of the images')
parser.add_argument('--validation', '-v', dest='val', type=float, default=10.0,
help='Percent of the data that is used as validation (0-100)')
parser.add_argument('--amp', action='store_true', default=False, help='Use mixed precision')
parser.add_argument('--bilinear', action='store_true', default=False, help='Use bilinear upsampling')
parser.add_argument('--classes', '-c', type=int, default=2, help='Number of classes')
return parser.parse_args()
argparser主要有三个步骤:
1. Argumentparser()对象,将命令行解析成python数据类型所需要的全部信息。
2.add_argument()方法添加函数,主要定batchsize,lr,epochs,这些乱七八糟的东西,这样方便在命令行直接修改。
3. 这个封装的函数相当于最终解析出来(parparse_args())。
创建解析器 - 添加命令行参数-解析参数
主要这个函数其实就是干了三个事:1.通过arg parser设定需要的轮次,bs,学习率等 。2. 设定输入图像,输出图像尺寸。 3. 是用cpu,一块gpu还是用多块gpu
1. 创建数据集 (这一步的作用主要是实现loading data 和 augmentation)
# 1. Create dataset
try:
dataset = CarvanaDataset(dir_img, dir_mask, img_scale)
except (AssertionError, RuntimeError):
dataset = BasicDataset(dir_img, dir_mask, img_scale)
定义数据集的路径,mask的路径。可以在这做数据增强,下面举个例子是训练集的数据增强,一般都是用 transforms.Compose的方法,比如下图就是用了随机旋转,随即翻转,转成tensor,做标准化(如果想添加自己的数据增强方法就在transform里自己定义一个类,然后compose进来实现自定义数据增强)。
transform_train = transforms.Compose([
transforms.RandomRotation(degrees=8),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize(mean, std), ])
2. 划分数据集(训练集,验证集,测试集,通过random_split对数据集进行一定比例的随即划分。
# 2. Split into train / validation partitions
n_val = int(len(dataset) * val_percent)
n_train = len(dataset) - n_val
train_set, val_set = random_split(dataset, [n_train, n_val], generator=torch.Generator().manual_seed(0))
3. 创建dataloder
# 3. Create data loaders
loader_args = dict(batch_size=batch_size, num_workers=4, pin_memory=True)
train_loader = DataLoader(train_set, shuffle=True, **loader_args)
val_loader = DataLoader(val_set, shuffle=False, drop_last=True, **loader_args)
首先简单介绍一下啥是dataLoader,它是PyTorch中数据读取的一个重要接口,该接口定义在dataloader.py中,该接口的目的:将自定义的Dataset根据batch size大小、是否shuffle等封装成一个一个Batch Size大小的Tensor,用于后面的训练。
例如:定义的train_loder继承了dataloder,用自己的train_set数据集,按batchsize分成一批一批的tensor去训练;shuffle是每一个epoch结束之后,是否要重新排序;num_worker这个参数决定了有几个进程来处理data loading。0意味着所有的数据都会被load进主进程。根据我的经验哈,一般如果出席那这些osa,显存太小这种报错,就把num_workers改成0,或者4就行了。一般gpu上跑还是8或者16差不多,这个是影响训练速度的,num_workers太小的话,gpu利用率会非常低,训练不好。(很多时候我在本地都是num = 0,然后放到服务器上忘了改num,直接显存就爆炸了。。)
4. 创建优化器,定义学习率策略,定义损失函数
# 4. Set up the optimizer, the loss, the learning rate scheduler and the loss scaling for AMP
optimizer = optim.RMSprop(net.parameters(), lr=learning_rate, weight_decay=1e-8, momentum=0.9)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'max', patience=2) # goal: maximize Dice score
grad_scaler = torch.cuda.amp.GradScaler(enabled=amp)
criterion = nn.CrossEntropyLoss()
global_step = 0
optimizer能够保持当前参数状态并基于计算得到的梯度进行参数更新,可以继承网络初始参数,权重衰减,学习率策略啊一些东西。方法分为2大类:一大类方法是SGD及其改进(加Momentum);另外一大类是Per-parameter adaptive learning rate methods(逐参数适应学习率方法),包括AdaGrad、RMSProp、Adam等。这东西就跟机器学习当中选择什么算法来进行梯度更新一样。
我的经验:一般优化器就是SGD或者Adam; 学习率策略一般是 lr_ = base_lr * (1.0 - iter_num / max_iterations) ** 0.9 或者过多少轮减半什么的,初始学习率最多0.1,一般0.01或者0.001;bs是2,4,8,16。
5. 开始训练 begin
# 5. Begin training
for epoch in range(1, epochs+1):
net.train()
epoch_loss = 0
with tqdm(total=n_train, desc=f'Epoch {epoch}/{epochs}', unit='img') as pbar:
for batch in train_loader:
images = batch['image']
true_masks = batch['mask']
assert images.shape[1] == net.n_channels, \
f'Network has been defined with {net.n_channels} input channels, ' \
f'but loaded images have {images.shape[1]} channels. Please check that ' \
'the images are loaded correctly.'
images = images.to(device=device, dtype=torch.float32)
true_masks = true_masks.to(device=device, dtype=torch.long)
with torch.cuda.amp.autocast(enabled=amp):
masks_pred = net(images)
loss = criterion(masks_pred, true_masks) \
+ dice_loss(F.softmax(masks_pred, dim=1).float(),
F.one_hot(true_masks, net.n_classes).permute(0, 3, 1, 2).float(),
multiclass=True)
optimizer.zero_grad(set_to_none=True)
grad_scaler.scale(loss).backward()
grad_scaler.step(optimizer)
grad_scaler.update()
pbar.update(images.shape[0])
global_step += 1
epoch_loss += loss.item()
experiment.log({
'train loss': loss.item(),
'step': global_step,
'epoch': epoch
})
pbar.set_postfix(**{'loss (batch)': loss.item()})
在每一个epoch里,通过train_loder得到多少个batch,每个batch,每个batch训练。通过网络分割得到的masks_prede和传入的true_masks进行loss计算,在优化器内不断反向传播,更新梯度,更新优化器,使得loss越来越小,并且趋于稳定。
这里主要就是loss的选择了,一般就是交叉熵和dice_loss
6. 计算
val_score = evaluate(net, val_loader, device)
scheduler.step(val_score)
logging.info('Validation Dice score: {}'.format(val_score))
通过训练的网络对验证集图片进行预测,然后与验证集的true_masks进行比较得到精度。
7. 保存权重
if save_checkpoint:
Path(dir_checkpoint).mkdir(parents=True, exist_ok=True)
torch.save(net.state_dict(), str(dir_checkpoint / 'checkpoint_epoch{}.pth'.format(epoch + 1)))
logging.info(f'Checkpoint {epoch + 1} saved!')
在这里可以将训练的每一轮参数保存下来,保存成pth文件,到时候在预测的时候直接用就可以了。
也就是训练当中所用到的net的结构是什么样子的
class UNet(nn.Module):
def __init__(self, n_channels, n_classes, bilinear=True):
super(UNet, self).__init__()
self.n_channels = n_channels
self.n_classes = n_classes
self.bilinear = bilinear
self.inc = DoubleConv(n_channels, 64)
self.down1 = Down(64, 128)
self.down2 = Down(128, 256)
self.down3 = Down(256, 512)
factor = 2 if bilinear else 1
self.down4 = Down(512, 1024 // factor)
self.up1 = Up(1024, 512 // factor, bilinear)
self.up2 = Up(512, 256 // factor, bilinear)
self.up3 = Up(256, 128 // factor, bilinear)
self.up4 = Up(128, 64, bilinear)
self.outc = OutConv(64, n_classes)
def forward(self, x):
x1 = self.inc(x)
x2 = self.down1(x1)
x3 = self.down2(x2)
x4 = self.down3(x3)
x5 = self.down4(x4)
x = self.up1(x5, x4)
x = self.up2(x, x3)
x = self.up3(x, x2)
x = self.up4(x, x1)
logits = self.outc(x)
return logits
结构比较简单,在encoder阶段,主要有几个模块
第一个模块就是doublecov(两个(conv2d,bn,relu)),主要用在最开始将三通道图片转化为64通道图片。照例子来说,每一次unet conv后,图片尺寸都会下降2,但是在代码中
out_size = (in_size - K + 2P)/ S +1
特意将大小设成3,padding设成1,stride设成1,这样在做conv的时候图片尺寸就不会发生变化了
第二个模块就是down模块(maxpool2d,doubleconv),每次将图片尺寸减半,并在池化后进行conv的操作,增加通道数。
在decoder阶段:
up模块(unsample)+conv 上采样将图片尺寸增加,conv将通道数减少,并且和endocder同层的特征图进行连接
最后outc模块,看你是想输出几通道的图片,就有几个卷积核就好了。
debug:
AssertionError: Either no mask or multiple masks found for the ID 0008052191_9: []
解决方案:找到了img_file路径,mask file路径找不到。在data_loading里将mask_suffix改为空,如果你的img和mask是一摸一样的名字的话。
RuntimeError: CUDA error: CUBLAS_STATUS_NOT_INITIALIZED when calling `cublasCreate(handle)`
解决方案:是分类标签越界的问题,最终mask是要分0,1的。项目当中数据集mask是0,1。但是我的image跟mask都是0-255,位深度24,所以原项目是只将img除了255,代码中只要将is_mask改成False就好了。
RuntimeError: 1only batches of spatial targets supported (3D tensors) but got targets of size: : [1, 1, 256, 256]
解决方案:这个报错比较明显,可以打印一下自己在求loss的时候的target跟ground truth,将图像尺寸reshape一下(比如:true_mask.reshape(1,256,256)就好了)。同样在验证集的时候也遇到这个问题,同样的解决方法自然就完成了。
wandb.errors.CommError: check_hostname requires server_hostname
解决方案:这个我也不大懂反正大概意思因为我开了软件,可能wandb那里出现了什么问题,把他关了就好了。
BrokenPipeError: [Errno 32] Broken pipe
解决方案:好像还有是说os:显存太小一类的,将num_workers改成0就好了。