在用h5py库函数保存模型时抛错,保存不了
Traceback (most recent call last):
File "INQ.py", line 258, in
verbose=2)
File "/home/chutz/anaconda3/lib/python3.5/site-packages/keras/engine/training.py", line 1705, in fit
validation_steps=validation_steps)
File "/home/chutz/anaconda3/lib/python3.5/site-packages/keras/engine/training.py", line 1256, in _fit_loop
callbacks.on_epoch_end(epoch, epoch_logs)
File "/home/chutz/anaconda3/lib/python3.5/site-packages/keras/callbacks.py", line 77, in on_epoch_end
callback.on_epoch_end(epoch, logs)
File "/home/chutz/anaconda3/lib/python3.5/site-packages/keras/callbacks.py", line 458, in on_epoch_end
self.model.save(filepath, overwrite=True)
File "/home/chutz/anaconda3/lib/python3.5/site-packages/keras/engine/topology.py", line 2591, in save
save_model(self, filepath, overwrite, include_optimizer)
File "/home/chutz/anaconda3/lib/python3.5/site-packages/keras/models.py", line 185, in save_model
optimizer_weights_group.attrs['weight_names'] = weight_names
File "h5py/_objects.pyx", line 54, in h5py._objects.with_phil.wrapper
File "h5py/_objects.pyx", line 55, in h5py._objects.with_phil.wrapper
File "/home/chutz/anaconda3/lib/python3.5/site-packages/h5py/_hl/attrs.py", line 95, in __setitem__
self.create(name, data=value, dtype=base.guess_dtype(value))
File "/home/chutz/anaconda3/lib/python3.5/site-packages/h5py/_hl/attrs.py", line 188, in create
attr = h5a.create(self._id, self._e(tempname), htype, space)
File "h5py/_objects.pyx", line 54, in h5py._objects.with_phil.wrapper
File "h5py/_objects.pyx", line 55, in h5py._objects.with_phil.wrapper
File "h5py/h5a.pyx", line 47, in h5py.h5a.create
RuntimeError: Unable to create attribute (object header message is too large)
报错原因:HDF5文件类型大小限制
HDF5 has a header limit of 64kb for all metadata of the columns. This include name, types, etc. When you go about roughly 2000 columns, you will run out of space to store all the metadata.
解决方法:仅保存权重,模型结构以json对象形式保存
在callback中调用h5py库,定位到相应位置,我用的callback是ModelCheckpoint(),其中有一个on_epoch_end()来做保存工作,看到已给出self.save_weights_only
def on_epoch_end(self, epoch, logs=None):
logs = logs or {}
self.epochs_since_last_save += 1
if self.epochs_since_last_save >= self.period:
self.epochs_since_last_save = 0
filepath = self.filepath.format(epoch=epoch + 1, **logs)
if self.save_best_only:
current = logs.get(self.monitor)
if current is None:
warnings.warn('Can save best model only with %s available, '
'skipping.' % (self.monitor), RuntimeWarning)
else:
if self.monitor_op(current, self.best):
if self.verbose > 0:
print('\nEpoch %05d: %s improved from %0.5f to %0.5f,'
' saving model to %s'
% (epoch + 1, self.monitor, self.best,
current, filepath))
self.best = current
if self.save_weights_only:
self.model.save_weights(filepath, overwrite=True)
else:
self.model.save(filepath, overwrite=True)
else:
if self.verbose > 0:
print('\nEpoch %05d: %s did not improve from %0.5f' %
(epoch + 1, self.monitor, self.best))
else:
if self.verbose > 0:
print('\nEpoch %05d: saving model to %s' % (epoch + 1, filepath))
if self.save_weights_only:
self.model.save_weights(filepath, overwrite=True)
else:
self.model.save(filepath, overwrite=True)
修改原函数:
(1)重新定义filepath,把后缀由hdf5改为h5
(2)调用callback时设置save_weights_only=True
checkpointCallback = ModelCheckpoint(filepath=parameters.modelSaveName, verbose=1, save_weights_only=True)
bestCheckpointCallback = ModelCheckpoint(filepath=parameters.ModelSaveName, verbose=1, save_best_only=True, save_weights_only=True)
上面一个是保存每个epoch的模型,一个是保存best model
备注:我没有另外以json对象形式保存,可以参考这个代码,自定义callback函数,分别保存json对象和权重