【Hugging Face】transformers 库中 model 的常用方法和属性

Hugging Face transformers 库中 model 的常用方法和属性

transformers 库中,model 代表 预训练的 Transformer 模型,可用于 文本分类、问答、文本生成等任务。不同任务的 model 可能会有不同的方法和属性,但它们共享许多常见功能。


1. model 的常见属性

在加载 AutoModelAutoModelForXXX 后,可以使用以下属性:

from transformers import AutoModel

model = AutoModel.from_pretrained("bert-base-uncased")
属性 作用 示例
model.config 获取模型配置 print(model.config)
model.num_parameters() 获取模型参数数量 model.num_parameters()
model.device 查看模型在哪个设备上 model.device
model.dtype 获取数据类型(fp32, fp16, bf16) model.dtype
model.embeddings 获取模型的嵌入层 model.embeddings
model.encoder 获取 Transformer 编码器 model.encoder
model.lm_head 语言模型的输出层(如 GPT-2, T5) model.lm_head

2. model 的常用方法

方法 作用
model.forward(input_ids, attention_mask) 前向传播(通常直接调用 model(...)
model.to(device) 将模型移动到 CPU/GPU
model.eval() 进入评估模式
model.train() 进入训练模式
model.generate(input_ids, max_length=50) 生成文本(适用于 GPT-2, T5, BART)
model.save_pretrained(path) 保存模型
model.load_state_dict(torch.load(path)) 加载训练好的参数
model.parameters() 获取所有参数
model.named_parameters() 获取所有参数及名称

3. model 详细用法

3.1. 前向传播

所有 Transformer 模型都支持 forward(),但通常直接调用 model(...)

from transformers import AutoTokenizer, AutoModel
import torch

tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
model = AutoModel.from_pretrained("bert-base-uncased")

# 处理输入文本
text = "Hugging Face is great!"
inputs = tokenizer(text, return_tensors="pt")

# 前向传播
outputs = model(**inputs)
print(outputs.last_hidden_state.shape)

输出

torch.Size([1, 7, 768])

last_hidden_state.shape 解释:

  • 1:批量大小
  • 7:序列长度(包括 [CLS] 和 [SEP])
  • 768:隐藏层维度

3.2. 进入训练或评估模式

model.train()  # 启用 dropout、BatchNorm
model.eval()   # 关闭 dropout,进入推理模式

3.3. 将模型移动到 GPU

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

3.4. 计算模型参数数量

num_params = model.num_parameters()
print(f"模型参数数量: {num_params / 1e6:.2f}M")

3.5. 生成文本(适用于 GPT-2, BART, T5)

from transformers import AutoModelForCausalLM

model = AutoModelForCausalLM.from_pretrained("gpt2")
input_text = "Hugging Face is"
input_ids = tokenizer(input_text, return_tensors="pt").input_ids

output_ids = model.generate(input_ids, max_length=50)
output_text = tokenizer.decode(output_ids[0], skip_special_tokens=True)
print(output_text)

3.6. 保存和加载模型

保存模型
model.save_pretrained("./my_model")
加载模型
from transformers import AutoModel

model = AutoModel.from_pretrained("./my_model")

3.7. 加载自定义权重

如果你有一个训练好的 .bin 权重:

import torch

state_dict = torch.load("pytorch_model.bin")
model.load_state_dict(state_dict)

4. model 在不同任务中的应用

4.1. 文本分类

from transformers import AutoModelForSequenceClassification

model = AutoModelForSequenceClassification.from_pretrained("distilbert-base-uncased", num_labels=2)
inputs = tokenizer("Hugging Face is great!", return_tensors="pt")
outputs = model(**inputs)
logits = outputs.logits
predicted_class = logits.argmax().item()
print(predicted_class)

4.2. 问答任务

from transformers import AutoModelForQuestionAnswering

model = AutoModelForQuestionAnswering.from_pretrained("distilbert-base-cased-distilled-squad")
question = "Where is Hugging Face based?"
context = "Hugging Face is based in New York City."

inputs = tokenizer(question, context, return_tensors="pt")
outputs = model(**inputs)

start = outputs.start_logits.argmax()
end = outputs.end_logits.argmax() + 1
answer = tokenizer.decode(inputs.input_ids[0][start:end])
print(answer)

4.3. 生成式任务(翻译、摘要)

from transformers import AutoModelForSeq2SeqLM

model = AutoModelForSeq2SeqLM.from_pretrained("facebook/bart-large-cnn")
input_text = "Hugging Face provides NLP tools for AI applications."
inputs = tokenizer(input_text, return_tensors="pt")

output_ids = model.generate(inputs.input_ids, max_length=30)
output_text = tokenizer.decode(output_ids[0], skip_special_tokens=True)
print(output_text)

5. 总结

modeltransformers 中是核心组件,适用于 文本分类、问答、翻译、文本生成等任务

常用方法:

  • model(...) 进行 前向传播
  • model.generate() 进行 文本生成
  • model.to(device) 移动模型到 GPU
  • model.eval() 进入推理模式
  • model.train() 进入训练模式
  • model.save_pretrained(path) 保存模型
  • model.num_parameters() 查看参数数量

不同任务使用不同的 AutoModelForXXX

  • 文本分类AutoModelForSequenceClassification
  • 问答AutoModelForQuestionAnswering
  • 文本生成AutoModelForCausalLM
  • 翻译/摘要AutoModelForSeq2SeqLM

你可能感兴趣的:(Hugging,Face,model,模型的属性和方法,transformers,Hugging,Face,python)