lora微调

import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
import torch
import torch.nn as nn
import bitsandbytes as bnb
import transformers
from transformers import AutoTokenizer, AutoConfig, AutoModelForCausalLM
"""
opt-6.7b模型,它以float16的精度存储,大小大约为13GB!如果我们使用bitsandbytes库以8位加载它们,我们需要大约7GB的显存
"""
# load_in_8bit=True参数来调用bitsandbytes库进行8位量化
model = AutoModelForCausalLM.from_pretrained("facebook/opt-6.7b",load_in_8bit=True,device_map="auto")
tokenizer = AutoTokenizer.from_pretrained("facebook/opt-6.7b")

# 针对所有非int8的模块进行预处理以提升精度
from peft import prepare_model_for_int8_training
model = prepare_model_for_int8_training(model)

# 配置LoRA的参数
from peft import LoraConfig, get_peft_model
config = LoraConfig(r=16, lora_alpha=32, target_modules=["q_proj", "v_proj"], lora_dropout=0.05, bias="none", task_type="CAUSAL_LM")
model = get_peft_model(model, config)

def print_trainable_parameters(model):
    """Prints the number of trainable parameters in the model."""
    trainable_params = 0
    all_param = 0
    for _, param in model.named_parameters():
        all_param += param.numel()
        if param.requires_grad:
            trainable_params += param.numel()
            print(f"trainable params: {trainable_params} || all params: {all_param} || trainable%: {100 * trainable_params / all_param}")


# 加载数据(名人名言数据集作为训练数据)
from datasets import load_dataset
data = load_dataset("Abirate/english_quotes")
data = data.map(lambda samples: tokenizer(samples["quote"]), batched=True)

# 训练
trainer = transformers.Trainer(
    model=model,
    train_dataset=data["train"],
    args=transformers.TrainingArguments(per_device_train_batch_size=4,gradient_accumulation_steps=4,warmup_steps=100,max_steps=200,learning_rate=2e-4,fp16=True,logging_steps=1,output_dir="outputs",),
    data_collator=transformers.DataCollatorForLanguageModeling(tokenizer, mlm=False))
trainer.train()

# 推理
batch = tokenizer("Two things are infinite: ", return_tensors="pt")
with torch.cuda.amp.autocast():
    output_tokens = model.generate(**batch, max_new_tokens=50)
    print("\n\n", tokenizer.decode(output_tokens[0], skip_special_tokens=True))

参考链接:

1、​​​​​​​​​​​​​​2023年的深度学习入门指南(12) - PEFT与LoRA-CSDN博客

2、https://github.com/huggingface/notebooks/blob/main/peft/Fine_tune_BLIP2_on_an_image_captioning_dataset_PEFT.ipynb​​​​​​​

你可能感兴趣的:(深度学习,机器学习,人工智能)