如同传统程序项目开发一样,在深度学习项目中,有很多程式化的、可封装的代码段。将此类代码段形成可重复使用的模板,将大大开发效率。下面即推荐一种基于pyTorch的深度学习项目模板。
模板地址:https://github.com/victoresque/pytorch-template
清晰的文件夹结构,适用于许多深度学习项目。
.json
配置文件支持方便参数调整。
可定制的命令行选项,用于更方便的参数调整。
检查点(checkpoint)保存和恢复。
用于更快开发的抽象基类:
BaseTrainer
处理检查点保存/恢复、训练过程记录等。BaseDataLoader
处理batch的生成、数据清洗和训练集/验证集拆分。BaseModel
提供基本模型摘要。pytorch-template/
│
├── train.py - 启动模型训练的py文件
├── test.py - 测试模型的py文件
│
├── config.json - 控制训练的config文件
├── parse_config.py - 处理config文件及cli选项的py文件
│
├── new_project.py - 初始化新项目所需运行的py文件
│
├── base/ - 抽象基类
│ ├── base_data_loader.py
│ ├── base_model.py
│ └── base_trainer.py
│
├── data_loader/ - 数据处理与装入
│ └── data_loaders.py
│
├── data/ - 存放输入数据的默认文件夹
│
├── model/ - 模型,losses和评价指标
│ ├── model.py
│ ├── metric.py
│ └── loss.py
│
├── saved/
│ ├── models/ - 存放训练好的模型文件
│ └── log/ - 默认的tensorboard和log文件存放地址
│
├── trainer/ - trainer类
│ └── trainer.py
│
├── logger/ -tensorboard可视化及logging模块
│ ├── visualization.py
│ ├── logger.py
│ └── logger_config.json
│
└── utils/ - 其他工具函数
├── util.py
└── ...
原repo是模板的MINST示例,可直接使用python train.py -c config.json
运行。
当你需要开始一个新项目时,需要首先运行new_project.py。通过python new_project.py ../NewProject
创建一个名为“NewProject”的新项目文件夹。该脚本将过滤掉不需要的文件,如cache、git 文件和README.md。
config.json文件详细内容如下所示:
{
"name": "Mnist_LeNet", // 项目名称
"n_gpu": 1, // 用于训练的GPU数
"arch": {
"type": "MnistModel", // 模型名称
"args": {
}
},
"data_loader": {
"type": "MnistDataLoader", // 选择DataLoader
"args":{
"data_dir": "data/", // 数据集所在路径
"batch_size": 64, // batch size
"shuffle": true, // 在划分训练/验证集前是否打乱数据集
"validation_split": 0.1 // 验证集的大小
"num_workers": 2, // 装入数据集时所开进程数量
}
},
"optimizer": {
"type": "Adam",
"args":{
"lr": 0.001, // 学习率
"weight_decay": 0, // (可选)权重衰减
"amsgrad": true
}
},
"loss": "nll_loss", // loss
"metrics": [
"accuracy", "top_k_acc" // 评价指标
],
"lr_scheduler": {
"type": "StepLR", // 学习率scheduler
"args":{
"step_size": 50,
"gamma": 0.1
}
},
"trainer": {
"epochs": 100, // epochs
"save_dir": "saved/", // checkpoints文件会存储在save_dir/models/name
"save_freq": 1, //
"verbosity": 2, // 0: quiet, 1: per epoch, 2: full
"monitor": "min val_loss" // mode and metric for model performance monitoring. set 'off' to disable.
"early_stop": 10 // number of epochs to wait before early stop. set 0 to disable.
"tensorboard": true, // 是否使用tensorboard
}
}
修改好config.json文件,然后运行
python train.py --config config.json
我们可以从上一检查点继续训练:
python train.py --resume path/to/checkpoint
可以通过将配置文件的n_gpu
参数设置地更大来启用多 GPU 训练。如果配置为使用比可用数量更少的 gpu,默认情况下将使用前 n 个设备,但我们仍可以通过--device
来指定GPU。
python train.py --device 2,3 -c config.json
更改配置文件的值是调整超参数的一种干净、安全且简单的方法。但是,如果某些值需要过于频繁或快速地更改,有时最好使用命令行选项。
该模板默认使用存储在 json 文件中的配置,但你仍可以通过命令行选项的方式更改其中的一部分。
# simple class-like object having 3 attributes, `flags`, `type`, `target`.
CustomArgs = collections.namedtuple('CustomArgs', 'flags type target')
options = [
CustomArgs(['--lr', '--learning_rate'], type=float, target=('optimizer', 'args', 'lr')),
CustomArgs(['--bs', '--batch_size'], type=int, target=('data_loader', 'args', 'batch_size'))
# options added here can be modified by command line flags.
]
我们可以自由编写自己的DataLoader。
BaseDataLoader
BaseDataLoader
是torch.utils.data.DataLoader
的子类,您可以使用其中任何一个。BaseDataLoader
主要处理:
BaseDataLoader.split_validation()
做训练/验证集划分for batch_idx, (x_batch, y_batch) in data_loader:
pass
我们可以自由编写自己的trainer。
继承BaseTrainer
BaseTrainer
主要处理:
训练过程记录
检查点保存、恢复
可重新配置的性能监控,用于保存当前的最佳模型,并提前停止训练。
monitor
设置为max val_accuracy
,这意味着每个epoch结束后都会保存一个最佳模型model_best.pth
early_stop
被设置为true
,当模型性能在给定数量的 epoch 内没有提高时,训练将自动终止。实现抽象方法
你必须实现_train_epoch()
。如果需要验证功能,则需要进一步实现_valid_epoch()
。上述两种方法都在trainer/trainer.py
里。
我们可以自由编写自己的Model。
继承BaseModel
BaseModel
主要处理:
torch.nn.Module
__str__
:修改print
函数以打印可训练参数的数量。自定义损失函数可以在“model/loss.py”中实现。通过将config文件中“loss”更改为相应的名称来使用它们。
自定义的metrics可在“model/metrics.py”,方法同上。