peft
库(Parameter-Efficient Fine-Tuning)进行微调,支持如下tuning:
PromptTuningConfig
设置Prompt tuning配置,下面num_virtual_tokens
设置prompt前缀的token数,因为token初始化用任务相关文字效果更好,所以下面用Classify if the tweet is a complaint or not:
初始化,#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
@Author : andy
@Date : 2023/7/10 20:37
@Contact: [email protected]
@File : prompt_tuning.py
"""
from transformers import AutoModelForCausalLM, AutoTokenizer, default_data_collator, get_linear_schedule_with_warmup
from peft import get_peft_config, get_peft_model, PromptTuningInit, PromptTuningConfig, TaskType, PeftType
import torch
from datasets import load_dataset
import os
from torch.utils.data import DataLoader
from tqdm import tqdm
device = "mps"
# device = "cuda"
model_name_or_path = "bigscience/bloomz-560m"
tokenizer_name_or_path = "bigscience/bloomz-560m"
peft_config = PromptTuningConfig(
task_type=TaskType.CAUSAL_LM,
prompt_tuning_init=PromptTuningInit.TEXT,
num_virtual_tokens=8,
prompt_tuning_init_text="Classify if the tweet is a complaint or not:",
tokenizer_name_or_path=tokenizer_name_or_path,
)
dataset_name = "twitter_complaints"
text_column = "Tweet text"
label_column = "text_label"
max_length = 64
learning_rate = 3e-2
num_epochs = 20
batch_size = 8
output_dir = './output'
# 1. load a subset of the RAFT dataset at https://huggingface.co/datasets/ought/raft
dataset = load_dataset("ought/raft", dataset_name)
# get lable's possible values
label_values = [name.replace("_", "") for name in dataset["train"].features["Label"].names]
# append label value to the dataset to make it more readable
dataset = dataset.map(
lambda x: {label_column: [label_values[label] for label in x["Label"]]},
batched=True,
num_proc=1
)
# have a look at the data structure
dataset["train"][0]
# 2. dataset
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)
if tokenizer.pad_token_id is None:
tokenizer.pad_token_id = tokenizer.eos_token_id
def preprocess_fn(examples):
tweets = examples[text_column]
# pad labels with a pad token at the end
labels = [str(x) + tokenizer.pad_token for x in examples[label_column]]
# concatenate the tweet with it label
inputs = [f"{text_column} : {tweet}\nLabel :{label}"
for tweet, label in zip(tweets, labels)]
# tokenize input
model_inputs = tokenizer(inputs,
padding='max_length',
max_length=max_length,
truncation=True,)
# tokenize label, as -100 not a valid token id, do the padding manually here
labels_input_ids = []
for i in range(len(labels)):
ids = tokenizer(labels[i])["input_ids"]
padding = [-100] * (max_length - len(ids))
labels_input_ids.append(padding + ids)
model_inputs["labels"] = labels_input_ids
# make model inputs tensor
model_inputs["input_ids"] = [torch.tensor(ids) for ids in model_inputs["input_ids"]]
model_inputs["attention_mask"] = [torch.tensor(ids) for ids in model_inputs["attention_mask"]]
model_inputs["labels"] = [torch.tensor(ids) for ids in model_inputs["labels"]]
return model_inputs
# have a look at the preprocessing result
# print(preprocess_fn(dataset["train"][:2]))
processed_datasets = dataset.map(
preprocess_fn,
batched=True,
num_proc=1,
remove_columns=dataset["train"].column_names, #remove unprocessed column for training
load_from_cache_file=False,
desc="Running tokenizer on datasset"
)
test_size = round(len(processed_datasets["train"]) * 0.2)
train_val = processed_datasets["train"].train_test_split(
test_size=test_size, shuffle=True, seed=42)
train_data = train_val["train"]
val_data = train_val["test"]
# 3. model
model = AutoModelForCausalLM.from_pretrained(model_name_or_path)
model = get_peft_model(model, peft_config)
print(model.print_trainable_parameters())
trainable params: 8192 || all params: 559222784 || trainable%: 0.0014648902430985358
从上面打印结果看出,模型的参数有5.6亿左右,但是需要训练的参数只占0.001%,只有8192个。
# 4. trainer
from transformers import Trainer, TrainingArguments
trainer = Trainer(
model=model,
train_dataset=train_data,
eval_dataset=val_data,
data_collator=default_data_collator,
args=TrainingArguments(
output_dir='./output',
per_device_train_batch_size=batch_size,
num_train_epochs=num_epochs,
learning_rate=learning_rate,
load_best_model_at_end=True,
logging_strategy='steps',
logging_steps=10,
evaluation_strategy='steps',
eval_steps=10,
save_strategy='steps',
save_steps=10,
)
)
trainer.train()
# 5. inference
def inference():
def generate(inputs, infer_model):
with torch.no_grad():
inputs = {k: v.to(device) for k, v in inputs.items()}
outputs = infer_model.generate(
input_ids=inputs["input_ids"],
attention_mask=inputs["attention_mask"],
max_new_tokens=20,
eos_token_id=3
)
print(tokenizer.batch_decode(outputs.detach().cpu().numpy(), skip_special_tokens=True)[0])
# (1) base model_inference
base_model = AutoModelForCausalLM.from_pretrained(model_name_or_path)
base_model.to(device)
inputs = tokenizer(
f'{text_column} : {"@denny the grocery price is soaring, even milk is becoming unaffordable, could you do something?"}\nLabel :',
return_tensors="pt", # Return PyTorch torch.Tensor objects.
)
generate(inputs, base_model)
print("----------------------------------------")
shot1 = f'{text_column} : {"@nationalgridus I have no water and the bill is current and paid. Can you do something about this?"}\nLabel :complaint\n'
shot2 = f'{text_column} : {"@HMRCcustomers No this is my first job"}\nLabel :no complaint\n'
input = f'{text_column} : {"@denny the grocery price is soaring, even milk is becoming unaffordable, could you do something?"}\nLabel :'
inputs_few_shot = tokenizer(
shot1 + shot2 + input,
return_tensors="pt",
)
generate(inputs_few_shot, base_model)
# (2) prompt-tuned model_inference
from peft import PeftModel, PeftConfig
path = "/content/drive/MyDrive/prompt_tuning"
config = PeftConfig.from_pretrained(path)
pretrained_model = AutoModelForCausalLM.from_pretrained(config.base_model_name_or_path)
prompt_tuned_model = PeftModel.from_pretrained(pretrained_model, path)
prompt_tuned_model.to(device)
inputs = tokenizer(
f'{text_column} : {"@denny the grocery price is soaring, even milk is becoming unaffordable, could you do something?"}\nLabel :',
return_tensors="pt", # Return PyTorch torch.Tensor objects.
)
generate(inputs, prompt_tuned_model)
inference()
Tweet text : @denny the grocery price is soaring, even milk is becoming unaffordable, could you do something?
Label : @denny the grocery
Tweet text : @denny the grocery price is soaring, even milk is becoming unaffordable, could you do something?
Label :complaint
[1] https://github.com/jxhe/unify-parameter-efficient-tuning
[2] Continuous Optimization:从Prefix-tuning到更强大的P-Tuning V2