C#调用Tensorflow进行模板加载和预测

C# 利用tensorflowsharp调用Tesnorflow模型预测

  • 训练模型的生成
  • C#中部署.pb训练模型

有时我们会用到C#来部署tensorflow或者keras的训练模型,其中会就涉及到模型的加载和调用。C#已经开发了相关支持包TensorFlowSharp。该插件可以在NuGet中获取并安装。当前使用的VS2015, Python 3.6,Pycharm2018.1 ,tensorflow2.0

训练模型的生成

利用python训练模型,这里不再详细介绍如何训练模型。后续会专门来做一章节来介绍这一块内容。当前应用是用以简单的贯序模型做图片的二分类。由于用keras生成的模型是.h5的文件,需要转换成可以用tensorflow调用的.pb文件。转换的相关代码如下。

from keras import layers, models
from keras.preprocessing.image import ImageDataGenerator
import numpy as np
import tensorflow.compat.v1 as tf1
import tensorflow as tf

def seq_model(X_train, X_val, Y_train, Y_val, img_width, img_height, batch_size):
    nb_train_samples = len(X_train)
    nb_validation_samples = len(X_val)

    model = models.Sequential()

    model.add(layers.Conv2D(32, (3, 3), input_shape=(img_width, img_height, 3)))
    model.add(layers.Activation('relu'))
    model.add(layers.MaxPooling2D(pool_size=(2, 2)))

    model.add(layers.Conv2D(32, (3, 3)))
    model.add(layers.Activation('relu'))
    model.add(layers.MaxPooling2D(pool_size=(2, 2)))

    model.add(layers.Conv2D(64, (3, 3)))
    model.add(layers.Activation('relu'))
    model.add(layers.MaxPooling2D(pool_size=(2, 2)))

    model.add(layers.Flatten())
    model.add(layers.Dense(64))
    model.add(layers.Activation('relu'))
    model.add(layers.Dropout(0.5))
    model.add(layers.Dense(1))
    model.add(layers.Activation('sigmoid'))

    model.compile(loss='binary_crossentropy',
                  optimizer='rmsprop',
                  metrics=['accuracy'])
    model.summary()
    train_datagen = ImageDataGenerator(
        rescale=1. / 255,
        shear_range=0.2,
        zoom_range=0.2,
        horizontal_flip=True)

    val_datagen = ImageDataGenerator(
        rescale=1. / 255,
        shear_range=0.2,
        zoom_range=0.2,
        horizontal_flip=True)

    train_generator = train_datagen.flow(np.array(X_train), Y_train, batch_size=batch_size)
    validation_generator = val_datagen.flow(np.array(X_val), Y_val, batch_size=batch_size)
    history = model.fit_generator(
        train_generator,
        steps_per_epoch=nb_train_samples // batch_size,
        epochs=100,
        validation_data=validation_generator,
        validation_steps=nb_validation_samples // batch_size
    )
    model.save_weights('model_wieghts.h5')
    model.save('model_keras.h5')

    print('input is :', model.input.name)
    print('output is:', model.output.name)

    tf1.reset_default_graph()
    tf1.keras.backend.set_learning_phase(0)  # 调用模型前一定要执行该命令
    tf1.disable_v2_behavior()  # 禁止tensorflow2.0的行为
    network = tf.keras.models.load_model('model_keras.h5')
    frozen_graph = freeze_session(tf1.keras.backend.get_session(),
                                  output_names=[out.op.name for out in network.outputs])
    tf1.train.write_graph(frozen_graph, "D:\AI Vision\AICore", "saved2pb.pb", as_text=False)

在训练完之后,请一定要记下他的输入层和输出层名称,在C#中调用时需要用作配置参数。
C#调用Tensorflow进行模板加载和预测_第1张图片

C#中部署.pb训练模型

注意事项
在C#中如果需要调用tesorflow加载训练模型。

  1. .net版本在4.6以上,本版本使用4.7.2
    C#调用Tensorflow进行模板加载和预测_第2张图片
    2.使用64位系统,去掉勾选Prefer 32-bit。.C#调用Tensorflow进行模板加载和预测_第3张图片引用命名空间
using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;
using System.Threading.Tasks;
using System.Drawing.Imaging;
using System.IO;
using TensorFlow;
using LT.ST;
using System.Drawing;
using LT.ST.Element;

读取模型

//modelFile为.pb训练文件路径
private TFGraph LoadModel(string modelFile)
        {
            var g = new TFGraph();
            var model = File.ReadAllBytes(modelFile);
            g.Import(model, "");
            return g;
        }

获取输入图形张量

