有些样本比较小,算力要求不高的项目我们可以使用个人电脑的CPU来进行学习和培训。工业自动化或者一些特殊场合,有时我们习惯于用C#等做人机交互的前端。对于这样的项目我们如何来调取tensorflow 或者keras来培训模型呢?结合之前发布的 C#来部署tensorflow的培训模型 我们就可以利用C# 完成从图片加载、分类、训练到部署的所有操作。完成一个完整的AI应用项目。
本示例应用 VS 2015 Python 3.6, django 2.1, tensorflow 2.0。
实现方法
使用python构建一个django 后台,通过后台对样本进行学习训练,生成模板。C#通过system.net来操作http协议和后台传递数据和通讯。
urlpatterns = [
url(r'^train/$', views.train_start, name='train_start'),
]
务必新建一个线程来开始训练样本。因为对于http协议,最好尽快响应,避免占用路由资源。
def train_start(request):
if request.method == 'POST':
rev = json.loads(request.body)
elif request.method == 'GET':
rev = json.loads(request.GET.get('data'))
else:
return HttpResponse('access deny')
if rev['Event'] == 'TRAIN':
threading.Thread(target=Manager, args=[rev, rev['Event']]).start()
res = {'return': 0}
else:
res = Manager(rev, rev['Event'])
return HttpResponse(json.dumps(res))
任务分配
from AICore.MainTest.Train import Train
from AICore.MainTest.Msg import GetMsg
def Manager (rec, event):
functions = {'TRAIN': Train,
"CHECK": GetMsg}
func = functions[event]
return func(rec)
调取训练
def Train(rec):
global train_dir
global labels
global IMG_W
global IMG_H
global BATCH_SIZE
global ModelName
global Echos
global ModelType
if len(rec) != 0:
Dir = server_dir + str(rec['project']['iID'])
IMG_W = rec['project']['fWidth']
IMG_H = rec['project']['fHeighth']
ModelName = rec['project']['strPN']
BATCH_SIZE = rec['Mode']['iBatchSize']
Echos = rec['Mode']['iEchos']
ModelType = rec['Mode']['strName']
else:
Dir = train_dir + '4'
train, train_label = get_file(Dir, labels)
# 训练数据及标签
X_train, X_val, Y_train, Y_val = get_batch(train, train_label, IMG_W, IMG_H)
if ModelType == 'SEQModel':
seq_model(X_train, X_val, Y_train, Y_val , IMG_W, IMG_H, BATCH_SIZE, Dir, ModelName, Echos)
elif ModelType == 'VGG19':
VGG19Model(X_train, X_val, Y_train, Y_val , IMG_W, IMG_H, BATCH_SIZE, Dir, ModelName, Echos)
关键代码,以贯序模型为例。
import tensorflow.compat.v1 as tf1
import numpy as np
import matplotlib.pyplot as plt
ModelPath = "D:\\AI_Vision\\AIServer\\AICore\BaseModel\\"
Msg = ''
def SEQModel_Msg(MsgEx):
global Msg
if MsgEx == Msg:
Msg = ''
return Msg
def freeze_session(session, keep_var_names=None, output_names=None, clear_devices=True):
graph = session.graph
with graph.as_default():
output_names = output_names or []
print("output_names", output_names)
input_graph_def = graph.as_graph_def()
print("len node1", len(input_graph_def.node))
if clear_devices:
for node in input_graph_def.node:
node.device = ""
frozen_graph = tf1.graph_util.convert_variables_to_constants(session, input_graph_def,
output_names)
outgraph = tf1.graph_util.remove_training_nodes(frozen_graph) # 云掉与推理无关的内容
print("##################################################################")
for node in outgraph.node:
print('node:', node.name)
print("len node1", len(outgraph.node))
return outgraph
def showCurve(dir,history):
fig = plt.figure() # 新建一张图
if 'accuracy' in dict(history.history).keys():
plt.plot(history.history['accuracy'], label='training acc')
plt.plot(history.history['val_accuracy'], label='val acc')
else:
plt.plot(history.history['acc'], label='training acc')
plt.plot(history.history['val_acc'], label='val acc')
plt.title('model accuracy')
plt.ylabel('accuracy')
plt.xlabel('epoch')
plt.legend(loc='lower right')
fig.savefig(dir + "\\TrainModel" + "\\accuracy.png")
fig = plt.figure()
plt.plot(history.history['loss'], label='training loss')
plt.plot(history.history['val_loss'], label='val loss')
plt.title('model loss')
plt.ylabel('loss')
plt.xlabel('epoch')
plt.legend(loc='upper right')
fig.savefig(dir + "\\TrainModel" + "\\loss.png")
def seq_model(X_train, X_val, Y_train, Y_val, img_width, img_height, batch_size, dir, modelName,echos):
global Msg
nb_train_samples = len(X_train)
nb_validation_samples = len(X_val)
model = tf1.keras.models.Sequential()
model.add(tf1.keras.layers.Conv2D(32, (3, 3), input_shape=(img_width, img_height, 3)))
model.add(tf1.keras.layers.Activation('relu'))
model.add(tf1.keras.layers.MaxPooling2D(pool_size=(2, 2)))
model.add(tf1.keras.layers.Conv2D(32, (3, 3)))
model.add(tf1.keras.layers.Activation('relu'))
model.add(tf1.keras.layers.MaxPooling2D(pool_size=(2, 2)))
model.add(tf1.keras.layers.Conv2D(64, (3, 3)))
model.add(tf1.keras.layers.Activation('relu'))
model.add(tf1.keras.layers.MaxPooling2D(pool_size=(2, 2)))
model.add(tf1.keras.layers.Flatten())
model.add(tf1.keras.layers.Dense(64))
model.add(tf1.keras.layers.Activation('relu'))
model.add(tf1.keras.layers.Dropout(0.5))
model.add(tf1.keras.layers.Dense(1))
model.add(tf1.keras.layers.Activation('sigmoid'))
model.compile(loss='binary_crossentropy',
optimizer='rmsprop',
metrics=['accuracy'])
#model.summary()
stringlist = []
model.summary(print_fn=lambda x: stringlist.append(x))
Msg += "\n".join(stringlist)+"\r\n"
train_datagen = tf1.keras.preprocessing.image.ImageDataGenerator(
rescale=1. / 255,
shear_range=0.2,
zoom_range=0.2,
horizontal_flip=True)
val_datagen = tf1.keras.preprocessing.image.ImageDataGenerator(
rescale=1. / 255,
shear_range=0.2,
zoom_range=0.2,
horizontal_flip=True)
train_generator = train_datagen.flow(np.array(X_train), Y_train, batch_size=batch_size)
validation_generator = val_datagen.flow(np.array(X_val), Y_val, batch_size=batch_size)
history = model.fit_generator(
train_generator,
steps_per_epoch=nb_train_samples // batch_size,
epochs=echos,
validation_data=validation_generator,
validation_steps=nb_validation_samples // batch_size
)
Msg += "loss" + str(history.history['loss']) + "\r\n"
Msg += "val_loss" + str(history.history['val_loss']) + "\r\n"
if 'accuracy' in dict(history.history).keys():
Msg += "accuracy" + str(history.history['accuracy']) + "\r\n"
Msg += "val_accuracy" + str(history.history['val_accuracy']) + "\r\n"
else:
Msg += "accuracy" + str(history.history['acc']) + "\r\n"
Msg += "val_accuracy" + str(history.history['val_acc']) + "\r\n"
model.save_weights(ModelPath+'model_wieghts.h5')
model.save(ModelPath+'model_keras.h5')
Msg += 'input is :'+ model.input.name + "\r\n" + 'output is:'+ model.output.name + "\r\n"
#saved_model = tf.keras.models.load_model("D:\AI Vision\AICore\model_wieghts.h5")
#saved_model.save("D:\AI Vision\AICore\mode_test")
tf1.reset_default_graph()
tf1.keras.backend.set_learning_phase(0) # 调用模型前一定要执行该命令
tf1.disable_v2_behavior() # 禁止tensorflow2.0的行为
network = tf1.keras.models.load_model(ModelPath+'model_keras.h5')
frozen_graph = freeze_session(tf1.keras.backend.get_session(),
output_names=[out.op.name for out in network.outputs])
tf1.train.write_graph(frozen_graph, dir+"\\TrainModel", modelName+".pb", as_text=False)
showCurve(dir,history)
Msg = 'Finish'
创建一个Post类发送和获取后台信息。
using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;
using System.Threading.Tasks;
using System.Net;
using System.IO;
using Newtonsoft.Json;
namespace LT.Device
{
class PostMan
{
public string PostWebRequest(string postUrl, string paramData, Encoding dataEncode)
{
string ret = string.Empty;
try
{
byte[] byteArray = dataEncode.GetBytes(paramData); //转化
HttpWebRequest webReq = (HttpWebRequest)WebRequest.Create(new Uri(postUrl));
webReq.Method = "POST";
webReq.ContentType = "application/x-www-form-urlencoded";
webReq.Timeout = 5000;
webReq.ContentLength = byteArray.Length;
Stream newStream = webReq.GetRequestStream();
newStream.Write(byteArray, 0, byteArray.Length);//写入参数
newStream.Close();
HttpWebResponse response = (HttpWebResponse)webReq.GetResponse();
StreamReader sr = new StreamReader(response.GetResponseStream(), Encoding.Default);
ret = sr.ReadToEnd();
sr.Close();
response.Close();
newStream.Close();
}
catch (Exception ex)
{
return ex.Message;
}
return ret;
}
public string GetWebRequest(string postUrl, string paramData, Encoding dataEncode)
{
string ret = string.Empty;
try
{
byte[] byteArray = dataEncode.GetBytes(paramData); //转化
HttpWebRequest webReq = (HttpWebRequest)WebRequest.Create(new Uri(postUrl));
webReq.Method = "GET";
webReq.ContentType = "application/x-www-form-urlencoded";
webReq.Timeout = 5000;
webReq.ContentLength = byteArray.Length;
Stream newStream = webReq.GetRequestStream();
newStream.Write(byteArray, 0, byteArray.Length);//写入参数
newStream.Close();
HttpWebResponse response = (HttpWebResponse)webReq.GetResponse();
StreamReader sr = new StreamReader(response.GetResponseStream(), Encoding.Default);
ret = sr.ReadToEnd();
sr.Close();
response.Close();
newStream.Close();
}
catch (Exception ex)
{
return ex.Message;
}
return ret;
}
}
}
通过一个线程调取上面编写的类来发送信息
#region Method StepsCode
protected override void StepsCode(ref DataBaseStruct dataDirector, ref ChainsElement chains)
{
PostMan man = new PostMan();
PostParamStruct param = new PostParamStruct();
switch(numberActiveStep)
{
case 0://initial
RootDir = BasePath + dataDirector.CurrProject.iID.ToString();
viewModeLocal.UpdateMessage("",true);
GoToNextStep(0.2);
break;
case 1: //Check Model Exist
if(1==CopyFolder(RootDir, Properties.Settings.Default.FTPPath))
{
GoToNextStep(0.1);
}
else
GoToStep(9,0.1);
break;
case 2://SendPost
param.Event = "TRAIN";
param.project = dataDirector.CurrProject;
param.Labels = dataDirector.LabelCatetorys;
param.Mode = dataDirector.CurrMode;
strPram = JsonConvert.SerializeObject(param);
try
{
string ret = man.PostWebRequest("http://127.0.0.1:8000/ai/train/", strPram, Encoding.UTF8);
JObject jo = JObject.Parse(ret);
viewModeLocal.UpdateMessage(strPram);
GoToNextStep(0.5);
}
catch(Exception e){
viewModeLocal.UpdateMessage(e.Message.ToString());
GoToStep(9, 0.5);
}
break;
case 3://Check Msg
param.Event = "CHECK";
strPram = JsonConvert.SerializeObject(param);
try
{
string ret = man.PostWebRequest("http://127.0.0.1:8000/ai/train/", strPram, Encoding.UTF8);
JObject jo = JObject.Parse(ret);
ret = jo["return"].ToString();
if (ret == "Finish")
GoToStep(10, 0.5);
else
{
if (ret != "")
viewModeLocal.UpdateMessage(jo["return"].ToString());
GoToNextStep(0.5);
}
}
catch (Exception e)
{
viewModeLocal.UpdateMessage(e.Message.ToString());
GoToStep(9, 0.5);
}
break;
case 4://...
GoToNextStep(0.1);
break;
case 5://...
GoToNextStep(0.1);
break;
case 6://...
GoToNextStep(0.5);
break;
case 7://Check Msg
GoToNextStep(0.5);
break;
case 8://...
GoToStep(3, 0.5);
break;
case 9://...
this.Enable = false;
break;
case 10:
this.Enable = false;
break;
}
}
#endregion