tensorflow框架下,多进程model.predict(x)无响应/暂停/无输出

尝试多进程在tensorflow下运行函数,函数功能中包括model.predict(x),但是在windows环境中运行正常,linux中无法运行,会出现程序无响应,调试后发现在子进程函数中model.predict(x)执行时程序出现问题。

参考代码如下,注意到此时model声明在__main__父进程中,并作为参数传给子进程。另一种情况是在子进程中使用全局变量model。或者可能存在从其他模块中引入的有关model变量。

from keras.preprocessing import image
from keras.models import load_model


import numpy as np
import os,gc,multiprocessing

def load_image(img_path):
    img = image.load_img(img_path, target_size=(28, 28), color_mode="grayscale")
    input_img_data = image.img_to_array(img)
    input_img_data = input_img_data.reshape(1, 28, 28, 1)

    input_img_data = input_img_data.astype('float32')
    input_img_data /= 255
    
    return input_img_data


def run(pict_path,model):
    #global model
    preprocessed_input = load_image(pict_path)
    
    predictions=model.predict(preprocessed_input)
    top_1 = (np.argmax(predictions),np.max(predictions))
    print('Predicted class:')
    print('%s with probability %.2f' % (top_1[0], top_1[1]))


if __name__=='__main__':
    picts=os.listdir("./mnist")
    model = load_model('Model3.h5')# 注意此时model声明在子进程外部(父进程声明后作为参数传入)
    for whichone,p in enumerate(picts):#循环读取图片预测
        pict_path="./mnist/"+picts[0]
        run_p=multiprocessing.Process(target=run,args=[pict_path,model])
        run_p.start()
        run_p.join()
        gc.collect()

该代码在linux(ubuntu)中应该无法运行,但是在windows上应该能运行。

解决方法如下

将model声明写在子进程的函数中,才能正常运行。

代码如下,将model声明在子进程内部。现在只测试过load_model()函数,如果直接调用变量尚未尝试。

# -*- coding: utf-8 -*-  
from keras.preprocessing import image
from keras.models import load_model


import numpy as np
import os,gc

def load_image(img_path):
    img = image.load_img(img_path, target_size=(28, 28), color_mode="grayscale")
    input_img_data = image.img_to_array(img)
    input_img_data = input_img_data.reshape(1, 28, 28, 1)

    input_img_data = input_img_data.astype('float32')
    input_img_data /= 255
    # input_img_data = preprocess_input(input_img_data)  # final input shape = (1,224,224,3)
    return input_img_data

import multiprocessing,gc

def run(pict_path):
    model = load_model('Model3.h5')#model声明在子进程内部
    
    preprocessed_input = load_image(pict_path)#"./minst/729_4.png")
    
    predictions=model.predict(preprocessed_input)
    top_1 = (np.argmax(predictions),np.max(predictions))
    print('Predicted class:')
    print('%s with probability %.2f' % (top_1[0], top_1[1]))


if __name__=='__main__':
    picts=os.listdir("./mnist")
    
    for whichone,p in enumerate(picts):#for p in x4test:
        #print(whichone)
        pict_path="./mnist/"+picts[0]
        run_p=multiprocessing.Process(target=run,args=[pict_path])
        run_p.start()
        run_p.join()
        gc.collect()

当然,主要问题的解决思路参考https://stackoverflow.com/questions/70496446/parallelizing-keras-model-predict-using-multiprocessing

BTW,讲述干货和知识的收费能理解,怎么有人的连解决bug的文章怎么还要收费的啊?离谱奥

成功解决算法模型在预测的时候model.predict(X_test)其预测功能戛然而止且代码无bug的无提示的无法继续向下运行代码而在当前直接退出

你可能感兴趣的:(深度学习,人工智能,python,tensorflow,keras)