tensorflow 加载预训练模型

加载预训练模型需要具备两个条件:1.框架结构(知道每一层的名字),2. 预训练好的模型文件.ckpt

加载预训练模型代码如下:

import tensorflow as tf
import numpy as np
weights_1 = tf.Variable(tf.zeros([ 3 , 4 ]))
# weights_2 = tf.Variable(tf.zeros([ 4 , 3 ]))

sess = tf.InteractiveSession()
saver = tf.train.Saver()

saver.restore(sess, '/tmp/checkpoint/model.ckpt' )
o_test = np.array([[ 4.0 , 3.0 , 2.0 ]], dtype = 'float32' )
label = tf.matmul(o_test, weights_1)
# label = tf.matmul(label, weights_2)
print sess.run(label)




以上为加载模型代码,可以写全变量名,也可只写一部分。可根据输出来定。


其中,model.ckpt训练模型代码如下:

import tensorflow as tf
import numpy as np

i_data = np.array([[ 5.0 , 3.0 , 2.0 ]], dtype = 'float32' )
i_label = np.array([[ 15.0 , 10.0 , 22.0 ]], dtype = 'float32' )

weights_1 = tf.Variable(tf.zeros([ 3 , 4 ]))
out_1 = tf.matmul(i_data, weights_1)

weights_2 = tf.Variable(tf.zeros([ 4 , 3 ]))
out = tf.matmul(out_1, weights_2)

init_op = tf.global_variables_initializer()
saver = tf.train.Saver()

loss = tf.reduce_mean(tf.square(out - i_label))
training = tf.train.GradientDescentOptimizer( 0.01 ).minimize(loss)

sess = tf.Session()
sess.run(init_op)
for i in range ( 20000 ):
sess.run(training)
save_path = saver.save(sess, '/tmp/checkpoint/model.ckpt' )

你可能感兴趣的:(tensorflow学习)