编码器通常以 RNN 的方式把网络结构进行编码,然后评测器把编码的结果拿去进行训练和评测,拿到包括准确率、模型大小在内的一些指标,反馈给编码器,编码器进行修改,再次编码,如此迭代。经过若干次迭代以后,最终得到一个设计好的模型。
为了性能考虑,迭代中用到的训练数据通常是几万张规模的数据集(比如 CIFAR-10),模型设计结束后,会用大规模数据(比如 ImageNet)重新训练一次,进一步优化参数。具体原理可以参考以下链接:解读百度AutoDL
本项目主要是使用搜索出来的模型结构在CIFAR-10数据上进行训练和验证主要的目录结构如下:
|--root
|--|--build # 该目录下的文件用于根据不同的配置构建网络
|--|--|--layers.py # 网络中各种层的实现
|--|--|--resnet_base.py # 带残差的结构
|--|--|--ops.py # 调用layers.py中实现的各种层组成op
|--|--|--vgg_base.py # 不带残差的结构
|--|--tokens # 通过二进制存储的各种模型的配置
|--|--dataset # cifar数据集
|--|--model # 训练完成后保存的可以用于infer的固化模型
|--|--test # 用于存放需要测试的图像
|--|--reader.py # 数据集读取部分
|--|--train_hinas_res.py # 用于训练带残差结构的网络
|--|--train_hinas.py # 用于训练不带残差结构的网络
|--|--nn_paddle.py # 具体的训练逻辑以及模型保存写在这个文件中
|--|--infer.py # 用于对test中的图片进行预测,需要修改文件中图片的路径
In[6]
# 从work中把代码解压出来
!tar xzf data/data9705/cifar-10-python.tar.gz -C dataset/cifar/
mv: cannot move '/home/aistudio/HiNAS_models/build' to '/home/aistudio/build': Directory not empty
mv: cannot move '/home/aistudio/HiNAS_models/tokens' to '/home/aistudio/tokens': Directory not empty
rm: cannot remove 'HiNAS_models': Is a directory
In[1]
# 安装程序依赖的库文件
!pip install absl-py
DEPRECATION: Python 2.7 will reach the end of its life on January 1st, 2020. Please upgrade your Python as Python 2.7 won't be maintained after that date. A future version of pip will drop support for Python 2.7.
Looking in indexes: https://pypi.mirrors.ustc.edu.cn/simple/
Collecting absl-py
Downloading https://mirrors.tuna.tsinghua.edu.cn/pypi/web/packages/da/3f/9b0355080b81b15ba6a9ffcf1f5ea39e307a2778b2f2dc8694724e8abd5b/absl-py-0.7.1.tar.gz (99kB)
100% |████████████████████████████████| 102kB 9.9MB/s ta 0:00:01
Requirement already satisfied: six in /opt/conda/envs/python27-paddle120-env/lib/python2.7/site-packages (from absl-py) (1.12.0)
Requirement already satisfied: enum34 in /opt/conda/envs/python27-paddle120-env/lib/python2.7/site-packages (from absl-py) (1.1.6)
Building wheels for collected packages: absl-py
Building wheel for absl-py (setup.py) ... done
Stored in directory: /home/aistudio/.cache/pip/wheels/cc/27/b8/80769636fbf30d2fddba4c6e149163c0a319ba2dfc73f6e660
Successfully built absl-py
Installing collected packages: absl-py
Successfully installed absl-py-0.7.1
本目录下包含6个图像分类模型,都是百度大数据实验室 Hierarchical Neural Architecture Search (HiNAS) 项目通过机器自动发现的模型,在CIFAR-10数据集上达到96.1%的准确率。这6个模型分为两类,前3个没有skip link,分别命名为 HiNAS 0-2号,后三个网络带有skip link,功能类似于Resnet中的shortcut connection,分别命名 HiNAS 3-5号优品拍拍
使用train_hinas.py --model=model_id来训练没有skip link的HiNAS 0-2号网络模型,model_id代表0,1,2中的一个