介绍 LoRA 与 QLoRA
1. LoRA (Low-Rank Adaptation)
LoRA 是一种用于大规模语言模型 (LLM) 的参数高效微调技术,旨在减少微调大模型所需的计算资源和存储空间。LoRA 的核心思想是将全量参数更新分解为低秩矩阵的形式,从而显著减少参数数量和计算开销。
核心思想:
低秩分解:将大模型的权重矩阵表示为两个低秩矩阵的乘积。这种分解方法不仅保留了原始模型的表示能力,还显著减少了微调过程中需要更新的参数数量。
参数高效:通过这种方式,只需微调少量参数(即低秩矩阵的参数),而非整个模型的参数,从而大大降低了存储和计算成本。
优点:
存储节省:减少了需要存储和更新的参数数量。
计算效率:降低了微调过程中所需的计算资源。
可扩展性:适用于各种大规模预训练模型,包括 NLP 和 CV 等领域。
应用场景:
自然语言处理 (NLP):如机器翻译、文本生成等任务。
计算机视觉 (CV):如图像分类、对象检测等任务。
import torch
import torch.nn as nn
class LoRALinear(nn.Module):
def __init__(self, in_features, out_features, rank=4):
super(LoRALinear, self).__init__()
self.rank = rank
self.weight_A = nn.Parameter(torch.randn(in_features, rank))
self.weight_B = nn.Parameter(torch.randn(rank, out_features))
self.bias = nn.Parameter(torch.zeros(out_features))
def forward(self, x):
return x @ self.weight_A @ self.weight_B + self.bias
# 示例用法
lora_layer = LoRALinear(512, 1024)
input_data = torch.randn(32, 512)
output_data = lora_layer(input_data)
2. QLoRA (Quantized Low-Rank Adaptation)
QLoRA (Quantized Low-Rank Adaptation) 是一种结合了模型量化和低秩适配的技术,旨在减少大规模预训练模型微调和部署的计算和存储成本。QLoRA 的主要思路是:
量化预训练模型参数:将已有的大规模预训练模型参数进行量化处理,以减少存储需求和计算负担。
低秩适配 (LoRA):在量化后的模型上应用低秩适配技术,仅微调少量附加的低秩矩阵,从而保持微调的高效性。
核心思想
模型量化:对预训练模型的权重进行量化处理,如使用 8-bit 或更低精度表示模型参数,以减少存储空间和计算资源。
低秩适配:在量化后的模型基础上进行低秩适配,通过引入低秩矩阵来进行微调,从而节省参数数量和计算开销。
优点
存储节省:通过量化大幅减少了模型权重的存储需求。
计算加速:量化后的计算通常更快,特别是在专门的硬件(如 GPU、TPU)上。
高效微调:结合 LoRA 技术,只需微调少量参数即可实现模型的适应,提高了微调效率。
import torch
import torch.nn as nn
# 定义一个简单的预训练线性层
class PretrainedLinear(nn.Module):
def __init__(self, in_features, out_features):
super(PretrainedLinear, self).__init__()
self.weight = nn.Parameter(torch.randn(out_features, in_features))
self.bias = nn.Parameter(torch.zeros(out_features))
def forward(self, x):
return torch.matmul(x, self.weight.t()) + self.bias
# 量化函数
def quantize_tensor(tensor, num_bits=8):
scale = tensor.abs().max() / (2**(num_bits - 1) - 1)
quantized = (tensor / scale).round()
return quantized, scale
# 反量化函数
def dequantize_tensor(quantized, scale):
return quantized * scale
# 定义量化后的 LoRA 适配层
class QuantizedLoRALinear(nn.Module):
def __init__(self, pretrained_layer, rank=4, num_bits=8):
super(QuantizedLoRALinear, self).__init__()
self.pretrained_layer = pretrained_layer
# 对预训练层的权重进行量化 一般对预训练的大模型如LLAMA3 的 attention 层和fc层 等线性层进行量化
self.quantized_weight, self.scale = quantize_tensor(pretrained_layer.weight.data, num_bits)
# 定义低秩适配的权重
self.weight_A = nn.Parameter(torch.randn(pretrained_layer.weight.shape[1], rank))
self.weight_B = nn.Parameter(torch.randn(rank, pretrained_layer.weight.shape[0]))
def forward(self, x):
# 反量化预训练的权重
dequantized_weight = dequantize_tensor(self.quantized_weight, self.scale)
# 计算 LoRA 适配部分
lora_adjustment = torch.matmul(torch.matmul(x, self.weight_A), self.weight_B.t())
# 总输出为预训练层输出加上 LoRA 适配部分
return torch.matmul(x, dequantized_weight.t()) + lora_adjustment + self.pretrained_layer.bias
# 示例用法
pretrained_layer = PretrainedLinear(512, 1024)
quantized_lora_layer = QuantizedLoRALinear(pretrained_layer, rank=4, num_bits=8)
input_data = torch.randn(32, 512)
output_data = quantized_lora_layer(input_data)
print(output_data)
例子:使用 LLaMA 3 作为预训练模型,并结合 QLoRA 技术进行微调,可以显著减少存储和计算成本,同时保持模型性能。下面将介绍如何将 QLoRA 应用于 LLaMA 3 模型,并提供一个具体的示例代码。
主要步骤
加载预训练的 LLaMA 3 模型和分词器。
对注意力层和全连接层进行量化。
应用 LoRA 技术进行低秩适配。
准备数据为alpaca数据格式
微调模型。
测试模型
import torch
import torch.nn as nn
from transformers import LLaMAForCausalLM, LLaMATokenizer, Trainer, TrainingArguments, DataCollatorForLanguageModeling
from datasets import load_dataset
# 定义量化函数
def quantize_tensor(tensor, num_bits=8):
scale = tensor.abs().max() / (2**(num_bits - 1) - 1)
quantized = (tensor / scale).round().int()
return quantized, scale
# 定义反量化函数
def dequantize_tensor(quantized, scale):
return quantized.float() * scale
# 定义量化后的 LoRA 适配层
class QLoRALayer(nn.Module):
def __init__(self, pretrained_layer, rank=4, num_bits=8):
super(QLoRALayer, self).__init__()
self.pretrained_layer = pretrained_layer
# 对预训练层的权重进行量化
self.quantized_weight, self.scale = quantize_tensor(pretrained_layer.weight.data, num_bits)
# 定义低秩适配的权重
self.weight_A = nn.Parameter(torch.randn(pretrained_layer.weight.shape[1], rank))
self.weight_B = nn.Parameter(torch.randn(rank, pretrained_layer.weight.shape[0]))
def forward(self, x):
# 反量化预训练的权重
dequantized_weight = dequantize_tensor(self.quantized_weight, self.scale)
# 计算 LoRA 适配部分
lora_adjustment = torch.matmul(torch.matmul(x, self.weight_A), self.weight_B.t())
# 总输出为预训练层输出加上 LoRA 适配部分
return torch.matmul(x, dequantized_weight.t()) + lora_adjustment
# 加载预训练的 LLaMA 3 模型和分词器
model_name = "facebook/llama-3b"
model = LLaMAForCausalLM.from_pretrained(model_name)
tokenizer = LLaMATokenizer.from_pretrained(model_name)
# 遍历模型中的注意力层和全连接层进行量化和 LoRA 适配
for name, module in model.named_modules():
if isinstance(module, nn.Linear):
quantized_lora_layer = QLoRALayer(module)
# 替换原始层
setattr(model, name, quantized_lora_layer)
# 准备 Alpaca 格式的数据集
dataset = load_dataset("alpaca", split="train")
# 数据预处理函数
def preprocess_function(examples):
inputs = [ex["instruction"] + " " + ex["input"] for ex in examples]
model_inputs = tokenizer(inputs, max_length=512, truncation=True, padding="max_length")
labels = tokenizer(examples["output"], max_length=512, truncation=True, padding="max_length")
model_inputs["labels"] = labels["input_ids"]
return model_inputs
# 处理数据集
tokenized_dataset = dataset.map(preprocess_function, batched=True)
# 数据整理器
data_collator = DataCollatorForLanguageModeling(tokenizer, mlm=False)
# 微调模型的训练参数
training_args = TrainingArguments(
output_dir="./results",
overwrite_output_dir=True,
num_train_epochs=3,
per_device_train_batch_size=8,
save_steps=10_000,
save_total_limit=2,
logging_dir="./logs",
)
# 定义 Trainer
trainer = Trainer(
model=model,
args=training_args,
data_collator=data_collator,
train_dataset=tokenized_dataset,
)
# 开始微调
trainer.train()
# 测试模型
model.eval()
instruction = "Translate English to French:"
input_text = "The quick brown fox jumps over the lazy dog."
test_input = tokenizer(instruction + " " + input_text, return_tensors="pt")
with torch.no_grad():
generated_ids = model.generate(test_input["input_ids"], max_length=50)
generated_text = tokenizer.decode(generated_ids[0], skip_special_tokens=True)
print(f"Generated Text: {generated_text}")