【炼丹手记】在ModelArts上使用AI芯片Ascend训练基于MindSpore的DeepFM

论文

DeepFM: A Factorization-Machine based Neural Network for CTR Prediction
https://arxiv.org/abs/1703.04247

CTR预估是目前推荐系统的核心技术,其目标是预估用户点击推荐内容的概率。
在CTR预估任务中,特征非常重要。
这篇论文提出的DeepFM模型是一种可以从原始特征中抽取到各种复杂度特征的端到端模型,可以有效避免人工特征工程的困扰。

数据集

criteo是非常经典的点击率预估比赛数据集。
下载地址如下:
http://go.criteo.net/criteo-research-kaggle-display-advertising-challenge-dataset.tar.gz

数据集预处理

准备数据和代码

把准备好的数据集从OBS复制到ModelArts。

import moxing as mox
mox.file.copy_parallel("obs://dataset-city/recommend-criteo/data", "/cache/criteo_ori")

把准备好的代码也从OBS复制到ModelArts。

mox.file.copy_parallel("obs://2021-ms-models/deepfm/", "/home/ma-user/work/deepfm")

预处理

用下面的命令,进行数据预处理。

python -m src.preprocess_data --data_path=/cache/criteo_ori/ --dense_dim=13 --slot_dim=26 --threshold=100 --train_line_count=45840617 --skip_id_convert=0 '''

坑一

ModdelArts中普通的CodeLab环境默认是不支持MindSpore的。
需要换成自己指定的Nodebook环境

sh-4.3$python -m src.preprocess_data --data_path=/cache/ --dense_dim=13 --slot_dim=26 --threshold=100 --train_line_count=45840617 --skip_id_convert=0
Traceback (most recent call last):
  File "/opt/conda/lib/python3.7/runpy.py", line 193, in _run_module_as_main
    "__main__", mod_spec)
  File "/opt/conda/lib/python3.7/runpy.py", line 85, in _run_code
    exec(code, run_globals)
  File "/home/ma-user/work/deepfm/src/preprocess_data.py", line 20, in 
    from mindspore.mindrecord import FileWriter
ModuleNotFoundError: No module named 'mindspore'

坑二

脚本中默认数据要放在origin_data

FileNotFoundError: [Errno 2] No such file or directory: '/cache/criteo_ori/origin_data/train.txt'
sh-4.4$cp /cache/criteo_ori/origin_data/train_small.txt /cache/criteo_ori/origin_data/train.txt
sh-4.4$cp /cache/criteo_ori/origin_data/train.txt /cache/criteo_ori/origin_data/val.txt 

移动数据

sh-4.4$mv /cache/criteo_ori/*.txt /cache/criteo_ori/origin_data/
sh-4.4$ls /cache/criteo_ori/
origin_data  stats_dict
sh-4.4$ls /cache/criteo_ori/origin_data/
readme.txt  train_small.txt  train_very_small.txt  val_small.txt

预处理结果

生成了mindrecord

sh-4.4$ls /cache/criteo_ori/
mindrecord  origin_data  stats_dict
sh-4.4$ls /cache/criteo_ori/stats_dict/
cat_count_dict.pkl  val_max_dict.pkl  val_min_dict.pkl
sh-4.4$ls /cache/criteo_ori/mindrecord/
test_input_part.mindrecord0       train_input_part.mindrecord02     train_input_part.mindrecord07     train_input_part.mindrecord12     train_input_part.mindrecord17
test_input_part.mindrecord0.db    train_input_part.mindrecord02.db  train_input_part.mindrecord07.db  train_input_part.mindrecord12.db  train_input_part.mindrecord17.db
test_input_part.mindrecord1       train_input_part.mindrecord03     train_input_part.mindrecord08     train_input_part.mindrecord13     train_input_part.mindrecord18
test_input_part.mindrecord1.db    train_input_part.mindrecord03.db  train_input_part.mindrecord08.db  train_input_part.mindrecord13.db  train_input_part.mindrecord18.db
test_input_part.mindrecord2       train_input_part.mindrecord04     train_input_part.mindrecord09     train_input_part.mindrecord14     train_input_part.mindrecord19
test_input_part.mindrecord2.db    train_input_part.mindrecord04.db  train_input_part.mindrecord09.db  train_input_part.mindrecord14.db  train_input_part.mindrecord19.db
train_input_part.mindrecord00     train_input_part.mindrecord05     train_input_part.mindrecord10     train_input_part.mindrecord15     train_input_part.mindrecord20
train_input_part.mindrecord00.db  train_input_part.mindrecord05.db  train_input_part.mindrecord10.db  train_input_part.mindrecord15.db  train_input_part.mindrecord20.db
train_input_part.mindrecord01     train_input_part.mindrecord06     train_input_part.mindrecord11     train_input_part.mindrecord16
train_input_part.mindrecord01.db  train_input_part.mindrecord06.db  train_input_part.mindrecord11.db  train_input_part.mindrecord16.db
sh-4.4$ls /cache/criteo_ori/
mindrecord  origin_data  stats_dict

结果保存

把处理好的数据COPY到OBS,保存起来。

mox.file.copy_parallel("/cache/criteo_ori/", "obs://dataset-city/recommend-criteo/")

训练

训练用代码

把代码复制到OBS的这个路径 obs://2021-ms-models/deepfm/

创建算法

【炼丹手记】在ModelArts上使用AI芯片Ascend训练基于MindSpore的DeepFM_第1张图片

【炼丹手记】在ModelArts上使用AI芯片Ascend训练基于MindSpore的DeepFM_第2张图片

创建训练作业

【炼丹手记】在ModelArts上使用AI芯片Ascend训练基于MindSpore的DeepFM_第3张图片
【炼丹手记】在ModelArts上使用AI芯片Ascend训练基于MindSpore的DeepFM_第4张图片

训练完成

【炼丹手记】在ModelArts上使用AI芯片Ascend训练基于MindSpore的DeepFM_第5张图片

算法发布在 AI Gallery

上面介绍的算法已经发布到了AI Gallery。链接如下:

https://developer.huaweicloud.com/develop/aigallery/algorithm/detail?id=ce2013a6-b5da-4616-a553-39362be1b38c

你可能感兴趣的:(昇思,人工智能,机器学习,深度学习)