PaddleNLP使用Vicuna

LLaMA 模型

LLaMa 是一个大型语言模型,由 Meta 开源。

它的全称是 Large Language Model Meta AI,参数量从 70 亿到 650 亿不等。

例如,130 亿参数的 LLaMA 模型在大多数基准上可以胜过参数量达 1750 亿的 GPT-3,而且可以在单块 V100 GPU 上运行。

而最大的 650 亿参数的 LLaMA 模型可以媲美谷歌的 Chinchilla-70B 和 PaLM-540B。

Vicuna 模型

Vicuna 是一个由 UC 伯克利、CMU、斯坦福等机构的学者联手发布的最新开源大模型。

基于 Meta 开源的 LLaMA 大模型,使用 ShareGPT 平台上的用户共享对话数据微调而来。

包含 7B 和 13B 两个型号的开源预训练模型。

PaddleNLP使用Vicuna_第1张图片

下载模型

# 下载 Vicuna 7B
# !git lfs clone http://git.aistudio.baidu.com/180581/vicuna-7b-v1.1.git

# 下载 Vicuna 13B
!git lfs clone http://git.aistudio.baidu.com/180581/vicuna-13b-v1.1.git

开发环境

!pip install --pre --upgrade paddlenlp -f https://www.paddlepaddle.org.cn/whl/paddlenlp.html --user
!pip install paddlepaddle-gpu==0.0.0.post112 -f https://www.paddlepaddle.org.cn/whl/linux/gpu/develop.html --user

代码

import os
import glob
import paddle

from tqdm import tqdm
from paddlenlp.transformers import LlamaForCausalLM, LlamaConfig, LlamaTokenizer

pattern = 'paddle-model-?????-of-?????.pdparams'

# Vicuna 7B
# ckpt_dir = 'vicuna-7b-v1.1'
# config_dict =  {
#     "hidden_size": 4096,
#     "initializer_range": 0.02,
#     "intermediate_size": 11008,
#     "max_position_embeddings": 2048,
#     "model_type": "llama",
#     "num_attention_heads": 32,
#     "num_hidden_layers": 32,
#     "rms_norm_eps": 1e-06,
#     "vocab_size": 32000,
#     "bos_token_id": 1,
#     "eos_token_id": 2,
#     "pad_token_id": 0,
#     "use_cache": True,
#     "use_recompute": False,
#     "use_flash_attention": False,
# }

# Vicuna 13B
ckpt_dir = 'vicuna-13b-v1.1'
config_dict =  {
    "hidden_size": 5120,
    "initializer_range": 0.02,
    "intermediate_size": 13824,
    "max_position_embeddings": 2048,
    "model_type": "llama",
    "num_attention_heads": 40,
    "num_hidden_layers": 40,
    "rms_norm_eps": 1e-06,
    "vocab_size": 32000,
    "bos_token_id": 1,
    "eos_token_id": 2,
    "pad_token_id": 0,
    "use_cache": True,
    "use_recompute": False,
    "use_flash_attention": False,
}

paddle.set_default_dtype('float16')

tokenizer = LlamaTokenizer.from_pretrained(ckpt_dir)

config = LlamaConfig(**config_dict)

model = LlamaForCausalLM(config)
model.eval()

for name, layer in model.named_sublayers():
    if 'rotary_emb' in name:
        layer.inv_freq = layer.inv_freq.cast(paddle.float32)

paddle.device.cuda.empty_cache()


for file_path in tqdm(glob.glob(os.path.join(ckpt_dir, pattern))):
    params = paddle.load(file_path)
    assert model.set_dict(params)[1] == [], 'Load error.'
    del params
    paddle.device.cuda.empty_cache()

input_text = input('USER: ')
prompt = f'''USER: {input_text}\n\nASSISTANT: '''
with paddle.no_grad():
    with paddle.amp.auto_cast(False, level='O2', dtype='float16'):
        while True:
            if input_text == 'exit':
                break
            inputs = tokenizer(
                prompt, 
                return_tensors="pd", 
                return_attention_mask=True,
                return_position_ids=True
            )
            outputs = model.generate(
                input_ids=inputs.input_ids, 
                attention_mask=inputs.attention_mask, 
                position_ids=inputs.position_ids, 
                max_length=2048-inputs.input_ids.shape[1], 
                min_length=0, 
                decode_strategy="sampling",
                temperature=0.8, 
                top_k=40, 
                top_p=0.95, 
                repetition_penalty=1.1,
                bos_token_id=tokenizer.bos_token_id,
                eos_token_id=tokenizer.eos_token_id,
                pad_token_id=tokenizer.pad_token_id,
                use_cache=True, 
                use_fast=True, 
                use_fp16_decoding=True)
            response = tokenizer.decode(outputs[0][0], skip_special_tokens=True)
            print('ASSISTANT: ' + response)
            input_text = input('USER: ')
            prompt += f'''{response}\n\nUSER: {input_text}\n\nASSISTANT: '''
            del inputs
            del outputs
            del response
            paddle.device.cuda.empty_cache()

你可能感兴趣的:(自然语言处理)