mmclassification使用

安装

1. 安装python,cuda,torch

按照官方文档使用cuda按照 https://mmclassification.readthedocs.io/en/latest/install.html

(1)python3.7

(2)目前cuda支持版本为9.2,10.1,10.2,11.0,11.1

 !!!不支持cuda10,为了同时使用tf1和tf2装的cuda10不可

Windows下可以安装多个版本的cuda:https://blog.csdn.net/qq_17783559/article/details/112916708

(3)conda命令安装,目前支持版本为1.3-1.8

2. 安装mmvc

按照官方文档进行安装 https://github.com/open-mmlab/mmcv

pip install mmcv-full=={mmcv_version} -f https://download.openmmlab.com/mmcv/dist/cu101/torch1.7.0/index.html

其中mmvc_version写1.3.0

3. 安装mmclassification

按照官方文档  https://mmclassification.readthedocs.io/en/latest/install.html

git clone https://github.com/open-mmlab/mmclassification.git
cd mmclassification
pip install -e .

训练

使用自定义数据集

参考:https://blog.csdn.net/weixin_43216130/article/details/115312600?utm_medium=distribute.pc_relevant.none-task-blog-baidujs_title-1&spm=1001.2101.3001.4242

1. 准备好数据集,按照meta, train, val的格式组织

2. 生成文件写入其类别

3. mmcls/datasets目录下创建py文件,定义类

4. 修改configs以及一系列文件

训练

python tools/train.py configs/resnet/resnet_mydataset.py
python tools/train.py [配置文件]

Tips:

  • 训练过程中生成的checkpoint保存在work_dirs文件夹中

测试

python tools/test.py configs/resnet/resnet_mydataset.py checkpoints/resnet_mydataset.pth --out result_mydataset.pickle
python tools/test.py [配置文件] [断点文件] --out [输出文件名]

Tips:

  • checkpoint文件要放在checkpoints文件夹中才可以用

  • 测试的时候不指定metric的话输出文件中写入的是对每个样本的分类情况,指定metric的话写入的则是整体的情况

你可能感兴趣的:(科研,深度学习)