目录
代码一(Using TensorFlow):
代码二(Using TensorFlow):
代码三(Using PyTorch):
参考:
本人在网上找了三个相关的代码,但是都有问题,这里记录一下修改哪些地方之后可以跑通。
代码地址:
https://github.com/wallarm/nascell-automl
这个代码有详细的说明:
The First Step-by-Step Guide for Implementing Neural Architecture Search with Reinforcement Learning Using TensorFlow
代码一和代码二使用TensorFlow,有很多版本的问题,基本上可以用以下代码解决:
import tensorflow.compat.v1 as tf
还有些小的问题:
需要在环境中的tensorflow文件夹下加一个example文件夹:
链接: https://pan.baidu.com/s/1mjIsDxr2TCh6wop0-99Cyw?pwd=gb6s
提取码: gb6s
可以在TensorFlow官网搜索相关函数,查看正确用法,比如查看NASCell。
可以发现NASCell(4 * max_layers)的用法在tensorflow-addons,需要安装这个包,并且导入。
import tensorflow_addons as tfa
cnn.py中的这行初始化的代码也进行了修改。
tf.initializers.glorot_normal()
修改后代码:NAS with RL(Using TensorFlow)-CSDN博客
代码地址:
https://github.com/titu1994/neural-architecture-search?tab=readme-ov-file
也是有很多版本的问题,解决思路和代码一类似,除了上面的操作,还有:
在train.py中代码开始的地方添加了下面的代码:
tf.compat.v1.disable_eager_execution()
controller.py中下面两行代码有问题:
_, loss, summary, global_step = self.policy_session.run([self.train_op, self.total_loss, self.summaries_op,
self.global_step],
feed_dict=feed_dict)
self.summary_writer.add_summary(self.summaries_op, global_step)
self.summaries_op这个变量好像有问题,无法fetch到,我也没看懂,好像是跟可视化有关(tensorboard),索性直接把这块删了,改成下面这块:
_, loss, global_step = self.policy_session.run([self.train_op, self.total_loss,
self.global_step],
feed_dict=feed_dict)
# self.summary_writer.add_summary(self.summaries_op, global_step)
修改后代码:NAS with RL(Using TensorFlow)-CSDN博客
代码地址:
https://github.com/Longcodedao/NAS-With-RL
修改1:
把controller类下forward中注释掉的
self.total_layer = torch.randint(1, self.max_layer, (1,)).item()
移到初始化__init__下面了。
修改2:
将代码中play_episode函数下的unsqueeze和squeeze的参数由1改为0。
对其整理和解析的博客如下:
NAS with RL(使用强化学习进行神经网络架构搜索,基于pytorch框架)-CSDN博客
超详细No module named ‘tensorflow.examples’报错解决方法,详细有效!_no module named 'tensorflow.examples-CSDN博客
【问题解决】pytorch: RuntimeError: DataLoader worker (pid(s) 27292) exited unexpectedly_runtimeerror: dataloader worker (pid(s) 25676, 116-CSDN博客
Python-squeeze()、unsqueeze()函数的理解_python squeeze-CSDN博客
torch.gather/torch.scatter_size does not match previous size-CSDN博客
Tensorflow报错:TypeError: Fetch argument None has invalid type class ‘NoneType’_typeerror: fetch argument none has invalid type TensorFlow报错:tf.placeholder() is not compatible with eager execution.-CSDN博客 tensorflow_addons(tfa)安装与使用-CSDN博客 AttributeError: module ‘keras.backend‘ has no attribute ‘set_session‘_module 'keras.backend' has no attribute 'set_sessi-CSDN博客