PyTorch已经在torch.optim.lr_scheduler为我们封装好了一些动态调整学习率的方法。调用方法如下:
# 选择一种优化器
optimizer = torch.optim.Adam(...)
# 选择上面提到的一种或多种动态调整学习率的方法
scheduler1 = torch.optim.lr_scheduler....
scheduler2 = torch.optim.lr_scheduler....
...
schedulern = torch.optim.lr_scheduler....
# 进行训练
for epoch in range(100):
train(...)
validate(...)
optimizer.step()
# 需要在优化器参数更新之后再动态调整学习率
scheduler1.step()
...
schedulern.step()
也可通过自定义函数来定义学习率变化。
# 冻结原预训练模型的参数
def set_parameter_requires_grad(model, feature_extracting):
if feature_extracting:
for param in model.parameters():
param.requires_grad = False
import torchvision.models as models
# 冻结参数的梯度
feature_extract = True
# 加载预训练好的模型
model = models.resnet18(pretrained=True)
set_parameter_requires_grad(model, feature_extract)
# 输出部分的全连接层
num_ftrs = model.fc.in_features
model.fc = nn.Linear(in_features=num_ftrs, out_features=4, bias=True)
timm 是Ross Wightman创建torchvision的扩充库,提供了许多计算机视觉的SOTA模型。可以通过以下命令获取预训练好的模型清单:
import timm
avail_pretrained_models = timm.list_models(pretrained=True)#还支持模糊查询
import timm
import torch
# 将1000类改为10类输出
model = timm.create_model('resnet34',num_classes=10,pretrained=True)
# 改变输入通道数
model = timm.create_model('resnet34',num_classes=10,pretrained=True,in_chans=1)
torch.save(model.state_dict(),'./checkpoint/timm_model.pth')
model.load_state_dict(torch.load('./checkpoint/timm_model.pth'))
from torch.cuda.amp import autocast
# 用autocast装饰模型中的forward函数
@autocast()
def forward(self, x):
...
return x
# 训练
for x in train_loader:
x = x.cuda()
#在将数据输入模型及其之后的部分放入with autocast()
with autocast():
output = model(x)
...
图片数据可以使用imgaug库以及Albumentations库来进行数据增强。
import argparse #python内置,无需安装
# 创建ArgumentParser()对象
parser = argparse.ArgumentParser()
# 添加参数
parser.add_argument('-o', '--output', action='store_true',
help="shows output")
# action = `store_true` 会将output参数记录为True
parser.add_argument('--lr', type=float, default=3e-5, help='select the learning rate, default=1e-3')
# type 规定了参数的格式
# default 规定了默认值
parser.add_argument('--batch_size', type=int, required=True, help='input batch size')
# required=True 意为必选参数
# 使用parse_args()解析函数
args = parser.parse_args()
if args.output:
print("This is some output")
print(f"learning rate:{args.lr} ")
import argparse
def get_options(parser=argparse.ArgumentParser()):
parser.add_argument('--workers', type=int, default=0,
help='number of data loading workers, you had better put it '
'4 times of your gpu')
parser.add_argument('--batch_size', type=int, default=4, help='input batch size, default=64')
parser.add_argument('--niter', type=int, default=10, help='number of epochs to train for, default=10')
parser.add_argument('--lr', type=float, default=3e-5, help='select the learning rate, default=1e-3')
parser.add_argument('--seed', type=int, default=118, help="random seed")
parser.add_argument('--cuda', action='store_true', default=True, help='enables cuda')
parser.add_argument('--checkpoint_path',type=str,default='',
help='Path to load a previous trained model if not empty (default empty)')
parser.add_argument('--output',action='store_true',default=True,help="shows output")
opt = parser.parse_args()
if opt.output:
print(f'num_workers: {opt.workers}')
print(f'batch_size: {opt.batch_size}')
print(f'epochs (niters) : {opt.niter}')
print(f'learning rate : {opt.lr}')
print(f'manual_seed: {opt.seed}')
print(f'cuda enable: {opt.cuda}')
print(f'checkpoint_path: {opt.checkpoint_path}')
return opt
if __name__ == '__main__':
opt = get_options()
import config
opt = config.get_options()
manual_seed = opt.seed
num_workers = opt.workers
batch_size = opt.batch_size
lr = opt.lr
niters = opt.niters
checkpoint_path = opt.checkpoint_path
# 随机数的设置,保证复现结果
def set_seed(seed):
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
random.seed(seed)
np.random.seed(seed)
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True
...
if __name__ == '__main__':
set_seed(manual_seed)
for epoch in range(niters):
train(model,lr,batch_size,num_workers,checkpoint_path)
val(model,lr,batch_size,num_workers,checkpoint_path)
datawhale 深入浅出pytorch