tensorflow如何保存和读取模型参数

#定义保存的路径
checkpoint_path = "zzw/zzw.ckpt"
checkpoint_dir = os.path.dirname(checkpoint_path)


cp_callback = tf.keras.callbacks.ModelCheckpoint(checkpoint_path,
                                                 save_weights_only=True,
                                                 verbose=1)


# 在model.fit里面添加callbacks ,在每个epochs生成一个ckpt
model.fit(train_images, train_labels,  epochs = 10,
          validation_data = (test_images,test_labels),
          callbacks = [cp_callback])  
#读取模型参数
model.load_weights('zzw/zzw.cpkt')

 

你可能感兴趣的:(深度学习)