tensorflow2解析模型参数

tensorflow2 解析模型参数

introduction

CNN模型是计算机视觉里面常用的工具了,训模师训好模型可能还需要其他的操作,比如可能做剪枝,或者量化,需要对模型的参数做一些操作。这时候就需要解析模型的参数了。这篇文章主要叙述一下,在tensorflow2 下怎么解析模型参数。

step by step

1. 先构建并保存一个模型

tensorflow2 构建模型还是首选keras接口, 在模型保存方面有好几个接口可以选择,触类旁通, 此处我习惯使用tf.save_model。

class PlainCNN(tf.keras.Model):
    def __init__(self,
                 kernel_initializer='glorot_normal'):
        super(PlainCNN, self).__init__()


        self.conv1= tf.keras.layers.Conv2D( filters=32,
                                           kernel_size=(3, 3),
                                           strides=2,
                                           padding='same',
                                           use_bias=False,
                                           kernel_initializer=kernel_initializer)

        self.conv2 = tf.keras.layers.Conv2D(filters=64,
                                            kernel_size=(3, 3),
                                            strides=2,
                                            padding='same',
                                            use_bias=False,
                                            kernel_initializer=kernel_initializer)

        self.conv3 = tf.keras.layers.Conv2D(filters=128,
                                            kernel_size=(3, 3),
                                            strides=2,
                                            padding='same',
                                            use_bias=False,
                                            kernel_initializer=kernel_initializer)
    @tf.function(input_signature=[tf.TensorSpec([None, None, None, 3], tf.float32)])
    def call(self, inputs, training=False):
        x1 = self.conv1(inputs)
        x2 = self.conv2(x1)
        x3 = self.conv3(x2)
        return x3
import numpy as np
model = PlainCNN()
image=np.zeros(shape=(1,48,48,3),dtype=np.float32)
x=model(image)

# 将会打印出模型的信息:
# the result shape is : (1, 6, 6, 128)
print('the result shape is :', x.shape)
##保存模型为 tmp_model
tf.saved_model.save(model,'tmp_model')

此时我们有了tmp_model 模型目录

2. 加载模型
model=tf.saved_model.load('./tmp_model')
###这时model可以理解为上面的代码,可以直接inference, 也可以依次获取每个变量的variables,####或者trainable_variables,等等
print(model.conv1.variables)

输出实例
‘ListWrapper([ array([[[[ 8.40592012e-02, -1.82091370e-02, 6.67047426e-02,
-1.57765336e-02, 5.53220101e-02, 6.69927672e-02,
1.18893180e-02, -3.55695821e-02, 2.96269413e-02,
6.24449886e-02, 8.93794820e-02, 1.48759335e-01,
-6.80063153e-03, -2.44185757e-02, -1.68685019e-01,


然后就可以想做什么就做什么了。

你可能感兴趣的:(tensorflow2,tensorflow2,模型参数解析,tensorflow2)