五、pytorch进阶训练技巧——pytorch学习

1. 自定义损失函数

pytorch 在nn.Module中提供了很多的常用的损失函数,但是有时需要提出全新的函数来提升模型的表现,这时需要自己来定义损失函数

1.1 以函数方式定义

就是自己定义一个函数,没啥好说的

1.2 以类的方式定义

以类方式定义更加常用,在以类方式定义损失函数时,我们如果看每一个损失函数的继承关系我们就可以发现Loss函数部分继承自_loss, 部分继承自_WeightedLoss, 而_WeightedLoss继承自_loss, _loss继承自 nn.Module。我们可以将其当作神经网络的一层来对待,同样地,我们的损失函数类就需要继承自nn.Module类。

如下举例IoUloss函数定义:

在自定义损失函数时,涉及到数学运算时,我们最好全程使用PyTorch提供的张量计算接口,这样就不需要我们实现自动求导功能并且我们可以直接调用cuda,使用numpy或者scipy的数学运算时,操作会有些麻烦,

from turtle import forward
import torch.nn as nn
import torch.nn.functional as F

class IoULoss(nn.Module):
    def __init__(self, weight=None, size_average=True):
        super(IoULoss, self).__init__()

    def forward(self, inputs, targets, smooth=1):
        inputs = F.sigmoid(inputs)
        inputs = inputs.view(-1)
        targets = targets.view(-1)

        intersection = (inputs * targets).sum()
        total = (inputs + targets).sum()

        union = total - intersection
        IoU = (intersection + smooth) / (union + smooth)
        return  1 - IoU

2. 动态调整学习率

2.1 使用官方scheduler

PyTorch已经在torch.optim.lr_scheduler为我们封装好了一些动态调整学习率的方法供我们使用,如下面列出的这些scheduler。

from torch.optim import lr_scheduler

lr_scheduler.LambdaLR                   #将每个参数组的学习率设置为初始lr乘以给定函数
lr_scheduler.StepLR                     # 在每个epoch,衰减学习率
lr_scheduler.MultiStepLR                # 一旦epoch达到一定数量,按γ衰减学习率
lr_scheduler.ExponentialLR              # 按指数筛选学习率           
lr_scheduler.CosineAnnealingLR          # 按cosine函数衰减学习率  
lr_scheduler.ReduceLROnPlateau          # 当指标停止改进时衰减学习率
lr_scheduler.CyclicLR                   # 周期性衰减学习率
lr_scheduler.CosineAnnealingWarmRestarts    # 
# 使用官方的Scheduler 
# 选择一种优化器
optimizer = torch.optim.Adam(...) 
# 选择上面提到的一种或多种动态调整学习率的方法
scheduler1 = torch.optim.lr_scheduler.... 
scheduler2 = torch.optim.lr_scheduler....
...
schedulern = torch.optim.lr_scheduler....
# 进行训练
for epoch in range(100):
    train(...)
    validate(...)
    optimizer.step()
    # 需要在优化器参数更新之后再动态调整学习率
    scheduler1.step() 
    ...
    schedulern.step()

我们在使用官方给出的torch.optim.lr_scheduler时,需要将scheduler.step()放在optimizer.step()后面进行使用。

2.2 自定义scheduler

自定义函数adjust_learning_rate来改变param_group中lr的值,在下面的叙述中会给出一个简单的实现。

需要学习率每30轮下降为原来的1/10,假设已有的官方API中没有符合我们需求的,那就需要自定义函数来实现学习率的改变。

