使用C# 调用tensorflow和keras 训练样本

使用C# 调用tensorflow和keras 训练样本

有些样本比较小,算力要求不高的项目我们可以使用个人电脑的CPU来进行学习和培训。工业自动化或者一些特殊场合,有时我们习惯于用C#等做人机交互的前端。对于这样的项目我们如何来调取tensorflow 或者keras来培训模型呢?结合之前发布的 C#来部署tensorflow的培训模型 我们就可以利用C# 完成从图片加载、分类、训练到部署的所有操作。完成一个完整的AI应用项目。

应用准备

本示例应用 VS 2015 Python 3.6, django 2.1, tensorflow 2.0。
实现方法
使用python构建一个django 后台,通过后台对样本进行学习训练,生成模板。C#通过system.net来操作http协议和后台传递数据和通讯。

django 后台

配置路由

urlpatterns = [
    url(r'^train/$', views.train_start, name='train_start'),
]

配置试图

务必新建一个线程来开始训练样本。因为对于http协议,最好尽快响应,避免占用路由资源。

def train_start(request):
    if request.method == 'POST':
        rev = json.loads(request.body)
    elif request.method == 'GET':
        rev = json.loads(request.GET.get('data'))
    else:
        return HttpResponse('access deny')
    if rev['Event'] == 'TRAIN':
        threading.Thread(target=Manager, args=[rev, rev['Event']]).start()
        res = {'return': 0}
    else:
        res = Manager(rev, rev['Event'])
    return HttpResponse(json.dumps(res))

执行训练

任务分配

from AICore.MainTest.Train import Train
from AICore.MainTest.Msg import GetMsg
def Manager (rec, event):
    functions = {'TRAIN': Train,
                 "CHECK": GetMsg}
    func = functions[event]
    return func(rec)

调取训练

def Train(rec):
    global train_dir
    global labels
    global IMG_W
    global IMG_H
    global BATCH_SIZE
    global ModelName
    global Echos
    global ModelType
    if len(rec) != 0:
        Dir = server_dir + str(rec['project']['iID'])
        IMG_W = rec['project']['fWidth']
        IMG_H = rec['project']['fHeighth']
        ModelName = rec['project']['strPN']
        BATCH_SIZE = rec['Mode']['iBatchSize']
        Echos = rec['Mode']['iEchos']
        ModelType = rec['Mode']['strName']
    else:
        Dir = train_dir + '4'
    train, train_label = get_file(Dir, labels)
    # 训练数据及标签
    X_train, X_val, Y_train, Y_val = get_batch(train, train_label, IMG_W, IMG_H)
    if ModelType == 'SEQModel':
        seq_model(X_train, X_val, Y_train, Y_val , IMG_W, IMG_H, BATCH_SIZE, Dir, ModelName, Echos)
    elif ModelType == 'VGG19':
        VGG19Model(X_train, X_val, Y_train, Y_val , IMG_W, IMG_H, BATCH_SIZE, Dir, ModelName, Echos)

关键代码,以贯序模型为例。

import tensorflow.compat.v1 as tf1
import numpy as np
import matplotlib.pyplot as plt

ModelPath = "D:\\AI_Vision\\AIServer\\AICore\BaseModel\\"
Msg = ''


def SEQModel_Msg(MsgEx):
    global Msg
    if MsgEx == Msg:
        Msg = ''
    return Msg


def freeze_session(session, keep_var_names=None, output_names=None, clear_devices=True):
    graph = session.graph
    with graph.as_default():
        output_names = output_names or []
        print("output_names", output_names)
        input_graph_def = graph.as_graph_def()
        print("len node1", len(input_graph_def.node))
        if clear_devices:
            for node in input_graph_def.node:
                node.device = ""
        frozen_graph = tf1.graph_util.convert_variables_to_constants(session, input_graph_def,
                                                                     output_names)

        outgraph = tf1.graph_util.remove_training_nodes(frozen_graph)  # 云掉与推理无关的内容
        print("##################################################################")
        for node in outgraph.node:
            print('node:', node.name)
        print("len node1", len(outgraph.node))
        return outgraph


def showCurve(dir,history):
    fig = plt.figure()  # 新建一张图
    if 'accuracy' in dict(history.history).keys():
        plt.plot(history.history['accuracy'], label='training acc')
        plt.plot(history.history['val_accuracy'], label='val acc')
    else:
        plt.plot(history.history['acc'], label='training acc')
        plt.plot(history.history['val_acc'], label='val acc')

    plt.title('model accuracy')
    plt.ylabel('accuracy')
    plt.xlabel('epoch')
    plt.legend(loc='lower right')
    fig.savefig(dir + "\\TrainModel" + "\\accuracy.png")

    fig = plt.figure()
    plt.plot(history.history['loss'], label='training loss')
    plt.plot(history.history['val_loss'], label='val loss')
    plt.title('model loss')
    plt.ylabel('loss')
    plt.xlabel('epoch')
    plt.legend(loc='upper right')
    fig.savefig(dir + "\\TrainModel" + "\\loss.png")


