TensorFlow2.0:模型的保存与加载

**

一、权重参数的保存与加载

**

network.save_weights('weights.ckpt')
network.load_weights('weights.ckpt')

权重参数的保存与加载可以针对任何模型,包括自定义的。
但是在加载权重参数时,其模型的结构需要与原来的完全一致。

import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers,Sequential,optimizers,datasets,metrics

def preprocess(x,y):
    x = tf.cast(tf.reshape(x,[-1]),dtype=tf.float32)/255.
    y = tf.cast(tf.one_hot(y,depth=10),dtype=tf.int32)
    return x,y

#load_data
(x_train,y_train),(x_val,y_val) = datasets.mnist.load_data()
print('data: ',x_train.shape,y_train.shape,x_val.shape,y_val.shape)

db = tf.data.Dataset.from_tensor_slices((x_train,y_train))
db = db.map(preprocess).shuffle(60000).batch(128)
db_val = tf.data.Dataset.from_tensor_slices((x_val,y_val))
db_val = db_val.map(preprocess).batch(128)

#self def layer
class MyDense(layers.Layer):#inherit layers.Layer
    def __init__(self,input_dim,output_dim):#init
        super(MyDense,self).__init__()

        self.kernal = self.add_variable('w',[input_dim,output_dim])
        self.bias = self.add_variable('b',[output_dim])

    def call(self,inputs,training=None):#compute
        out = inputs @ self.kernal + self.bias
        return out

#self def network
class MyModel(keras.Model):#inherit keras.Model
    def __init__(self):#init
        super(MyModel,self).__init__()
        self.fc1 = MyDense(input_dim=28*28,output_dim=512)
        self.fc2 = MyDense(input_dim=512, output_dim=256)
        self.fc3 = MyDense(input_dim=256, output_dim=128)
        self.fc4 = MyDense(input_dim=128, output_dim=64)
        self.fc5 = MyDense(input_dim=64, output_dim=32)
        self.fc6 = MyDense(input_dim=32, output_dim=10)

    def call(self,inputs,training=None):#compute inputs.shape = [b,28*28]
        x = self.fc1(inputs)
        x = tf.nn.relu(x)
        x = self.fc2(x)
        x = tf.nn.relu(x)
        x = self.fc3(x)
        x = tf.nn.relu(x)
        x = self.fc4(x)
        x = tf.nn.relu(x)
        x = self.fc5(x)
        x = tf.nn.relu(x)
        x = self.fc6(x)
        return x

network = MyModel()
network.build(input_shape=[None,28*28])
network.summary()

#input para
network.compile(optimizer=optimizers.Adam(lr=1e-2),
                loss = tf.losses.CategoricalCrossentropy(from_logits=True),
                metrics = ['accuracy'])

#run network
network.fit(db,epochs=5,validation_data=db_val,validation_freq=1)

#save weights
network.save_weights('weights.ckpt')
print('saved weights')
del network

#new network
network = MyModel()
network.build(input_shape=[None,28*28])
network.compile(optimizer=optimizers.Adam(lr=1e-2),
                loss = tf.losses.CategoricalCrossentropy(from_logits=True),
                metrics = ['accuracy'])

network.load_weights('weights.ckpt')
print('loaded weights')

network.evaluate(db_val)

**

二、完整模型的保存与加载

**

network.save('model.h5')
network = tf.keras.models.load_model('model.h5')

该方法仅适用Sequential类的模型,不适用于自定义模型。

import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers,Sequential,optimizers,datasets,metrics

def preprocess(x,y):
    x = tf.cast(tf.reshape(x,[-1]),dtype=tf.float32)/255.
    y = tf.cast(tf.one_hot(y,depth=10),dtype=tf.int32)
    return x,y

#load_data
(x_train,y_train),(x_val,y_val) = datasets.mnist.load_data()
print('data: ',x_train.shape,y_train.shape,x_val.shape,y_val.shape)

db = tf.data.Dataset.from_tensor_slices((x_train,y_train))
db = db.map(preprocess).shuffle(60000).batch(128)
db_val = tf.data.Dataset.from_tensor_slices((x_val,y_val))
db_val = db_val.map(preprocess).batch(128)

network = Sequential([
    layers.Dense(512,activation=tf.nn.relu),
    layers.Dense(256,activation=tf.nn.relu),
    layers.Dense(128,activation=tf.nn.relu),
    layers.Dense(64,activation=tf.nn.relu),
    layers.Dense(32,activation=tf.nn.relu),
    layers.Dense(10)
])
network.build(input_shape=[None,28*28])
network.summary()

#input para
network.compile(optimizer=optimizers.Adam(lr=1e-2),
                loss = tf.losses.CategoricalCrossentropy(from_logits=True),
                metrics = ['accuracy'])

#run network
network.fit(db,epochs=5,validation_data=db_val,validation_freq=1)

#save all model
network.save('model.h5')
print('saved total mdoel.')
del network

#new network
network = tf.keras.models.load_model('model.h5')
print('loaded total model')
network.evaluate(db_val)

你可能感兴趣的:(tensorflow2.0)