load_model加载使用leaky_relu激活函数的网络报错

这个问题在之前已经有讲解如何解决了,但是今天准备修改源码的方式修复bug。

之前文章

load_model加载使用'leaky_relu'激活bug处理

报错的主要因为这个方法tensorflow.python.keras.activations.deserialize


这个方法最后调用方法tensorflow.python.keras.utils.generic_utils.deserialize_keras_object

红色虚线处就是报错的地方,代码是执行到obj = module_objects.get(object_name)赋值obj变量为None,进而导致了这个bug
所以问题就出在module_objects这个字典对象中,回到上一个方法这个对象主要是

  globs = globals()

  # only replace missing activations
  advanced_activations_globs = advanced_activations.get_globals()
  for key, val in advanced_activations_globs.items():
    if key not in globs:
      globs[key] = val

globals()是获取该文件的所有实例对象,按字典返回。
执行该方法的文件是activations.py,按文件名这个应该包含leaky_relu激活函数,但是这个方法并没有。下面的advanced_activations.get_globals()是执行advanced_activations.py中的globals()方法,下面是源码截图(有折叠):


这里并没有leaky_relu激活函数,只有LeakyRelu层类,所以上面的globs字典没有leaky_relu激活函数,因为advanced_activations.py文件是专门存放层的。所以这里我们需要修改源码让,activations.py中的globals()获取到leaky_relu激活函数,可以在activations.py编写leaky_relu激活函数方法,但是为了简单起见,这里选择直接导入其他文件中的leaky_relu方法,如下:

导入上面两个文件中的任意一个文件中的方法都行。

你可能感兴趣的:(load_model加载使用leaky_relu激活函数的网络报错)