CNN模型是计算机视觉里面常用的工具了,训模师训好模型可能还需要其他的操作,比如可能做剪枝,或者量化,需要对模型的参数做一些操作。这时候就需要解析模型的参数了。这篇文章主要叙述一下,在tensorflow2 下怎么解析模型参数。
tensorflow2 构建模型还是首选keras接口, 在模型保存方面有好几个接口可以选择,触类旁通, 此处我习惯使用tf.save_model。
class PlainCNN(tf.keras.Model):
def __init__(self,
kernel_initializer='glorot_normal'):
super(PlainCNN, self).__init__()
self.conv1= tf.keras.layers.Conv2D( filters=32,
kernel_size=(3, 3),
strides=2,
padding='same',
use_bias=False,
kernel_initializer=kernel_initializer)
self.conv2 = tf.keras.layers.Conv2D(filters=64,
kernel_size=(3, 3),
strides=2,
padding='same',
use_bias=False,
kernel_initializer=kernel_initializer)
self.conv3 = tf.keras.layers.Conv2D(filters=128,
kernel_size=(3, 3),
strides=2,
padding='same',
use_bias=False,
kernel_initializer=kernel_initializer)
@tf.function(input_signature=[tf.TensorSpec([None, None, None, 3], tf.float32)])
def call(self, inputs, training=False):
x1 = self.conv1(inputs)
x2 = self.conv2(x1)
x3 = self.conv3(x2)
return x3
import numpy as np
model = PlainCNN()
image=np.zeros(shape=(1,48,48,3),dtype=np.float32)
x=model(image)
# 将会打印出模型的信息:
# the result shape is : (1, 6, 6, 128)
print('the result shape is :', x.shape)
##保存模型为 tmp_model
tf.saved_model.save(model,'tmp_model')
此时我们有了tmp_model 模型目录
model=tf.saved_model.load('./tmp_model')
###这时model可以理解为上面的代码,可以直接inference, 也可以依次获取每个变量的variables,####或者trainable_variables,等等
print(model.conv1.variables)
输出实例
‘ListWrapper([
-1.57765336e-02, 5.53220101e-02, 6.69927672e-02,
1.18893180e-02, -3.55695821e-02, 2.96269413e-02,
6.24449886e-02, 8.93794820e-02, 1.48759335e-01,
-6.80063153e-03, -2.44185757e-02, -1.68685019e-01,
…
…
’
然后就可以想做什么就做什么了。