序列化的好处:spark分布式预测,方便模型存储
这块踩坑无数,翻阅无数中英文文档,其核心问题是,你自定义的model,是继承的什么类,如果是layer类,那么就不会存在model.to_json()等模型序列化时报错,这里包括to_yarm等序列化操作。报错如下:
NotImplementedError Traceback (most recent call last)
in
----> 1 model.to_json()
/usr/local/anaconda3/lib/python3.6/site-packages/tensorflow_core/python/keras/engine/network.py in to_json(self, **kwargs)
1207 A JSON string.
1208 """
-> 1209 model_config = self._updated_config()
1210 return json.dumps(
1211 model_config, default=serialization.get_json_type, **kwargs)
/usr/local/anaconda3/lib/python3.6/site-packages/tensorflow_core/python/keras/engine/network.py in _updated_config(self)
1185 from tensorflow.python.keras import __version__ as keras_version # pylint: disable=g-import-not-at-top
1186
-> 1187 config = self.get_config()
1188 model_config = {
1189 'class_name': self.__class__.__name__,
/usr/local/anaconda3/lib/python3.6/site-packages/tensorflow_core/python/keras/engine/network.py in get_config(self)
885 if not self._is_graph_network:
886 raise NotImplementedError
--> 887 return copy.deepcopy(get_network_config(self))
888
889 @classmethod
/usr/local/anaconda3/lib/python3.6/site-packages/tensorflow_core/python/keras/engine/network.py in get_network_config(network, serialize_layer_fn)
1940 filtered_inbound_nodes.append(node_data)
1941
-> 1942 layer_config = serialize_layer_fn(layer)
1943 layer_config['name'] = layer.name
1944 layer_config['inbound_nodes'] = filtered_inbound_nodes
/usr/local/anaconda3/lib/python3.6/site-packages/tensorflow_core/python/keras/utils/generic_utils.py in serialize_keras_object(instance)
138 if hasattr(instance, 'get_config'):
139 return serialize_keras_class_and_config(instance.__class__.__name__,
--> 140 instance.get_config())
141 if hasattr(instance, '__name__'):
142 return instance.__name__
in get_config(self)
179
180 def get_config(self):
--> 181 base_config = super().get_config().copy()
182 config = {}
183 config['name'] = self.name
/usr/local/anaconda3/lib/python3.6/site-packages/tensorflow_core/python/keras/engine/network.py in get_config(self)
884 def get_config(self):
885 if not self._is_graph_network:
--> 886 raise NotImplementedError
887 return copy.deepcopy(get_network_config(self))
888
NotImplementedError:
那么核心问题是什么,在我们实现模型类时,如果继承的是model类,那么就会存在无法序列化问题
比如
class Intermitforecast(keras.Model):
def __init__(self, pre_lens=28, window=180, name='IF', featurenum=36, **kwargs):
super(Intermitforecast, self).__init__(name=name, **kwargs)
self.pre_lens = pre_lens
self.window = window
self.units = 28
self.ts_num = 1
self.ft_num = featurenum
这里需要改变的就是,不继承Model类,继承Layer类。同时要记得在Layer类中,重写def get_config(self):。切记__init__中的所有初始化参数都要添加到get_config中
class Intermitforecast(Layer):
def __init__(self, pre_lens=28, window=180, name='IF', featurenum=36, **kwargs):
super(Intermitforecast, self).__init__(name=name, **kwargs)
self.pre_lens = pre_lens
self.window = window
self.units = 28
self.ts_num = 1
self.ft_num = featurenum
………………………………
def get_config(self):
base_config = super().get_config().copy()
config = {}
config['name'] = self.name
config.update({"pre_lens": self.pre_lens})
config.update({"window": self.window})
config.update({"units": self.units})
config.update({"ts_num": self.ts_num})
config.update({"ft_num": self.pre_lens})
config.update({"w_reg": self.w_reg})
config.update({"v_reg": self.v_reg})
config.update({"hidden_units": self.hidden_units})
return dict(list(base_config.items()) + list(config.items()))
此处可以参考 解决NotImplementedError: Layer XX has arguments in `__init__` and therefore must override `get_config`_sinysama的博客-CSDN博客
另外:说一下我是如何发现model和layer两种类的差异的,没有遇到坑之前是万万不会理解的
from 官方文档【在类继承模型中,模型的拓扑结构是由 Python 代码定义的(而不是网络层的静态图)。这意味着该模型的拓扑结构不能被检查或序列化。因此,以下方法和属性不适用于类继承模型:】
model.inputs
和 model.outputs
。model.to_yaml()
和 model.to_json()
。model.get_config()
和 model.save()