ChatGLM-6B 微调之后模型 加载 并且问问题 代码

import os
import platform
import signal
from transformers import AutoTokenizer, AutoModel

# model_dir=""
print('load tokenizer')
model_dir='/xxx/home/work/chatglm-6b'

import torch
tokenizer = AutoTokenizer.from_pretrained(model_dir, trust_remote_code=True)
print('load tokenizer end ')
from transformers import (
    AutoConfig,
    AutoModel,
    AutoTokenizer,
    AutoTokenizer,
    DataCollatorForSeq2Seq,
    HfArgumentParser,
    Seq2SeqTrainingArguments,
    set_seed,
)

PRE_SEQ_LEN=128
class model_args:
    # ptuning_checkpoint=None
    ptuning_checkpoint='/xxx/trainOut/job_and_school-chatglm-6b-pt-128-2e-2/checkpoint-3000'
    model_name_or_path=model_dir
    pre_seq_len=PRE_SEQ_LEN
    prefix_projection=False

# Load pretrained model and tokenizer
config = AutoConfig.from_pretrained(model_args.model_name_or_path, trust_remote_code=True)
config.pre_seq_len = model_args.pre_seq_len
config.prefix_projection = model_args.prefix_projection

# Traceback (most recent call last):
#   File "/xxx/home/work/chat-glm-6-b-2/cli_demo.py", line 34, in 
#     model.transformer.prefix_encoder.load_state_dict(new_prefix_state_dict)
#   File "/xxx/condaEnvs/mossChat/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1614, in __getattr__
#     raise AttributeError("'{}' object has no attribute '{}'".format(
# AttributeError: 'ChatGLMModel' object has no attribute 'prefix_encoder'
if model_args.ptuning_checkpoint is not None:
    # Evaluation
    # Loading extra state dict of prefix encoder
    # config=config, 
    # model = AutoModel.from_pretrained(model_args.model_name_or_path, trust_remote_code=True)
    model = AutoModel.from_pretrained(model_args.model_name_or_path,config=config,  trust_remote_code=True).half().cuda()

    prefix_state_dict = torch.load(os.path.join(model_args.ptuning_checkpoint, "pytorch_model.bin"))
    new_prefix_state_dict = {}
    for k, v in prefix_state_dict.items():
        new_prefix_state_dict[k[len("transformer.prefix_encoder."):]] = v
    model.transformer.prefix_encoder.load_state_dict(new_prefix_state_dict)
else:
    # config=config,
    model = AutoModel.from_pretrained(model_args.model_name_or_path,  trust_remote_code=True).half().cuda()
    
# model = AutoModel.from_pretrained(model_dir, trust_remote_code=True).half().cuda()
print('load model end ')

print('model eval  ')

model = model.eval()
print('model eval  end ')
os_name = platform.system()
clear_command = 'cls' if os_name == 'Windows' else 'clear'
stop_stream = False


def build_prompt(history):
    prompt = "欢迎使用 ChatGLM-6B 模型,输入内容即可进行对话,clear 清空对话历史,stop 终止程序"
    for query, response in history:
        prompt += f"\n\n用户:{query}"
        prompt += f"\n\nChatGLM-6B:{response}"
    return prompt


def signal_handler(signal, frame):
    global stop_stream
    stop_stream = True


def main():
    history = []
    global stop_stream
    print("欢迎使用 ChatGLM-6B 模型,输入内容即可进行对话,clear 清空对话历史,stop 终止程序")
    while True:
        query = input("\n用户:")
        if query.strip() == "stop":
            break
        if query.strip() == "clear":
            history = []
            os.system(clear_command)
            print("欢迎使用 ChatGLM-6B 模型,输入内容即可进行对话,clear 清空对话历史,stop 终止程序")
            continue
        count = 0
        for response, history in model.stream_chat(tokenizer, query, history=history):
            if stop_stream:
                stop_stream = False
                break
            else:
                count += 1
                if count % 8 == 0:
                    os.system(clear_command)
                    print(build_prompt(history), flush=True)
                    signal.signal(signal.SIGINT, signal_handler)
        os.system(clear_command)
        print(build_prompt(history), flush=True)

# deepspeed


if __name__ == "__main__":
    main()

你可能感兴趣的:(python,深度学习,pytorch,人工智能)