C# tensorflow.NET调用训练好的frozen模型

我们经常会用tensorflow的python版本去训练深度学习模型,但是有些项目要在C#,因此,tensorflow.NET给了一个不错的选择。

本文介绍如何利用tensorflow.NET调用python训练好的模型。tensorflow == 1.14.0, TensorFlow.NET == 0.21.0, SciSharp.TensorFlow.Redist == 2.3.1, 具体支持版本信息要看下面这个网站

https://github.com/SciSharp/TensorFlow.NET

1)利用python版本的tensorflow训练模型;

2)新建c#的vs2019工程,利用NuGet安装相应版本的tensorflow.NET;

     注意几点:要用vs2019,NET Core 3.1

                       tensorflow.NET 和SciSharp.Tensorflow.Redist都要安装,如果是gpu版本,还要安装相应的gpu版本的库,详情见

                        https://tensorflownet.readthedocs.io/en/latest/HelloWorld.html

3)参考下面的代码

class UNet_tf

    {

        string dir = "Images";

        private string model_filepath = "";

        List file_ndarrays = new List();

        public UNet_tf(String model_filepath)

        {

            this.model_filepath = model_filepath;

        }

 

        public void load_graph()

        {
            tf.compat.v1.disable_eager_execution(); //这个很重要

            var files = Directory.GetFiles(dir);

            var nd = ReadTensorFromImageFile(files[0],

                input_height: 240,

                input_width: 240,

                input_mean: 128,

                input_std: 255);

            Console.WriteLine("img loaded");

 

            var graph = new Graph();

            graph.Import(Path.Join(base_path, this.model_filepath));

            Operation[] ops = graph.ToArray();

            Console.WriteLine("model loaded");

 

            var input_name = "x";

            var output_name = "cnn/output";

 

            var input_operation = graph.OperationByName(input_name);

            var output_operation = graph.OperationByName(output_name);

            var binary = tf.nn.sigmoid(output_operation.outputs[0]);

 

            var result_labels = new List();

            var sw = new Stopwatch();

 

            using (var sess = tf.Session(graph))

            {
                sw.Restart();

                var results = sess.run(binary, (input_operation.outputs[0], nd));

                results = np.squeeze(results);

                results = results[0];

                var needle = np.zeros((240, 240));

                var larva = np.zeros((240, 240));

                needle[results == 1] = 255;
            }

        }


        private NDArray ReadTensorFromImageFile(string file_name,

                                int input_height = 299,

                                int input_width = 299,

                                int input_mean = 0,

                                int input_std = 255)

        {

            var graph = tf.Graph().as_default();

 

            var file_reader = tf.io.read_file(file_name, "file_reader");

            var decodeJpeg = tf.image.decode_jpeg(file_reader, channels: 1, name: "DecodeJpeg");

            var cast = tf.cast(decodeJpeg, tf.float32);

            var dims_expander = tf.expand_dims(cast, 0);

            var resize = tf.constant(new int[] { input_height, input_width });

            var bilinear = tf.image.resize_bilinear(dims_expander, resize);

            var sub = tf.subtract(bilinear, new float[] { input_mean });

            var normalized = tf.divide(sub, new float[] { input_std });


            using (var sess = tf.Session(graph))

                return sess.run(normalized);

        }

    }

 

你可能感兴趣的:(tensorflow,C#,tensorflow,界面设计)