该项目分别训练八个模型并生成csv文件,并进行融合
conda create -n emotion python==3.8.0
conda activate emotion
cd {project_path}
pip install -r requirements.txt
打开train.sh
,可以看到训练的命令行,依次注释和解注释随后运行train.sh
。
因为是训练八个模型,分别是efficientnet_b2b
, efficientnet_b3b
, cbam_resnet50
, resmasking
,resmasking_dropout1
,resnest269e
,swin
,hrnet_w64
,所以要训练和测试,需要分别进行8次。
python main_fer2013.py --config ./config/efficientnet_b2b_config.json
python main_fer2013.py --config ./config/efficientnet_b3b_config.json
python main_fer2013.py --config ./config/cbam_resnet50_config.json
python main_fer2013.py --config ./config/hrnet_w64_config.json
python main_fer2013.py --config ./config/resmasking_config.json
python main_fer2013.py --config ./config/resmasking_dropout1_config.json
python main_fer2013.py --config ./config/resnest269e_config.json
python main_fer2013.py --config ./config/swin_config.json
checkpoint保存在{project_path}/checkpoint
目录下,可以在log
文件夹下查看训练的日志。
具体内容在test.sh
文件中。各个模型我们存放在百度云盘 https://pan.baidu.com/s/1mM-APWoLV5P3nvrzmG–Jg 提取码 1gyh
下载后复制到user_data/model_data下面即可运行下面的命令进行预测。
python gen_results.py --config ./config/efficientnet_b2b_config.json --model_name efficientnet_b2b --checkpoint_path efficientnet_b2b_2021Jul25_17.08
python gen_results.py --config ./config/efficientnet_b3b_config.json --model_name efficientnet_b3b --checkpoint_path efficientnet_b3b_2021Jul25_20.08
python gen_results.py --config ./config/cbam_resnet50_config.json --model_name cbam_resnet50 --checkpoint_path cbam_resnet50_test_2021Jul24_19.18
python gen_results.py --config ./config/hrnet_w64_config.json --model_name hrnet_w64 --checkpoint_path hrnet_test_2021Aug01_17.13
python gen_results.py --config ./config/resmasking_config.json --model_name resmasking --checkpoint_path resmasking_test_2021Jul26_14.33
python gen_results.py --config ./config/resmasking_dropout1_config.json --model_name resmasking_dropout1 --checkpoint_path resmasking_dropout1_test_2021Aug01_17.13
python gen_results.py --config ./config/resnest269e_config.json --model_name resnest269e --checkpoint_path resnest269e_test_2021Aug02_11.39
python gen_results.py --config ./config/swin_config.json --model_name swin_large_patch4_window7_224 --checkpoint_path swin_large_patch4_window7_224_test_2021Aug02_21.36
请注意,这里的model_name
是确定的,checkpoint_path
是你训练得到模型的名字,如果你自己训练了其中的一些模型,请将对应的名称修改为训练得到模型的名称。
上述8个模型的预测结果统一放在user_data/tmp_data里面,下面使用集成方法对上述八个模型的结果进行整合。
python gen_ensemble.py
我们将上述八个模型的结果进行集成,最终生成的文件放在prediction_result下面的result.csv文件中。