打卡资料《深入浅出pytorch第六章》:https://datawhalechina.github.io/thorough-pytorch/%E7%AC%AC%E5%85%AD%E7%AB%A0
1.自定义损失函数
-
以函数方式
def my_loss(output, target): loss = torch.mean((output - target)**2) return loss
-
以类方式
如Dice loss:
class DiceLoss(nn.Module): def __init__(self,weight=None,size_average=True): super(DiceLoss,self).__init__() def forward(self,inputs,targets,smooth=1): inputs = F.sigmoid(inputs) inputs = inputs.view(-1) targets = targets.view(-1) intersection = (inputs * targets).sum() dice = (2.*intersection + smooth)/(inputs.sum() + targets.sum() + smooth) return 1 - dice # 使用方法 criterion = DiceLoss() loss = criterion(input,targets)
2.动态调整学习率
使用官方scheduler api: torch.optim.lr_scheduler
-
自定义scheduler
例如 学习率每30轮下降为原来的1/10
def adjust_learning_rate(optimizer, epoch): lr = args.lr * (0.1 ** (epoch // 30)) for param_group in optimizer.param_groups: param_group['lr'] = lr
3.模型微调-torchvision
3.1 微调的流程
- 源数据集上预训练一个神经网络模型,即源模型
- 创建一个新的神经网络模型(可在源模型前n层或整个网络后添加新的层),即目标模型
- 为目标模型添加输出大小为目标数据集类别个数的输出层
- 目标数据集上训练模型,其中输出层从头训练,其余层基于源模型的参数微调
3.2 load 已有模型结构
-
实例化网络并传递pretrained参数
import torchvision.models as models resnet18 = models.resnet18(pretrained=True) alexnet = models.alexnet(pretrained=True) # more见pytorch文档
注意事项:
- 通常PyTorch模型的扩展为
.pt
或.pth
,程序运行时会首先检查默认路径中是否有已经下载的模型权重,一旦权重被下载,下次加载就不需要下载了。 - 一般情况下预训练模型的下载会比较慢,可以直接去 这里 查看模型的
model_urls
,然后手动下载,预训练模型的权重在Linux
和Mac
的默认下载路径是用户根目录下的.cache
文件夹。在Windows
下就是C:\Users\\.cache\torch\hub\checkpoint
。可以通过使用torch.utils.model_zoo.load_url()
设置权重的下载地址。 - 中途强行停止下载的话,一定要去对应路径下将权重文件删除干净,要不然可能会报错。
3.3 训练特定层
- 冻结固定层的参数
def set_parameter_requires_grad(model, feature_extracting):
if feature_extracting:
for param in model.parameters():
param.requires_grad = False
- 训练
import torchvision.models as models
# 冻结参数的梯度
feature_extract = True
model = models.resnet18(pretrained=True)
set_parameter_requires_grad(model, feature_extract)
# 修改模型
num_ftrs = model.fc.in_features
model.fc = nn.Linear(in_features=num_ftrs, out_features=4, bias=True)
另一个微调模块timm
4.半精度训练
PyTorch默认的浮点数存储方式是torch.float32,小数点后位数更多固然能保证数据的精确性,但绝大多数场景其实并不需要这么精确,只保留一半的信息也不会影响结果,也就是使用torch.float16格式。由于数位减了一半,因此被称为“半精度” 。
-
半精度训练的设置
import autocast
from torch.cuda.amp import autocast
-
模型设置
# 使用装饰器方法 @autocast() def forward(self, x): ... return x
-
训练过程
for x in train_loader: x = x.cuda() with autocast(): output = model(x) ...
5.数据增强-imgaug
暂不做记录,数据增强库可看imgaug、Albumentations
6.使用argparse进行调参
6.1 使用步骤
- 创建ArgumentParser()对象
- 调用add_argument()方法添加参数
- 使用parse_args()解析参数
# demo.py
import argparse
# 创建ArgumentParser()对象
parser = argparse.ArgumentParser()
# 添加参数
parser.add_argument('-o', '--output', action='store_true',
help="shows output")
# action = `store_true` 会将output参数记录为True
# type 规定了参数的格式
# default 规定了默认值
parser.add_argument('--lr', type=float, default=3e-5, help='select the learning rate, default=1e-3')
parser.add_argument('--batch_size', type=int, required=True, help='input batch size')
# 使用parse_args()解析函数
args = parser.parse_args()
if args.output:
print("This is some output")
print(f"learning rate:{args.lr} ")
argparse的参数主要可以分为可选参数和必选参数(required)。
6.2 利用config.py高效使用argparse修改超参数
import argparse
def get_options(parser=argparse.ArgumentParser()):
parser.add_argument('--workers', type=int, default=0,
help='number of data loading workers, you had better put it '
'4 times of your gpu')
parser.add_argument('--batch_size', type=int, default=4, help='input batch size, default=64')
parser.add_argument('--niter', type=int, default=10, help='number of epochs to train for, default=10')
parser.add_argument('--lr', type=float, default=3e-5, help='select the learning rate, default=1e-3')
parser.add_argument('--seed', type=int, default=118, help="random seed")
parser.add_argument('--cuda', action='store_true', default=True, help='enables cuda')
parser.add_argument('--checkpoint_path',type=str,default='',
help='Path to load a previous trained model if not empty (default empty)')
parser.add_argument('--output',action='store_true',default=True,help="shows output")
opt = parser.parse_args()
if opt.output:
print(f'num_workers: {opt.workers}')
print(f'batch_size: {opt.batch_size}')
print(f'epochs (niters) : {opt.niter}')
print(f'learning rate : {opt.lr}')
print(f'manual_seed: {opt.seed}')
print(f'cuda enable: {opt.cuda}')
print(f'checkpoint_path: {opt.checkpoint_path}')
return opt
if __name__ == '__main__':
opt = get_options()
# test.py
# 导入必要库
import config
opt = config.get_options()
manual_seed = opt.seed
num_workers = opt.workers
batch_size = opt.batch_size
...
进一步可参考:
- Python argparse 教程
- argparse 官方教程ps://geek-docs.com/python/python-tutorial/python-argparse.html)