def seq_model(X_train, X_val, Y_train, Y_val, img_width, img_height, batch_size, dir, modelName,echos):
    global Msg
    nb_train_samples = len(X_train)
    nb_validation_samples = len(X_val)

    model = tf1.keras.models.Sequential()

    model.add(tf1.keras.layers.Conv2D(32, (3, 3), input_shape=(img_width, img_height, 3)))
    model.add(tf1.keras.layers.Activation('relu'))
    model.add(tf1.keras.layers.MaxPooling2D(pool_size=(2, 2)))

    model.add(tf1.keras.layers.Conv2D(32, (3, 3)))
    model.add(tf1.keras.layers.Activation('relu'))
    model.add(tf1.keras.layers.MaxPooling2D(pool_size=(2, 2)))

    model.add(tf1.keras.layers.Conv2D(64, (3, 3)))
    model.add(tf1.keras.layers.Activation('relu'))
    model.add(tf1.keras.layers.MaxPooling2D(pool_size=(2, 2)))

    model.add(tf1.keras.layers.Flatten())
    model.add(tf1.keras.layers.Dense(64))
    model.add(tf1.keras.layers.Activation('relu'))
    model.add(tf1.keras.layers.Dropout(0.5))
    model.add(tf1.keras.layers.Dense(1))
    model.add(tf1.keras.layers.Activation('sigmoid'))

    model.compile(loss='binary_crossentropy',
                  optimizer='rmsprop',
                  metrics=['accuracy'])
    #model.summary()
    stringlist = []
    model.summary(print_fn=lambda x: stringlist.append(x))
    Msg += "\n".join(stringlist)+"\r\n"
    train_datagen = tf1.keras.preprocessing.image.ImageDataGenerator(
        rescale=1. / 255,
        shear_range=0.2,
        zoom_range=0.2,
        horizontal_flip=True)

    val_datagen = tf1.keras.preprocessing.image.ImageDataGenerator(
        rescale=1. / 255,
        shear_range=0.2,
        zoom_range=0.2,
        horizontal_flip=True)

    train_generator = train_datagen.flow(np.array(X_train), Y_train, batch_size=batch_size)
    validation_generator = val_datagen.flow(np.array(X_val), Y_val, batch_size=batch_size)
    history = model.fit_generator(
        train_generator,
        steps_per_epoch=nb_train_samples // batch_size,
        epochs=echos,
        validation_data=validation_generator,
        validation_steps=nb_validation_samples // batch_size
    )
    Msg += "loss" + str(history.history['loss']) + "\r\n"
    Msg += "val_loss" + str(history.history['val_loss']) + "\r\n"
    if 'accuracy' in dict(history.history).keys():
        Msg += "accuracy" + str(history.history['accuracy']) + "\r\n"
        Msg += "val_accuracy" + str(history.history['val_accuracy']) + "\r\n"
    else:
        Msg += "accuracy" + str(history.history['acc']) + "\r\n"
        Msg += "val_accuracy" + str(history.history['val_acc']) + "\r\n"

    model.save_weights(ModelPath+'model_wieghts.h5')
    model.save(ModelPath+'model_keras.h5')

    Msg += 'input is :'+ model.input.name + "\r\n" + 'output is:'+ model.output.name + "\r\n"
    #saved_model = tf.keras.models.load_model("D:\AI Vision\AICore\model_wieghts.h5")
    #saved_model.save("D:\AI Vision\AICore\mode_test")

    tf1.reset_default_graph()
    tf1.keras.backend.set_learning_phase(0)  # 调用模型前一定要执行该命令
    tf1.disable_v2_behavior()  # 禁止tensorflow2.0的行为
    network = tf1.keras.models.load_model(ModelPath+'model_keras.h5')
    frozen_graph = freeze_session(tf1.keras.backend.get_session(),
                                  output_names=[out.op.name for out in network.outputs])
    tf1.train.write_graph(frozen_graph, dir+"\\TrainModel", modelName+".pb", as_text=False)
    showCurve(dir,history)
    Msg = 'Finish'

前端代码

创建一个Post类发送和获取后台信息。

using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;
using System.Threading.Tasks;
using System.Net;
using System.IO;
using Newtonsoft.Json;

