自动网络搜索AutoDL之PaddlePaddle实现

  编码器通常以 RNN 的方式把网络结构进行编码,然后评测器把编码的结果拿去进行训练和评测,拿到包括准确率、模型大小在内的一些指标,反馈给编码器,编码器进行修改,再次编码,如此迭代。经过若干次迭代以后,最终得到一个设计好的模型。

  为了性能考虑,迭代中用到的训练数据通常是几万张规模的数据集(比如 CIFAR-10),模型设计结束后,会用大规模数据(比如 ImageNet)重新训练一次,进一步优化参数。具体原理可以参考以下链接:解读百度AutoDL

  本项目主要是使用搜索出来的模型结构在CIFAR-10数据上进行训练和验证主要的目录结构如下:

  |--root

  |--|--build # 该目录下的文件用于根据不同的配置构建网络

  |--|--|--layers.py # 网络中各种层的实现

  |--|--|--resnet_base.py # 带残差的结构

  |--|--|--ops.py # 调用layers.py中实现的各种层组成op

  |--|--|--vgg_base.py # 不带残差的结构

  |--|--tokens # 通过二进制存储的各种模型的配置

  |--|--dataset # cifar数据集

  |--|--model # 训练完成后保存的可以用于infer的固化模型

  |--|--test # 用于存放需要测试的图像

  |--|--reader.py # 数据集读取部分

  |--|--train_hinas_res.py # 用于训练带残差结构的网络

  |--|--train_hinas.py # 用于训练不带残差结构的网络

  |--|--nn_paddle.py # 具体的训练逻辑以及模型保存写在这个文件中

  |--|--infer.py # 用于对test中的图片进行预测,需要修改文件中图片的路径

  In[6]

  # 从work中把代码解压出来

  !tar xzf data/data9705/cifar-10-python.tar.gz -C dataset/cifar/

  mv: cannot move '/home/aistudio/HiNAS_models/build' to '/home/aistudio/build': Directory not empty

  mv: cannot move '/home/aistudio/HiNAS_models/tokens' to '/home/aistudio/tokens': Directory not empty

  rm: cannot remove 'HiNAS_models': Is a directory

  In[1]

  # 安装程序依赖的库文件

  !pip install absl-py

  DEPRECATION: Python 2.7 will reach the end of its life on January 1st, 2020. Please upgrade your Python as Python 2.7 won't be maintained after that date. A future version of pip will drop support for Python 2.7.

  Looking in indexes: https://pypi.mirrors.ustc.edu.cn/simple/

  Collecting absl-py

  Downloading https://mirrors.tuna.tsinghua.edu.cn/pypi/web/packages/da/3f/9b0355080b81b15ba6a9ffcf1f5ea39e307a2778b2f2dc8694724e8abd5b/absl-py-0.7.1.tar.gz (99kB)

  100% |████████████████████████████████| 102kB 9.9MB/s ta 0:00:01

  Requirement already satisfied: six in /opt/conda/envs/python27-paddle120-env/lib/python2.7/site-packages (from absl-py) (1.12.0)

  Requirement already satisfied: enum34 in /opt/conda/envs/python27-paddle120-env/lib/python2.7/site-packages (from absl-py) (1.1.6)

  Building wheels for collected packages: absl-py

  Building wheel for absl-py (setup.py) ... done

  Stored in directory: /home/aistudio/.cache/pip/wheels/cc/27/b8/80769636fbf30d2fddba4c6e149163c0a319ba2dfc73f6e660

  Successfully built absl-py

  Installing collected packages: absl-py

  Successfully installed absl-py-0.7.1

  本目录下包含6个图像分类模型,都是百度大数据实验室 Hierarchical Neural Architecture Search (HiNAS) 项目通过机器自动发现的模型,在CIFAR-10数据集上达到96.1%的准确率。这6个模型分为两类,前3个没有skip link,分别命名为 HiNAS 0-2号,后三个网络带有skip link,功能类似于Resnet中的shortcut connection,分别命名 HiNAS 3-5号优品拍拍

  使用train_hinas.py --model=model_id来训练没有skip link的HiNAS 0-2号网络模型,model_id代表0,1,2中的一个

你可能感兴趣的:(自动网络搜索AutoDL之PaddlePaddle实现)