//inputFileName  输入图片的路径
public unsafe TFTensor CreateTensorFromImageFileAlt(string inputFileName, TFDataType destinationDataType = TFDataType.Float)
        {
            Bitmap bitmap = new Bitmap(inputFileName);
			//需要将图片处理成训练样本的相同大小
            int w = dataDirectorLocal.CurrProject.fWidth;//训练时图片宽度
            int h = dataDirectorLocal.CurrProject.fHeighth;//训练时图片高度
            Bitmap resizedBmp = new Bitmap(w, h);
            Graphics g = Graphics.FromImage(resizedBmp);
            //设置高质量插值法  
            g.InterpolationMode = System.Drawing.Drawing2D.InterpolationMode.High;
            //设置高质量,低速度呈现平滑程度  
            g.SmoothingMode = System.Drawing.Drawing2D.SmoothingMode.HighQuality;
            g.CompositingQuality = System.Drawing.Drawing2D.CompositingQuality.HighQuality;
            //消除锯齿
            g.SmoothingMode = System.Drawing.Drawing2D.SmoothingMode.AntiAlias;
            g.DrawImage(bitmap, new Rectangle(0, 0, w, h), new Rectangle(0, 0, bitmap.Width, bitmap.Height), GraphicsUnit.Pixel);


            BitmapData data = resizedBmp.LockBits(new Rectangle(0, 0, resizedBmp.Width, resizedBmp.Height), ImageLockMode.ReadOnly, PixelFormat.Format24bppRgb);

            var matrix = new float[1, resizedBmp.Width, resizedBmp.Height, 3];

            byte* scan0 = (byte*)data.Scan0.ToPointer();

            for (int i = 0; i < data.Height; i++)
            {
                for (int j = 0; j < data.Width; j++)
                {
                    byte* pixelData = scan0 + i * data.Stride + j * 3;
                    matrix[0, i, j, 0] = pixelData[2];
                    matrix[0, i, j, 1] = pixelData[1];
                    matrix[0, i, j, 2] = pixelData[0];
                }
            }

            resizedBmp.UnlockBits(data);

            TFTensor tensor = matrix;

            g.Dispose();
            resizedBmp.Dispose();
            bitmap.Dispose();
            return tensor;
        }

开始测试和预测

public bool TensorAnalysisFromFile(string modelFile, string[] labelsFiles)
        {
            // 创建图
            try
            {
                var g = new TFGraph();
                g = LoadModel(modelFile);
                dataDirectorLocal.Images.Clear();
                foreach (var labelsFile in labelsFiles)
                {
                    // 定义常量
                    using (var session = new TFSession(g))
                    {
                        var tensor = CreateTensorFromImageFileAlt(labelsFile);
                        var runner = session.GetRunner();
 						//训练时的输入和输出名称
                        runner.AddInput(g["conv2d_1_input"][0], tensor);
                        runner.Fetch(g["activation_5/Sigmoid"][0]);

                        var output = runner.Run();

                        var result = output[0];

                        var rshape = result.Shape;

                        if (result.NumDims != 2 || rshape[0] != 1)
                        {
                            var shape = "";
                            foreach (var d in rshape)
                            {
                                shape += $"{d} ";
                            }
                            shape = shape.Trim();
                            Console.WriteLine($"Error: expected to produce a [1 N] shaped tensor where N is the number of labels, instead it produced one with shape [{shape}]");
                            Environment.Exit(1);
                        }
                        bool jagged = true;

                        var bestIdx = 0;
                        float p = 0, best = 0;
                        var probabilities = ((float[][])result.GetValue(jagged: true))[0];
                        ImageStruct img = new ImageStruct();
                        img.strSource = labelsFile;
                        img.Factor = (float)probabilities[0];
                        img.Name = img.Factor < 0.5 ? dataDirectorLocal.LabelCatetorys[1].Name : dataDirectorLocal.LabelCatetorys[2].Name;
                        img.backgroud2 = img.Factor < 0.5 ? dataDirectorLocal.LabelCatetorys[1].Color : dataDirectorLocal.LabelCatetorys[2].Color;
                        img.iTested = img.Factor < 0.5 ? dataDirectorLocal.LabelCatetorys[1].Key : dataDirectorLocal.LabelCatetorys[2].Key;
                        dataDirectorLocal.Images.Add(img);
                    }
                }
            }
            catch (Exception e)
            {
                return false;
            }
            return true;
        }

程序测试成功如下C#调用Tensorflow进行模板加载和预测_第4张图片
C#调用Tensorflow进行模板加载和预测_第5张图片

你可能感兴趣的:(C#,人工智能,python,深度学习,tensorflow,c#)