MINIST手写字识别Java后端调用python模型

目录

调用过程:

Java服务端代码:

python端训练模型的代码:

python socket传输代码:


一般都会用python写机器学习深度学习模型,实际中用Java写后端多一点,python写后端还是少一点

那么就需要把Java和python串起来

最简单的办法就是通过socket传输,Java和python起两个进程互相通过IP+端口号的模式进行调用

调用过程:

MINIST手写字识别Java后端调用python模型_第1张图片

1.python训练模型

2.Java读取用户输入

3.python识别后将结果返回给Java后端

Java服务端代码:

import java.io.*;
import java.net.Socket;

public class PythonSocket {
    private static final String HOST = "192.168.31.216";
    private static final int PORT = 12345;

    public Object remoteCall(String content){
        // 访问服务进程的套接字
        Socket socket = null;
        try {
            // 初始化套接字,设置访问服务的主机和进程端口号,HOST是访问python进程的主机名称,可以是IP地址或者域名,PORT是python进程绑定的端口号
            socket = new Socket(HOST,PORT);
            // 获取输出流对象
            OutputStream os = socket.getOutputStream();
            PrintStream out = new PrintStream(os);

            // 告诉服务进程,内容发送完毕,可以开始处理
            System.out.println("======发送完毕=====");
            // 获取服务进程的输入流
            InputStream is = socket.getInputStream();
            BufferedReader br = new BufferedReader(new InputStreamReader(is,"utf-8"));
            String tmp = null;
            StringBuilder sb = new StringBuilder();
            // 读取内容
            while((tmp=br.readLine())!=null)
                sb.append(tmp).append('\n');
            // 解析结果
            System.out.println("获取的结果为: "+sb.toString());
            return sb.toString();
        } catch (IOException e) {
            e.printStackTrace();
        } finally {
            try {if(socket!=null) socket.close();} catch (IOException e) {}
            System.out.println("远程接口调用结束.");
        }
        return null;
    }
}

python端训练模型的代码:

在这里只需要完成模型的训练和预测,将需要返回给Java的返回值封装成函数,一会在python socket中调用返回给Java就可以了

import paddle
import matplotlib.pyplot as plt
import paddle.nn.functional as F
from paddle.vision.transforms import Compose, Normalize
def train():
    transform = Compose([Normalize(mean=[127.5],
                                   std=[127.5],
                                   data_format='CHW')])
    # 使用transform对数据集做归一化
    print('download training data and load training data')
    train_dataset = paddle.vision.datasets.MNIST(mode='train', transform=transform)
    test_dataset = paddle.vision.datasets.MNIST(mode='test', transform=transform)
    print('load finished')

    train_data0, train_label_0 = train_dataset[0][0],train_dataset[0][1]
    train_data0 = train_data0.reshape([28,28])
    plt.figure(figsize=(2,2))
    plt.imshow(train_data0, cmap=plt.cm.binary)
    print('train_data0 label is: ' + str(train_label_0))

    class LeNet(paddle.nn.Layer):
        def __init__(self):
            super(LeNet, self).__init__()
            self.conv1 = paddle.nn.Conv2D(in_channels=1, out_channels=6, kernel_size=5, stride=1, padding=2)
            self.max_pool1 = paddle.nn.MaxPool2D(kernel_size=2,  stride=2)
            self.conv2 = paddle.nn.Conv2D(in_channels=6, out_channels=16, kernel_size=5, stride=1)
            self.max_pool2 = paddle.nn.MaxPool2D(kernel_size=2, stride=2)
            self.linear1 = paddle.nn.Linear(in_features=16*5*5, out_features=120)
            self.linear2 = paddle.nn.Linear(in_features=120, out_features=84)
            self.linear3 = paddle.nn.Linear(in_features=84, out_features=10)

        def forward(self, x):
            x = self.conv1(x)
            x = F.relu(x)
            x = self.max_pool1(x)
            x = F.relu(x)
            x = self.conv2(x)
            x = self.max_pool2(x)
            x = paddle.flatten(x, start_axis=1,stop_axis=-1)
            x = self.linear1(x)
            x = F.relu(x)
            x = self.linear2(x)
            x = F.relu(x)
            x = self.linear3(x)
            return x
    from paddle.metric import Accuracy
    model = paddle.Model(LeNet())   # 用Model封装模型
    optim = paddle.optimizer.Adam(learning_rate=0.001, parameters=model.parameters())

    # 配置模型
    model.prepare(
        optim,
        paddle.nn.CrossEntropyLoss(),
        Accuracy()
        )
    # 训练模型
    model.fit(train_dataset,
              epochs=1,
              batch_size=512,
              verbose=1
              )

    predict_acc = model.evaluate(test_dataset, batch_size=64, verbose=1)
    print("predict_acc: ",predict_acc['acc'])
    return predict_acc['acc']
    predict_label = model.predict(test_dataset, batch_size=64)
    print("predict_label :",predict_label[0])



if __name__ == '__main__':
    train()

python socket传输代码:

在这里面调用上面的train()函数就可以了

import socket
import threading
from mnist_train import train

def main():
    # 创建服务器套接字
    serversocket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
    # 获取本地主机名称
    host = socket.gethostname()
    # 设置一个端口
    port = 12345
    # 将套接字与本地主机和端口绑定
    serversocket.bind((host, port))
    # 设置监听最大连接数
    serversocket.listen(5)
    # 获取本地服务器的连接信息
    myaddr = serversocket.getsockname()
    print("服务器地址:%s" % str(myaddr))
    # 循环等待接受客户端信息
    while True:
        # 获取一个客户端连接
        clientsocket, addr = serversocket.accept()
        print("连接地址:%s" % str(addr))
        try:
            t = ServerThreading(clientsocket)  # 为每一个请求开启一个处理线程
            t.start()
            pass
        except Exception as identifier:
            print(identifier)
            pass
        pass
    serversocket.close()
    pass


class ServerThreading(threading.Thread):
    # words = text2vec.load_lexicon()
    def __init__(self, clientsocket, recvsize=1024 * 1024, encoding="utf-8"):
        threading.Thread.__init__(self)
        self._socket = clientsocket
        self._recvsize = recvsize
        self._encoding = encoding
        pass

    def run(self):
        print("开启线程.....")
        try:
            # 接受数据
            msg = ''
            # while True:
            #     # 读取recvsize个字节
            #     rec = self._socket.recv(self._recvsize)
            #     # 解码
            #     msg += rec.decode(self._encoding)
            #     # 文本接受是否完毕,因为python socket不能自己判断接收数据是否完毕,
            #     # 所以需要自定义协议标志数据接受完毕
            #     if msg.strip().endswith('over'):
            #         msg = msg[:-4]
            #         break
            # 解析json格式的数据
            # re = json.loads(msg)
            # 调用神经网络模型处理请求
            res = train()
            print("res为: ",res)
            sendmsg = str(res)
            # 发送数据
            self._socket.send(("%s" % sendmsg).encode(self._encoding))
            pass
        except Exception as identifier:
            self._socket.send("500".encode(self._encoding))
            print(identifier)
            pass
        finally:
            self._socket.close()
        print("任务结束.....")

        pass

    def __del__(self):

        pass

if __name__ == '__main__':
    main()

你可能感兴趣的:(socket,java,java,python)