tf.reset_default_graph() 注意要重新设置一下图
from __future__ import print_function
import tensorflow as tf
from tensorflow.contrib.tensor_forest.python import tensor_forest
from tensorflow.python.ops import resources
import os
os.environ["CUDA_VISIBLE_DEVICES"] = ""
from tensorflow.examples.tutorials.mnist import input_data
mnist=input_data.read_data_sets('/tmp/data/',one_hot=False)
tf.reset_default_graph()
num_steps=500
batch_size=1024
num_classes=10
num_features=784
num_trees=10
max_nodes=1000
X=tf.placeholder(tf.float32,shape=[None,num_features])
Y=tf.placeholder(tf.int32,shape=[None])
hparams=tensor_forest.ForestHParams(num_classes=num_classes,
num_features=num_features,
num_trees=num_trees,
max_nodes=max_nodes).fill()
forest_graph=tensor_forest.RandomForestGraphs(params=hparams)
train_op=forest_graph.training_graph(X,Y)
loss_op=forest_graph.training_loss(X,Y)
infer_op, _, _=forest_graph.inference_graph(X)
correct_pre=tf.equal(tf.argmax(infer_op,1),tf.cast(Y,tf.int64))
accuracy_op = tf.reduce_mean(tf.cast(correct_pre,tf.float32))
init_vars=tf.group(tf.global_variables_initializer(),resources.initialize_resources(resources.shared_resources()))
sess=tf.Session()
sess.run(init_vars)
for i in range(1,num_steps+1):
batch_x, batch_y=mnist.train.next_batch(batch_size)
_, l=sess.run([train_op,loss_op],feed_dict={X:batch_x,Y:batch_y})
if i%50==0 or i==1:
acc=sess.run(accuracy_op,feed_dict={X:batch_x,Y:batch_y})
print('Step: %i Loss: %f Accuracy: %f' %(i,l,acc))
test_x, test_y = mnist.test.images, mnist.test.labels
print("Test Accuracy:", sess.run(accuracy_op, feed_dict={X: test_x, Y: test_y}))