记录一下使用tensorflow-serving部署图像分割的过程
changeH5tosavedModel.py
import tensorflow as tf
from nets.unet import Unet as unet
if __name__ == '__main__':
model = unet((512, 512, 3), 2, 'vgg')
model.load_weights('EP100-loss0.196-valoss0.284.h5')
tf.saved_model.save(model, "test/1")
docker run -p 8501:8501 --mount type=bind,source=E:\projectFiles\standard\unetV1/test,target=/models/unetV1 -e MODEL_NAME=unetV1 -t tensorflow/serving
gpu(目前只在linux下测试了,因为win10似乎安装不能nvidia-docker):
首先安装必要的东西:
docker run --rm --gpus all nvidia/cuda:11.0-base nvidia-smi
然后拉取tensorflow-serving gpu镜像:
docker pull tensorflow/serving:latest-gpu
最后开启模型服务
docker run --gpus all -p 8501:8501 --mount type=bind,source=/home/hbli/pythonFiles/unetV1/test,target=/models/unetV1 -e MODEL_NAME=unetV1 -t tensorflow/serving:latest-gpu
MODEL_NAME是自己定的,target最后的unetV1的名字和MODEL_NAME一致,source是被部署的模型所在的文件夹。其他都一样。
httpClient.py
""" 图像分割的serving """
import cv2
import numpy as np
import requests
import json
import time
from PIL import Image
import colorsys
import matplotlib.pyplot as plt
import os
def resize_image(image, size):
""" 等比例resize """
iw, ih = image.size
w, h = size
scale = min(w/iw, h/ih)
nw = int(iw*scale)
nh = int(ih*scale)
image = image.resize((nw,nh), Image.BICUBIC)
new_image = Image.new('RGB', size, (128,128,128))
new_image.paste(image, ((w-nw)//2, (h-nh)//2))
return new_image, nw, nh
def preprocess_input(image):
image = image / 127.5 - 1
return image
input_shape = (512,512) # 与训练的时候一致
num_classes = 2 # 类别+1
def preProcessing(filepath):
inputs = cv2.imread(filepath)
old_img = Image.open(filepath)
h,w = inputs.shape[0],inputs.shape[1]
# print(f'初始图像size: {h},{w}')
""" 数据预处理 """
image_data, nw, nh = resize_image(old_img, (input_shape[1], input_shape[0]))
image_data = np.expand_dims(preprocess_input(np.array(image_data, np.float32)), 0)
return old_img,(h,w),(nw,nh),image_data
def mainProcess():
start = time.time()
####--------------------------核心代码----------------------------------------####
""" REST API端口 """
url = 'http://localhost:8501/v1/models/unetV1:predict'
data = json.dumps({'inputs':image_data.tolist()}) # 要求输入的数据是json格式
response = requests.post(url,data=data)
result = json.loads(response.content)
outputs = result['outputs'][0]
output_array = np.array(outputs) # list转numpy数组
####--------------------------核心代码---------------------------------------####
print(f'花费时间:{time.time()-start:.2f}s')
# print(type(output_array))
return output_array
def postProcessing():
""" 对预测结果进行后处理 """
# resize回图像原始的大小
pr = cv2.resize(output_array, (w, h), interpolation = cv2.INTER_LINEAR)
pr = pr.argmax(axis=-1) # 取出每一个像素点的种类
seg_img = np.zeros((np.shape(pr)[0], np.shape(pr)[1], 3))
if num_classes <= 21:
colors = [ (0, 0, 0), (128, 0, 0), (0, 128, 0), (128, 128, 0), (0, 0, 128), (128, 0, 128), (0, 128, 128),
(128, 128, 128), (64, 0, 0), (192, 0, 0), (64, 128, 0), (192, 128, 0), (64, 0, 128), (192, 0, 128),
(64, 128, 128), (192, 128, 128), (0, 64, 0), (128, 64, 0), (0, 192, 0), (128, 192, 0), (0, 64, 128),
(128, 64, 12)]
else:
hsv_tuples = [(x / num_classes, 1., 1.) for x in range(num_classes)]
colors = list(map(lambda x: colorsys.hsv_to_rgb(*x), hsv_tuples))
colors = list(map(lambda x: (int(x[0] * 255), int(x[1] * 255), int(x[2] * 255)), colors))
for c in range(num_classes):
seg_img[:,:,0] += ((pr[:,: ] == c )*(colors[c][0] )).astype('uint8')
seg_img[:,:,1] += ((pr[:,: ] == c )*(colors[c][1] )).astype('uint8')
seg_img[:,:,2] += ((pr[:,: ] == c )*(colors[c][2] )).astype('uint8')
resultImage = Image.fromarray(np.uint8(seg_img))
image = Image.blend(old_img,resultImage,0.7)
return image
def saveAndShow(image):
savename = os.path.basename(filepath)[:-4]+"httpResult.jpg"
savePath = 'servingOut/'
if not os.path.exists(savePath):
os.mkdir(savePath)
image.save(savePath+savename)
plt.title(os.path.basename(filepath))
plt.imshow(image)
plt.show()
if __name__ == '__main__':
while True:
try:
filepath = input('请输入待预测图像路径(输入c退出): ')
if filepath == 'c':
break
old_img,(h,w),(nw,nh),image_data = preProcessing(filepath=filepath)
output_array = mainProcess()
image = postProcessing()
saveAndShow(image)
except Exception as e:
print(e)
continue