基于tornado BELLE 搭建本地的web 服务

我的github

将BELLE 封装成web 后端服务,采用tornado 框架
import time

import torch
import torch.nn as nn

from gptq import *
from modelutils import *
from quant import *

from transformers import AutoTokenizer
import sys
import json
#import lightgbm as lgb
import logging
import tornado.escape
import tornado.ioloop
import tornado.web
import traceback
DEV = torch.device('cuda:0')

def get_bloom(model):
    import torch
    def skip(*args, **kwargs):
        pass
    torch.nn.init.kaiming_uniform_ = skip
    torch.nn.init.uniform_ = skip
    torch.nn.init.normal_ = skip
    from transformers import BloomForCausalLM
    model = BloomForCausalLM.from_pretrained(model, torch_dtype='auto')
    model.seqlen = 2048
    return model

def load_quant(model, checkpoint, wbits, groupsize):
    from transformers import BloomConfig, BloomForCausalLM 
    config = BloomConfig.from_pretrained(model)
    def noop(*args, **kwargs):
        pass
    torch.nn.init.kaiming_uniform_ = noop 
    torch.nn.init.uniform_ = noop 
    torch.nn.init.normal_ = noop 

    torch.set_default_dtype(torch.half)
    transformers.modeling_utils._init_weights = False
    torch.set_default_dtype(torch.half)
    model = BloomForCausalLM(config)
    torch.set_default_dtype(torch.float)
    model = model.eval()
    layers = find_layers(model)
    for name in ['lm_head']:
        if name in layers:
            del layers[name]
    make_quant(model, layers, wbits, groupsize)

    print('Loading model ...')
    if checkpoint.endswith('.safetensors'):
        from safetensors.torch import load_file as safe_load
        model.load_state_dict(safe_load(checkpoint))
    else:
        model.load_state_dict(torch.load(checkpoint,map_location=torch.device('cuda')))
    model.seqlen = 2048
    print('Done.')

    return model


import argparse
from datautils import *

parser = argparse.ArgumentParser()

parser.add_argument(
    'model', type=str,
    help='llama model to load'
)
parser.add_argument(
    '--wbits', type=int, default=16, choices=[2, 3, 4, 8, 16],
    help='#bits to use for quantization; use 16 for evaluating base model.'
)
parser.add_argument(
    '--groupsize', type=int, default=-1,
    help='Groupsize to use for quantization; default uses full row.'
)
parser.add_argument(
    '--load', type=str, default='',
    help='Load quantized model.'
)

parser.add_argument(
    '--text', type=str,
    help='hello'
)

parser.add_argument(
    '--min_length', type=int, default=10,
    help='The minimum length of the sequence to be generated.'
)

parser.add_argument(
    '--max_length', type=int, default=1024,
    help='The maximum length of the sequence to be generated.'
)

parser.add_argument(
    '--top_p', type=float , default=0.95,
    help='If set to float < 1, only the smallest set of most probable tokens with probabilities that add up to top_p or higher are kept for generation.'
)

parser.add_argument(
    '--temperature', type=float, default=0.8,
    help='The value used to module the next token probabilities.'
)

args = parser.parse_args()

if type(args.load) is not str:
    args.load = args.load.as_posix()

if args.load:
    model = load_quant(args.model, args.load, args.wbits, args.groupsize)
else:
    model = get_bloom(args.model)
    model.eval()
    
model.to(DEV)
tokenizer = AutoTokenizer.from_pretrained(args.model)
print("Human:")

inputs = 'Human: ' +'hello' + '\n\nAssistant:'
input_ids = tokenizer.encode(inputs, return_tensors="pt").to(DEV)
"""
with torch.no_grad():
    generated_ids = model.generate(
        input_ids,
        do_sample=True,
        min_length=args.min_length,
        max_length=args.max_length,
        top_p=args.top_p,
        temperature=args.temperature,
    )
print("Assistant:\n") 
print(tokenizer.decode([el.item() for el in generated_ids[0]])[len(inputs):]) # generated_ids开头加上了bos_token,需要将inpu的内容截断,只输出Assistant 
print("\n-------------------------------\n")

"""
#python bloom_inference.py BELLE_BLOOM_GPTQ_4BIT  --temperature 1.2  --wbits 4 --groupsize 128 --load  BELLE_BLOOM_GPTQ_4BIT/bloom7b-2m-4bit-128g.pt
class GateAPIHandler(tornado.web.RequestHandler):
    def initialize(self):
        self.set_header("Content-Type", "application/text")
        self.set_header("Access-Control-Allow-Origin", "*")


    async def post(self):

        print("BBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBB")
        postArgs = self.request.body_arguments

        print( postArgs)
        if (not 'status' in postArgs):
            return tornado.web.HTTPError(400)
        try:
            json_str = postArgs.get("status")[0]
#            req = json.loads(json_str)
            print(json_str)
            #logging.error("recieve time : {0} . player id : {1}".format(str(time.time()), str(req["playerID"])))
            inputs = 'Human: ' +json_str.decode('utf-8') + '\n\nAssistant:'
            input_ids = tokenizer.encode(inputs, return_tensors="pt").to(DEV)
            
            with torch.no_grad():
                generated_ids = model.generate(
                    input_ids,
                    do_sample=True,
                    min_length=args.min_length,
                    max_length=args.max_length,
                    top_p=args.top_p,
                    temperature=args.temperature,
                )
            print("Assistant:\n")
            answer=tokenizer.decode([el.item() for el in generated_ids[0]])[len(inputs):]
            print(answer) # generated_ids开头加上了bos_token,需要将inpu的内容截断,只输出Assistant 
            result = {'belle':answer}
            pred_str = str(json.dumps(result))
            self.write(pred_str)
            #logging.error("callback time : {0} . player id : {1}, result:{2}".format(str(time.time()), str(playerID), pred_str))
        except Exception as e:
            logging.error("Error: {0}.".format(e))
            traceback.print_exc()
            raise tornado.web.HTTPError(500)

    def get(self):
        raise tornado.web.HTTPError(300)


import logging
import tornado.autoreload
import tornado.ioloop
import tornado.options
import tornado.web
import tornado.httpserver
#import   itempredict
import argparse
from tornado.httpserver import HTTPServer





#trace()
if __name__ == "__main__":
    tornado.options.define("port", default=8081,type=int, help="This is a port number",
                           metavar=None, multiple=False, group=None, callback=None)
    tornado.options.parse_command_line()
    app = tornado.web.Application([
        (r"/", GateAPIHandler),
    ])
    apiport = tornado.options.options.port
    app.listen(apiport)
    logging.info("Start Gate API server on port {0}.".format(apiport))

    server = HTTPServer(app)
    server.start(1)
    #trace()
    #tornado.autoreload.start()
    tornado.ioloop.IOLoop.instance().start()
                                             

import base64
import json
import time
import requests
from utils.ops import read_wav_bytes

URL = 'http://192.168.3.9:8081'

#wav_bytes, sample_rate, channels, sample_width = read_wav_bytes('out.wav')
data = {
    'status': ' 如何理解黑格尔的 量变引起质变规律和否定之否定规律',

}


t0=time.time()
r = requests.post(URL,  data=data)
t1=time.time()
r.encoding='utf-8'

result = json.loads(r.text)
print(result)
print('time:', t1-t0, 's')

基于tornado BELLE 搭建本地的web 服务_第1张图片

你可能感兴趣的:(tornado,前端,pytorch)