使用 Pytorch 训练深度学习模型时常用的功能代码(保持更新)

固定随机种子以确保模型可复现

import os
import torch
import random
import numpy as np
def seed_torch_everywhere(seed=24):

	random.seed(seed)
	os.environ['PYTHONHASHSEED'] = str(seed)
	np.random.seed(seed)
	torch.manual_seed(seed)
	torch.cuda.manual_seed(seed)
	torch.cuda.manual_seed_all(seed) # if you are using multi-GPU.
	torch.backends.cudnn.benchmark = False
	torch.backends.cudnn.deterministic = True

保存与加载模型参数

应用场景:模型训练意外中断后,在最后一次保存的模型参数上接续训练

import torch
import shutil
def save_ckpt(state, checkpoint_dir, best_model_dir, is_best=False,  file_name='checkpoint.pt'):
    r"""在训练时将模型参数 state 保存在 checkpoint_dir 文件夹下,
    若当前模型为迄今最优模型则将此时的参数另复制一份到 best_model_dir 下。
    除了保存模型参数外,还可保存优化器、学习率规划器的状态,以及当前 epoch 值等。
    Usage:
    >>> checkpoint = {
    >>>     'epoch': epoch + 1,
    >>>     'state_dict': model.state_dict(),
    >>>     'optimizer': optimizer.state_dict()
    >>> }
    >>> save_ckpt(checkpoint, checkpoint_dir, best_model_dir, is_best)
    """
    f_path = os.path.join(checkpoint_dir, file_name)
    torch.save(state, f_path)
    if is_best:
        best_f_path = os.path.join(best_model_dir, file_name)
        shutil.copyfile(f_path, best_f_path)
def load_ckpt(checkpoint_fpath, model, optimizer=None, lr_scheduler=None):
    r"""从 checkpoint_fpath 中加载模型、优化器、学习率规划器、epoch 值等
    Usage:
    >>> model = MyModel(**kwargs)
    >>> optimizer = optim.Adam(model.parameters(), lr=learning_rate)
    >>> ckpt_path = "path/to/checkpoint/checkpoint.pt"
    >>> model, optimizer, start_epoch = load_ckpt(ckpt_path, model, optimizer) 
    """
    checkpoint = torch.load(checkpoint_fpath)
    model.load_state_dict(checkpoint['state_dict'])
    epoch = checkpoint['epoch']
    outputs = (model, epoch)
    if optimizer:
        optimizer.load_state_dict(checkpoint['optimizer'])
        outputs += (optimizer, )
    if lr_scheduler:
        lr_scheduler.load_state_dict(checkpoint['lr_scheduler'])
        outputs += (lr_scheduler, )

    return outputs

正则化 EarlyStopping

代码与使用方式可参考另一篇博客。

超参的随机搜索

伪代码与使用方式可参考另一篇博客。

你可能感兴趣的:(神经网络,pytorch,深度学习,python)