def adjust_learning_rate(optimizer, epoch):
    lr = args.lr * (0.1 ** (epoch // 30))
    for param_group in optimizer.param_groups:
        param_group['lr'] = lr

optimizer = torch.optim.SGD(model.parameters(),lr = args.lr,momentum = 0.9)
for epoch in range(10):
    train(...)
    validate(...)
    adjust_learning_rate(optimizer,epoch)

3. 模型微调-torchvision

迁移学习的一大应用场景是模型微调(finetune)。简单来说,就是我们先找到一个同类的别人训练好的模型,把别人现成的训练好了的模型拿过来,换成自己的数据,通过训练调整一下参数。 在PyTorch中提供了许多预训练好的网络模型(VGG,ResNet系列,mobilenet系列......),这些模型都是PyTorch官方在相应的大型数据集训练好的。学习如何进行模型微调,可以方便我们快速使用预训练模型完成自己的任务。

3.1 模型微调流程

  1. 在源数据集(如ImageNet数据集)上预训练一个神经网络模型,即源模型。
  2. 创建一个新的神经网络模型,即目标模型。它复制了源模型上除了输出层外的所有模型设计及其参数。我们假设这些模型参数包含了源数据集上学习到的知识,且这些知识同样适用于目标数据集。我们还假设源模型的输出层跟源数据集的标签紧密相关,因此在目标模型中不予采用。
  3. 为目标模型添加一个输出⼤小为⽬标数据集类别个数的输出层,并随机初始化该层的模型参数。
  4. 在目标数据集上训练目标模型。我们将从头训练输出层,而其余层的参数都是基于源模型的参数微调得到的。


    模型微调.png

3.2 使用已有模型结构

这里我们以torchvision中的常见模型为例,列出了如何在图像分类任务中使用PyTorch提供的常见模型结构和参数。对于其他任务和网络结构,使用方式是类似的:实例化网络,再传递pretrained参数

  • 实例化网络
from cgitb import reset
import torchvision.models as models 

reset18 = models.resnet18()
alexnet = models.alexnet()
vgg16 = models.vgg16()
squeezenet = models.squeezenet1_0()
densenet = models.densenet161()
inception = models.inception_v3()
googlenet = models.googlenet()
shufflenet = models.shufflenet_v2_x1_0()
mobilenet_v2 = models.mobilenet_v2()
mobilenet_v3_large = models.mobilenet_v3_large()
mobilenet_v3_small = models.mobilenet_v3_small()
resnext50_32x4d = models.resnext50_32x4d()
wide_resnet50_2 = models.wide_resnet50_2()
mnasnet = models.mnasnet1_0()
  • 传递pretrained参数

通过True或者False来决定是否使用预训练好的权重,在默认状态下pretrained = False,意味着我们不使用预训练得到的权重,当pretrained = True,意味着我们将使用在一些数据集上预训练得到的权重。

import torchvision.models as models
resnet18 = models.resnet18(pretrained=True)
alexnet = models.alexnet(pretrained=True)
squeezenet = models.squeezenet1_0(pretrained=True)
vgg16 = models.vgg16(pretrained=True)
densenet = models.densenet161(pretrained=True)
inception = models.inception_v3(pretrained=True)
googlenet = models.googlenet(pretrained=True)
shufflenet = models.shufflenet_v2_x1_0(pretrained=True)
mobilenet_v2 = models.mobilenet_v2(pretrained=True)
mobilenet_v3_large = models.mobilenet_v3_large(pretrained=True)
mobilenet_v3_small = models.mobilenet_v3_small(pretrained=True)
resnext50_32x4d = models.resnext50_32x4d(pretrained=True)
wide_resnet50_2 = models.wide_resnet50_2(pretrained=True)

注意事项:

  1. 通常PyTorch模型的扩展为.pt或.pth,程序运行时会首先检查默认路径中是否有已经下载的模型权重,一旦权重被下载,下次加载就不需要下载了。

  2. 一般情况下预训练模型的下载会比较慢,我们可以直接通过迅雷或者其他方式去 这里 查看自己的模型里面model_urls,然后手动下载,预训练模型的权重在Linux和Mac的默认下载路径是用户根目录下的.cache文件夹。在Windows下就是C:\Users.cache\torch\hub\checkpoint。我们可以通过使用 torch.utils.model_zoo.load_url()设置权重的下载地址。

  3. 如果觉得麻烦,还可以将自己的权重下载下来放到同文件夹下,然后再将参数加载网络。

self.model = models.resnet50(pretrained=False)

self.model.load_state_dict(torch.load('./model/resnet50-19c8e357.pth'))

  1. 如果中途强行停止下载的话,一定要去对应路径下将权重文件删除干净,要不然可能会报错。
    mnasnet = models.mnasnet1_0(pretrained=True)

3.3 训练特定层

  • 在默认情况下,参数的属性.requires_grad = True,如果我们从头开始训练或微调不需要注意这里。但如果我们正在提取特征并且只想为新初始化的层计算梯度,其他参数不进行改变。那我们就需要通过设置requires_grad = False来冻结部分层。在PyTorch官方中提供了如下set_parameter_requires_grad的样例。

  • 通过该样例,我们使用resnet18为例的将1000类改为4类,但是仅改变最后一层的模型参数,不改变特征提取的模型参数;注意我们先冻结模型参数的梯度,再对模型输出部分的全连接层进行修改,这样修改后的全连接层的参数就是可计算梯度的。

  • 之后在训练过程中,model仍会进行梯度回传,但是参数更新则只会发生在fc层。通过设定参数的requires_grad属性,我们完成了指定训练模型的特定层的目标,这对实现模型微调非常重要。

import torchvision.models as models

# 冻结参数的梯度
def set_parameter_requires_grad(model, feature_extracting):
    if feature_extracting:
        for param in model.parameters():
            param.requires_grad = False


# 冻结参数的梯度
feature_extract = True
model = models.resnet18(pretrained=True)
set_parameter_requires_grad(model, feature_extract)
# 修改模型
num_ftrs = model.fc.in_features
model.fc = nn.Linear(in_features=num_ftrs, out_features=4, bias=True)

4. 模型微调之timm

除了使用torchvision.models进行预训练以外,还有一个常见的预训练模型库,叫做timm,这个库是由来自加拿大温哥华Ross Wightman创建的。里面提供了许多计算机视觉的SOTA模型,可以当作是torchvision的扩充版本,并且里面的模型在准确度上也较高。

  • timm安装 pip install timm

4.1 查看/修改预训练模型种类

import timm
avail_pretrained_models = timm.list_models(pretrained=True)
len(avail_pretrained_models)
# 查询指定系列的模型时,可以在list_models()输入模型名称 
timm.list_models('resnet*', pretrained=True)
# 查看模型的具体参数,可以通过default_cfg实现,
# 创建模型时,使用num_classes可以将模型的输出进行修改
model = timm.create_model('resnet18', num_classes=10, pretrained=True)
model.default_cfg
# 输出
{'url': 'https://download.pytorch.org/models/resnet18-5c106cde.pth', 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7), 'crop_pct': 0.875, 'interpolation': 'bilinear', 'mean': (0.485, 0.456, 0.406), 'std': (0.229, 0.224, 0.225), 'first_conv': 'conv1', 'classifier': 'fc', 'architecture': 'resnet18'}
# 查看模型的输出
import torch
x = torch.randn(1, 3, 244, 244)
output = model(x)
output.shape
# 输出
torch.Size([1, 10])

# 模型第一层
print(dict(model.named_children())['conv1'])
# 查看模型第一层的参数,以第一层卷积为例)
print(list(dict(model.named_children())['conv1'].parameters()))
# print(model)          # 查看模型的网络结构
# 改变输入通道数(比如我们传入的图片是单通道的,但是模型需要的是三通道图片) 我们可以通过添加in_chans=1来改变
model = timm.create_model('resnet34',num_classes=10,pretrained=True,in_chans=1)
x = torch.randn(1,1,224,224)
output = model(x)

4.2 模型参数的保存

timm库所创建的模型是torch.model的子类,我们可以直接使用torch库中内置的模型参数保存和加载的方法,具体操作如下方代码所示

torch.save(model.state_dict(),'./checkpoint/timm_model.pth')

model.load_state_dict(torch.load('./checkpoint/timm_model.pth'))

# 保存模型权重
torch.save(model.state_dict(), './timm_model.pth')
# 加载模型权重
model.load_state_dict(torch.load('./timm_model.pth'))

5. 半精度训练

PyTorch默认的浮点数存储方式用的是torch.float32,小数点后位数更多固然能保证数据的精确性,但绝大多数场景其实并不需要这么精确,只保留一半的信息也不会影响结果,也就是使用torch.float16格式。由于数位减了一半,因此被称为“半精度”。半精度能够减少显存占用,使得显卡可以同时加载更多数据进行计算

半精度.png

半精度训练的设置
使用autocast配置半精度训练

半精度训练主要适用于数据本身的size比较大(比如说3D图像、视频等)。当数据本身的size并不大时(比如手写数字MNIST数据集的图片尺寸只有28*28),使用半精度训练则可能不会带来显著的提升。

# autocast的导入
from torch.cuda.amp import autocast

# 模型设置
# 使用python的装饰器方法,用autocast装饰模型中的forward函数。关于装饰器的使用。
@autocast()
def forward(self, x):
    ... 
    return x

# 训练过程
# 在训练过程中,只需在将数据输入模型及其之后的部分放入“with autocast():
for  x in train_loader:
    x = x.cuda()
    with autocast():
        ouput = model(x)
        ...
        

6. 数据增强

①数据增强有什么用?

深度学习最重要的是数据。我们需要大量数据才能避免模型的过度拟合。但是我们在许多场景无法获得大量数据,例如医学图像分析。数据增强技术的存在是为了解决这个问题,这是针对有限数据问题的解决方案。数据增强一套技术,可提高训练数据集的大小和质量,以便我们可以使用它们来构建更好的深度学习模型。 在计算视觉领域,生成增强图像相对容易。即使引入噪声或裁剪图像的一部分,模型仍可以对图像进行分类,数据增强有一系列简单有效的方法可供选择,有一些机器学习库来进行计算视觉领域的数据增强,比如:imgaug 官网它封装了很多数据增强算法,给开发者提供了方便

②数据增强的怎么做?

数据扩增是对读取进行数据增强的操作,所以需要在数据读取的时候完成。

③数据增强的方法有哪些?

数据扩增方法有很多:从颜色空间、尺度空间到样本空间,同时根据不同任务数据扩增都有相应的区别。
对于图像分类,数据扩增一般不会改变标签;对于物体检测,数据扩增会改变物体坐标位置;对于图像分割,数据扩增会像素标签;

④数据增强库

  • torchvision https://github.com/pytorch/vision

pytorch官方提供的数据扩增库,提供了基本的数据数据扩增方法,可以无缝与torch进行集成;但数据扩增方法种类较少,且速度中等;

  • imgaug

https://github.com/aleju/imgaug

imgaug是常用的第三方数据扩增库,提供了多样的数据扩增方法,且组合起来非常方便,速度较快;

  • albumentations

https://albumentations.readthedocs.io

是常用的第三方数据扩增库,提供了多样的数据扩增方法,对图像分类、语义分割、物体检测和关键点检测都支持,速度较快。

6.1 torchvision中的常见数据增强方法

基础数据扩增方法指常见的数据扩增方法,且都是标签一致的数据扩增方法,大都出现在torchvision中:

  • transforms.CenterCrop
    对图片中心进行裁剪;
  • transforms.ColorJitter
    对图像颜色的对比度、饱和度和零度进行变换;
  • transforms.FiveCrop
    对图像四个角和中心进行裁剪得到五分图像;
  • transforms.Grayscale
    对图像进行灰度变换;
  • transforms.Pad
    使用固定值进行像素填充;
  • transforms.RandomAffine
    随机仿射变换;
  • transforms.RandomCrop
    随机区域裁剪;
  • transforms.RandomHorizontalFlip
    随机水平翻转;
  • transforms.RandomRotation
    随机旋转;
  • transforms.RandomVerticalFlip
    随机垂直翻转;
import torchvision.transforms 
torchvision.transforms.CenterCrop()

6.2 imgaug的安装和使用

imgaug的安装方法和其他的Python包类似,我们可以通过以下两种方式进行安装

  • conda
    (我用这个安装失败了~~~)
conda config --add channels conda-forge
conda install imgaug
  • pip
    用下面第二行安装成功了
#  install imgaug either via pypi

pip install imgaug

#  install the latest version directly from github

pip install git+https://github.com/aleju/imgaug.git
imgaug的使用

imgaug仅仅提供了图像增强的一些方法,但是并未提供图像的IO操作,因此我们需要使用一些库来对图像进行导入,建议使用imageio进行读入,如果使用的是opencv进行文件读取的时候,需要进行手动改变通道,将读取的BGR图像转换为RGB图像。除此以外,当我们用PIL.Image进行读取时,因为读取的图片没有shape的属性,所以我们需要将读取到的img转换为np.array()的形式再进行处理。因此官方的例程中也是使用imageio进行图片读取。

单张图片的处理

import imageio
import imgaug as ia 
%matplotlib inline 

# image读取
import PIL
from PIL import Image
import numpy as np

# # Image读取照片
# img2 = Image.open('car.jpg')
# image2 = np.array(img2)
# ia.imshow(image2)

#图片的读取,使用imageiod读取,imgaug展示
img = imageio.imread('car.jpg')
print(img.shape)
# 可视化
ia.imshow(img)
  • 对单张图片进行增强处理
    以下为旋转,和多种方式的组合介绍
from imgaug import augmenters as iaa 

# 设置随机数种子
ia.seed(4)
# 实例化方法
rotate= iaa.Affine(rotate=(45))
img_aug = rotate(image = img)
ia.imshow(img_aug)
rotate.png
  • 对图片进行多种组合的数据增强

使用imgaug.augmenters.Sequential()来构造数据增强的pipline,与torchvison.transforms.Compose()类似

总的来说,对单张图片处理的方式基本相同,我们可以根据实际需求,选择合适的数据增强方法来对数据进行处理。

# iaa.Sequential的参数如下
# iaa.Sequential(children=None,   # Augmenter集合
#                 random_order=False, # 是否对每个batch使用不同顺序的Augmenter list
#                 name = None,
#                 deterministic=False,
#                 random_state=None
# )

# 构建处理序列
from email.mime import image


aug_seq = iaa.Sequential([
    iaa.Affine(rotate=(-25, 25)),           # 让图片在(-25, 25)间随机旋转
    iaa.AdditiveGaussianNoise(scale=(10, 60)),      # 给图片添加高斯噪声,高斯噪声的标准差为(10, 60)
    iaa.Crop(percent=(0, 0.2))          # 对图像进行随机裁剪,裁剪范围(0, 0.2)
])

# 对图片进行处理
image_aug = aug_seq(image=img)
ia.imshow(image_aug)

对批次图片进行处理

可以将图形数据按照NHWC(N:batch H:height W:Width C:channel)的形式或者由列表组成的HWC的形式对批量的图像进行处理。主要分为以下两部分:

  • 对批次的图片以同一种方式处理
  • 对批次的图片进行分部分处理。

import os
import imageio.v2 as imageio
import imgaug as ia 
%matplotlib inline 

root_path = '/Users/anker/Desktop/python_code/datasets/test'
image_list = []
# os.listdir(root_path)
for i in os.listdir(root_path):
    image_path = os.path.join(root_path, i)
    img = imageio.imread(image_path)
    image_list.append(img)
    print(img.shape)

# 对一批次的图片进行处理时,只需要将待处理的图片放在一个list中,并将image改为image即可进行数据增强操作,具体实际操作如下:
images = [image_list[0], image_list[0], image_list[0]]
# 传参时需要指明是images参数
images_aug = rotate(images = images)
# ia.imshow图片时,输入的图片必须是相同的大小
ia.imshow(np.hstack(images_aug))

输出


批次1.png
# 对批次图片使用多种增强方法,传参时注意传的是images参数
images_aug_seq = aug_seq.augment_images(images = images)
# images_aug_seq = aug_seq(images = images)         # 可以用上面的写法,也可以用本行的方法传参
ia.imshow(np.hstack(images_aug_seq))

输出

批次输出2.png

对批次的图片分部分处理

imgaug.augmenters.Sometimes()对batch中的一部分图片应用一部分Augmenters,剩下的图片应用另外的Augmenters。

aug_sometimes = iaa.Sometimes(0.5, iaa.GaussianBlur(0.7), iaa.Fliplr(1.0))
images_aug_sometimes = aug_sometimes(images = images)
ia.imshow(np.hstack(images_aug_sometimes))

输出

批次输出3.png

对不同大小的图片进行处理

除了可视化与其他不同外,其他都相同

image_list
# 构建pipline
seq = iaa.Sequential([
    iaa.CropAndPad(percent=(-0.2, 0.2), pad_mode='edge'),       # 对图片进行剪切和填充
    iaa.AddToHueAndSaturation((-60, 60)),               # 对图片的饱和度和色调进行调整
    iaa.ElasticTransformation(alpha=0.9, sigma=9),      # 对图片进行像素调整,产生水波纹的效果
    iaa.Cutout()          # 填充图像
])

# 对图像进行增强
images_seq = seq(images= image_list)

for i in range(len(image_list)):
    print("Image %d (input shape: %s, output shape: %s)" % (i, image_list[i].shape, images_seq[i].shape))
    ia.imshow(np.hstack([image_list[i], images_seq[i]]))

输出


批次输出4.png

6.3 imgaug在PyTorch的应用

import numpy as np
import imgaug
from imgaug import augmenters as iaa
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms

# 构建pipline
tfs = transforms.Compose([
    iaa.Sequential([
        iaa.flip.Fliplr(p=0.5),
        iaa.flip.Flipud(p=0.5),
        iaa.GaussianBlur(sigma=(0.0, 0.1)),
        iaa.MultiplyBrightness(mul=(0.65, 1.35)),
    ]).augment_image,
    # 不要忘记了使用ToTensor()
    transforms.ToTensor()
])

# 自定义数据集
class CustomDataset(Dataset):
    def __init__(self, n_images, n_classes, transform=None):
        # 图片的读取,建议使用imageio
        self.images = np.random.randint(0, 255,
                                        (n_images, 224, 224, 3),
                                        dtype=np.uint8)
        self.targets = np.random.randn(n_images, n_classes)
        self.transform = transform

    def __getitem__(self, item):
        image = self.images[item]
        target = self.targets[item]

        if self.transform:
            image = self.transform(image)

        return image, target

    def __len__(self):
        return len(self.images)


def worker_init_fn(worker_id):
    imgaug.seed(np.random.get_state()[1][0] + worker_id)


custom_ds = CustomDataset(n_images=50, n_classes=10, transform=tfs)
custom_dl = DataLoader(custom_ds, batch_size=64,
                       num_workers=4, pin_memory=True, 
                       worker_init_fn=worker_init_fn)

关于num_workers在Windows系统上只能设置成0,但是当我们使用Linux远程服务器时,可能使用不同的num_workers的数量,这是我们就需要注意worker_init_fn()函数的作用了。它保证了我们使用的数据增强在num_workers>0时是对数据的增强是随机的。

除去imgaug以外,还可以学习下Albumentations,因为Albumentations跟imgaug都有着丰富的教程资源,这个以后再看,先学完教程再说。

7. 使用argparse进行调参

argparse的作用就是将命令行传入的其他参数进行解析、保存和使用。在使用argparse后,我们在命令行输入的参数就可以以这种形式python file.py --lr 1e-4 --batch_size 32来完成对常见超参数的设置。

argparse的使用

  • 创建ArgumentParser()对象
  • 调用add_argument()方法添加参数
  • 使用parse_args()解析参数
# 简单demo
import argparse

# 创建ArgumentParse()对象
parser = argparse.ArgumentParser()

# 添加参数
parser.add_argument('-o', '--output', action='store_true', help='shows output')
# action = `store_true` 会将output参数记录为True
# type 规定了参数的格式
# default 规定了默认值
parser.add_argument('--lr', type=float, default=3e-5, help='select the learning rate, default=1e-3')
parser.add_argument('--batch_size', type=int, required=True, help='input batch size')

# 使用parse_args()解析函数
args = parser.parse_args()

if args.output:  
    print(f"learning rate:{args.lr} ")

更加高效的使用

一种方式是将超参数的设置写在单独的config.py文件中,然后在调用使用

另外一种是封装为函数,调用的时候进行使用

# 将超参数设置写在单独的config.py文件中
import argparse  
  
def get_options(parser=argparse.ArgumentParser()):  
  
    parser.add_argument('--workers', type=int, default=0,  
                        help='number of data loading workers, you had better put it '  
                              '4 times of your gpu')  
  
    parser.add_argument('--batch_size', type=int, default=4, help='input batch size, default=64')  
  
    parser.add_argument('--niter', type=int, default=10, help='number of epochs to train for, default=10')  
  
    parser.add_argument('--lr', type=float, default=3e-5, help='select the learning rate, default=1e-3')  
  
    parser.add_argument('--seed', type=int, default=118, help="random seed")  
  
    parser.add_argument('--cuda', action='store_true', default=True, help='enables cuda')  
    parser.add_argument('--checkpoint_path',type=str,default='',  
                        help='Path to load a previous trained model if not empty (default empty)')  
    parser.add_argument('--output',action='store_true',default=True,help="shows output")  
  
    opt = parser.parse_args()  
  
    if opt.output:  
        print(f'num_workers: {opt.workers}')  
        print(f'batch_size: {opt.batch_size}')  
        print(f'epochs (niters) : {opt.niter}')  
        print(f'learning rate : {opt.lr}')  
        print(f'manual_seed: {opt.seed}')  
        print(f'cuda enable: {opt.cuda}')  
        print(f'checkpoint_path: {opt.checkpoint_path}')  
  
    return opt  
  
if __name__ == '__main__':  
    opt = get_options()
# 在随后的train.py等文件中,单独使用
# 导入必要库
...
import config

opt = config.get_options()

manual_seed = opt.seed
num_workers = opt.workers
batch_size = opt.batch_size
lr = opt.lr
niters = opt.niters
checkpoint_path = opt.checkpoint_path

# 随机数的设置,保证复现结果
def set_seed(seed):
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    random.seed(seed)
    np.random.seed(seed)
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True

...


if __name__ == '__main__':
    set_seed(manual_seed)
    for epoch in range(niters):
        train(model,lr,batch_size,num_workers,checkpoint_path)
        val(model,lr,batch_size,num_workers,checkpoint_path)

你可能感兴趣的:(五、pytorch进阶训练技巧——pytorch学习)