华为云AI应用在线服务部署custom_service写法

目录

  • 前言
  • 例程
  • 分析

前言

在参加华为云AI大赛的过程中,经常遇到的就是要编写custom_service.py, 今天来总结一下编写这个文件的一些注意事项。下面给出了一个例程进行分析。

例程

# !/usr/bin/python
# -*- coding: UTF-8 -*-

import json
try:
    from model_service.pytorch_model_service import PTServingBaseService
except:
    PTServingBaseService = object
    
import torch
import os
import numpy as np
from model import MyModel
import pathlib


class RadioMapService(PTServingBaseService):
    def __init__(self, model_name, model_path):
        print('--------------------init--------------------')
        self.model_name = model_name
        self.model_path = model_path
        print(f"model_name:{model_name}")
        print(f"model_path:{model_path}")
		
		# 识别设备
		device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
        
        # 创建模型
        self.model = MyModel()
        
        # 加载模型参数
        params = torch.load(model_path, map_location=device)
		self.model.load_state_dict(params)
		
		# 设置推理状态
		self.model.eval()
        return 

    def _preprocess(self, data):
    	# 预处理步骤,一般比赛会提供这个函数的写法
        print('--------------------preprocess--------------------')
        preprocessed_data = {}
        for file_name, file_content in data['all_data'].items():
        	# 如果是图像文件,那么file_content一般是可读对象,可以用PIL直接读取
            print(f"file_name={file_name}, file_content={file_content}")
            self.file_name = file_name
            data_record = []
            lines = file_content.read().decode()
            lines = lines.split('\n')
            for line in lines:  # read all instance in the .txt
                if len(line) > 1:
                    data_record.append(json.loads(line))
            preprocessed_data[file_name] = data_record
        return preprocessed_data

    def _inference(self, data):
        # print('--------------------inference----------------------')
        device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
        data_tmp = data[self.file_name]
        data_tmp = data_tmp[0]
        X = torch.Tensor(data_tmp['pos'])
        with torch.no_grad():
        	result = self.model(X)
        # 处理输出
        result = result.cpu().numpy()
        results_fin = {'pred': result[0], 'bbox': result[1]}

        ##------------------------------------END-----------------------------------------##

        print(f'result_fin={results_fin}')
        return results_fin

    def _postprocess(self, data):
        print('--------------------postprocess--------------------')
        
        return data


if __name__ == "__main__":
    pass 
    # test the _inference api

分析

在上面的例程中,其实最需要注意的就是初始化参数 model_path。这个参数非常重要。
在obs桶中存储一个文件名为model.pth的文件,然后应用会自动检测,把xxx/xxx/mode.pth 当作model_path 传入 init(model_name, model_path)。
所以无论你的模型的名称是不是model.pth, 你都需要这么一个文件来保证model_path可以传入service对象。如果你还需要加载其他文件,那么可以根据model_path对应的路径修改。model_name 参数几乎用不到,所以不用关注。

你可能感兴趣的:(AI算法常用技术,人工智能,华为云,python)