TensorFlow 2.0 保存、读取、绘制模型

1. 直接保存和读取模型

  需要注意的是,在序贯模型中需要指定input_shape,如:

model = Sequential()
model.add(Flatten(input_shape=(28, 28)))
model.add(Dense(units = 10, activation = 'softmax'))

sgd = SGD(lr = 0.3)
model.compile(sgd, loss = 'mse', metrics = ['acc'])

保存模型:

model.save('model.h5')

from tensorflow.keras.models import load_model
model = load_model('model.h5')

2. 保存和读取模型结构和权重

2.1 保存模型结构和权重

  其中json文件保存的是网络的结构,hdf5文件保存的是网络结构中的参数。

import json

json_string = model.to_json() #得到json字符串,其中json字符串表示的是网络的基本结构
with open('model.json', 'w') as f:
    json.dump(json_string, f)

filepath = 'model.h5'
model.save_weights(filepath) #saves the weights of the model as a HDF5 file.

2.2 读取模型结构和权重

with open('model.json', 'r') as f:
    json_string = json.load(f)

model = model_from_json(json_string)
model.load_weights(filepath)

model.predict(x_test)

  需要注意的是,如果是evaluate就需要再次compile model。

3. 绘制模型

需要安装pydot and graphviz。

import numpy as np
from tensorflow.keras import Sequential
from tensorflow.keras import utils
from tensorflow.keras.datasets import mnist
from tensorflow.keras.layers import Dense,Dropout,Convolution2D,MaxPooling2D,Flatten
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.utils import plot_model
import matplotlib.pyplot as plt 
# 载入数据
(x_train,y_train),(x_test,y_test) = mnist.load_data()
# (60000,28,28)->(60000, 28, 28, 1)
x_train = x_train.reshape(-1, 28, 28, 1)/255.0
x_test = x_test.reshape(-1, 28, 28, 1)/255.0
# 换one hot格式
y_train = utils.to_categorical(y_train, num_classes=10)
y_test = utils.to_categorical(y_test, num_classes=10)

# 定义顺序模型
model = Sequential()
# 第一个卷积层
# input_shape 输入平面
# filters 卷积核/滤波器个数
# kernel_size 卷积窗口大小
# strides 步长
# padding padding方式 same/valid
# activation 激活函数


model.add(Convolution2D(
    input_shape = (28, 28, 1),
    filters = 32,
    kernel_size = 5,
    strides = 1,
    padding = 'same',
    activation = 'relu',
    name = 'conv1'
))
# 第一个池化层
model.add(MaxPooling2D(
    pool_size = 2,
    strides = 2,
    padding = 'same',
    name = 'pool1'
))
# 第二个卷积层
model.add(Convolution2D(64, 5,strides=1,padding='same',activation = 'relu',name='conv2'))
# 第二个池化层
model.add(MaxPooling2D(2,2,'same',name='pool2'))
# 把第二个池化层的输出扁平化为1维
model.add(Flatten())
# 第一个全连接层
model.add(Dense(1024,activation = 'relu'))
# Dropout
model.add(Dropout(0.5))
# 第二个全连接层
model.add(Dense(10,activation='softmax'))
plot_model(model,to_file="model.png",show_shapes=True,show_layer_names=True,rankdir='TB')
plt.figure(figsize=(15,15))
img = plt.imread("model.png")
plt.imshow(img)
plt.axis('off')
plt.show()

TensorFlow 2.0 保存、读取、绘制模型_第1张图片

你可能感兴趣的:(TensorFlow 2.0 保存、读取、绘制模型)