namespace LT.Device
{
    class PostMan
    {
        public string PostWebRequest(string postUrl, string paramData, Encoding dataEncode)
        {
            string ret = string.Empty;
            try
            {
                byte[] byteArray = dataEncode.GetBytes(paramData); //转化
                HttpWebRequest webReq = (HttpWebRequest)WebRequest.Create(new Uri(postUrl));
                webReq.Method = "POST";
                webReq.ContentType = "application/x-www-form-urlencoded";
                webReq.Timeout = 5000;
                webReq.ContentLength = byteArray.Length;
                Stream newStream = webReq.GetRequestStream();
                newStream.Write(byteArray, 0, byteArray.Length);//写入参数
                newStream.Close();
                HttpWebResponse response = (HttpWebResponse)webReq.GetResponse();
                StreamReader sr = new StreamReader(response.GetResponseStream(), Encoding.Default);
                ret = sr.ReadToEnd();

                sr.Close();
                response.Close();
                newStream.Close();
            }
            catch (Exception ex)
            {
                return ex.Message;
            }
            return ret;
        }

        public string GetWebRequest(string postUrl, string paramData, Encoding dataEncode)
        {
            string ret = string.Empty;
            try
            {
                byte[] byteArray = dataEncode.GetBytes(paramData); //转化
                HttpWebRequest webReq = (HttpWebRequest)WebRequest.Create(new Uri(postUrl));
                webReq.Method = "GET";
                webReq.ContentType = "application/x-www-form-urlencoded";
                webReq.Timeout = 5000;
                webReq.ContentLength = byteArray.Length;
                Stream newStream = webReq.GetRequestStream();
                newStream.Write(byteArray, 0, byteArray.Length);//写入参数
                newStream.Close();
                HttpWebResponse response = (HttpWebResponse)webReq.GetResponse();
                StreamReader sr = new StreamReader(response.GetResponseStream(), Encoding.Default);
                ret = sr.ReadToEnd();

                sr.Close();
                response.Close();
                newStream.Close();
            }
            catch (Exception ex)
            {
                return ex.Message;
            }
            return ret;
        }

    }
}

通过一个线程调取上面编写的类来发送信息

#region Method StepsCode
        protected override void StepsCode(ref DataBaseStruct dataDirector, ref ChainsElement chains)
        {
        	PostMan man = new PostMan();
        	PostParamStruct param = new PostParamStruct();
            switch(numberActiveStep)
            {
                case 0://initial
                    RootDir = BasePath + dataDirector.CurrProject.iID.ToString();
                    viewModeLocal.UpdateMessage("",true);
                    GoToNextStep(0.2);
                    break;
                case 1: //Check Model Exist
                    if(1==CopyFolder(RootDir, Properties.Settings.Default.FTPPath))
                    {
                        GoToNextStep(0.1);
                    }
                    else
                        GoToStep(9,0.1);
                    break;
                case 2://SendPost                                                    
                    param.Event = "TRAIN";
                    param.project = dataDirector.CurrProject;
                    param.Labels = dataDirector.LabelCatetorys;
                    param.Mode = dataDirector.CurrMode;
                    strPram = JsonConvert.SerializeObject(param);
                    try
                    {
                        string ret = man.PostWebRequest("http://127.0.0.1:8000/ai/train/", strPram, Encoding.UTF8);
                        JObject jo = JObject.Parse(ret);
                        viewModeLocal.UpdateMessage(strPram);
                        GoToNextStep(0.5);
                    }
                    catch(Exception e){
                        viewModeLocal.UpdateMessage(e.Message.ToString());
                        GoToStep(9, 0.5);
                    }
              
                    break;
                case 3://Check Msg 
                    param.Event = "CHECK";
                    strPram = JsonConvert.SerializeObject(param);
                    try
                    {
                        string ret = man.PostWebRequest("http://127.0.0.1:8000/ai/train/", strPram, Encoding.UTF8);
                        JObject jo = JObject.Parse(ret);
                        ret = jo["return"].ToString();
                        if (ret == "Finish")
                            GoToStep(10, 0.5);
                        else
                        {
                            if (ret != "")
                                viewModeLocal.UpdateMessage(jo["return"].ToString());
                            GoToNextStep(0.5);
                        }
                    }
                    catch (Exception e)
                    {
                        viewModeLocal.UpdateMessage(e.Message.ToString());
                        GoToStep(9, 0.5);
                    }
                    break;
                case 4://...
                    GoToNextStep(0.1);
                    break;
                case 5://...
                    GoToNextStep(0.1);
                    break;
                case 6://...
                    GoToNextStep(0.5);
                    break;
                case 7://Check Msg 
                    GoToNextStep(0.5);
                    break;
                case 8://...
                    GoToStep(3, 0.5);
                    break;
                case 9://...
                    this.Enable = false;
                    break;

                case 10:
                    this.Enable = false;
                    break;

            }
        }
        #endregion

最终效果

c# 显示训练的过程信息
使用C# 调用tensorflow和keras 训练样本_第1张图片显示损失-精度图谱
使用C# 调用tensorflow和keras 训练样本_第2张图片

你可能感兴趣的:(C#,人工智能,神经网络,tensorflow,机器学习,c#)