tensorflow模型自定义层保存问题Unknown layer:xx

错误内容

昨天在使用tensorflow2.0进行transformer模型保存时,出现了无法保存和读取模型的错误。
使用model.save()进行模型保存时出现这样的错误:NotImplementedError: Layer XX has arguments in init and therefore must override get_config.
我就依照网络上的方法,通过添加get_config()方法对模型的自定义层的类中进行修改,然后保存成功了。
但在读取模型时,还是出现了错误:ValueError: Unknown layer: TokenEmbedding。

解决方法

使用model.save_weights()进行模型参数的保存,在要使用模型时,创建一个模型结构,再使用model.load_weights()读取模型参数

你可能感兴趣的:(tensorflow,深度学习,transformer)