from transformers import AutoConfig, AutoModel, AutoTokenizer
model = AutoModel.from_pretrained("hfl/rbt3", force_download=False) # 为True时是强制下载
其他加载方式
模型下载
# 全部文件下载
!git clone "https://huggingface.co/hfl/rbt3"
# 指定文件下载,下载.bin结尾文件
!git lfs clone "https://huggingface.co/hfl/rbt3" --include="*.bin"
离线加载
model = AutoModel.from_pretrained("rbt3")
查看配置
model.config
BertConfig {
"_name_or_path": "rbt3",
"architectures": [
"BertForMaskedLM"
],
"attention_probs_dropout_prob": 0.1,
"classifier_dropout": null,
"directionality": "bidi",
"hidden_act": "gelu",
"hidden_dropout_prob": 0.1,
"hidden_size": 768,
"initializer_range": 0.02,
"intermediate_size": 3072,
"layer_norm_eps": 1e-12,
"max_position_embeddings": 512,
"model_type": "bert",
"num_attention_heads": 12,
"num_hidden_layers": 3,
"output_past": true,
"pad_token_id": 0,
"pooler_fc_size": 768,
"pooler_num_attention_heads": 12,
"pooler_num_fc_layers": 3,
"pooler_size_per_head": 128,
"pooler_type": "first_token_transform",
"position_embedding_type": "absolute",
"transformers_version": "4.35.2",
"type_vocab_size": 2,
"use_cache": true,
"vocab_size": 21128
}
使用config加载
config = AutoConfig.from_pretrained("rbt3")
config
BertConfig {
"_name_or_path": "rbt3",
"architectures": [
"BertForMaskedLM"
],
"attention_probs_dropout_prob": 0.1,
"classifier_dropout": null,
"directionality": "bidi",
"hidden_act": "gelu",
"hidden_dropout_prob": 0.1,
"hidden_size": 768,
"initializer_range": 0.02,
"intermediate_size": 3072,
"layer_norm_eps": 1e-12,
"max_position_embeddings": 512,
"model_type": "bert",
"num_attention_heads": 12,
"num_hidden_layers": 3,
"output_past": true,
"pad_token_id": 0,
"pooler_fc_size": 768,
"pooler_num_attention_heads": 12,
"pooler_num_fc_layers": 3,
"pooler_size_per_head": 128,
"pooler_type": "first_token_transform",
"position_embedding_type": "absolute",
"transformers_version": "4.35.2",
"type_vocab_size": 2,
"use_cache": true,
"vocab_size": 21128
}
# 查看模型是否要输出attention的结果
config.output_attentions
# 可用于更改model.config里的参数
from transformers import BertConfig
sen = "弱小的我也有大梦想!"
tokenizer = AutoTokenizer.from_pretrained("rbt3")
inputs = tokenizer(sen, return_tensors="pt") # 返回一个pytorch的tensor
inputs
{'input_ids': tensor([[ 101, 2483, 2207, 4638, 2769, 738, 3300, 1920, 3457, 2682, 8013, 102]]), 'token_type_ids': tensor([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]])}
model = AutoModel.from_pretrained("rbt3", output_attentions=True)
output = model(**inputs)
output
BaseModelOutputWithPoolingAndCrossAttentions(last_hidden_state=tensor([[[ 0.6804, 0.6664, 0.7170, ..., -0.4102, 0.7839, -0.0262],
[-0.7378, -0.2748, 0.5034, ..., -0.1359, -0.4331, -0.5874],
[-0.0212, 0.5642, 0.1032, ..., -0.3617, 0.4646, -0.4747],
...,
[ 0.0853, 0.6679, -0.1757, ..., -0.0942, 0.4664, 0.2925],
[ 0.3336, 0.3224, -0.3355, ..., -0.3262, 0.2532, -0.2507],
[ 0.6761, 0.6688, 0.7154, ..., -0.4083, 0.7824, -0.0224]]],
grad_fn=<NativeLayerNormBackward0>), pooler_output=tensor([[-1.2646e-01, -9.8619e-01, -1.0000e+00, -9.8325e-01, 8.0238e-01,
-6.6268e-02, 6.6919e-02, 1.4784e-01, 9.9451e-01, 9.9995e-01,
-8.3051e-02, -1.0000e+00, -9.8866e-02, 9.9980e-01, -1.0000e+00,
9.9993e-01, 9.8291e-01, 9.5363e-01, -9.9948e-01, -1.3219e-01,
-9.9733e-01, -7.7934e-01, 1.0720e-01, 9.8040e-01, 9.9953e-01,
-9.9939e-01, -9.9997e-01, 1.4967e-01, -8.7627e-01, -9.9996e-01,
-9.9821e-01, -9.9999e-01, 1.9396e-01, -1.1276e-01, 9.9359e-01,
-9.9153e-01, 4.4752e-02, -9.8731e-01, -9.9942e-01, -9.9982e-01,
2.9361e-02, 9.9847e-01, -9.2016e-03, 9.9999e-01, 1.7111e-01,
4.5076e-03, 9.9998e-01, 9.9467e-01, 4.9721e-03, -9.0707e-01,
6.9056e-02, -1.8141e-01, -9.8831e-01, 9.9668e-01, 4.9800e-01,
1.2997e-01, 9.9895e-01, -1.0000e+00, -9.9990e-01, 9.9478e-01,
-9.9989e-01, 9.9906e-01, 9.9820e-01, 9.9990e-01, -6.8953e-01,
9.9990e-01, 9.9987e-01, 9.4563e-01, -3.7660e-01, -1.0000e+00,
1.3151e-01, -9.7371e-01, -9.9997e-01, -1.3228e-02, -2.9800e-01,
-9.9985e-01, 9.9662e-01, -2.0004e-01, 9.9997e-01, 3.6876e-01,
-9.9997e-01, 1.5462e-01, 1.9264e-01, 8.9872e-02, 9.9996e-01,
9.9998e-01, 1.5184e-01, -8.9713e-01, -2.1646e-01, -9.9922e-01,
-4.9491e-01, 9.9957e-01, 9.9998e-01, -9.9998e-01, 9.9995e-01,
-5.1678e-01, 5.2056e-02, 5.4613e-02, -9.9816e-01, 9.9328e-01,
1.2717e-04, -1.3744e-01, 1.0000e+00, 9.9984e-01, -3.4417e-01,
-9.9995e-01, -9.9573e-01, 9.9988e-01, -9.9981e-01, 6.3344e-02,
1.0000e+00, 9.4779e-01, 1.0000e+00, 9.9946e-01, 9.9999e-01,
-9.9999e-01, -4.3540e-01, 2.3526e-01, -9.9997e-01, 9.9905e-01,
-9.9272e-01, 1.4150e-01, -9.3078e-01, -8.8246e-02, -1.2646e-02,
-9.9999e-01, 1.8302e-02, 3.9718e-02, -9.8869e-01, -9.9944e-01,
-9.9975e-01, -9.9994e-01, 9.9785e-01, 7.9386e-01, 2.7185e-01,
-1.5316e-01, 9.0940e-02, -9.5427e-02, -1.0000e+00, -9.9974e-01,
-9.9999e-01, 9.5742e-01, -3.5169e-01, 9.9779e-01, -9.9894e-01,
9.9997e-01, -9.9997e-01, 9.9997e-01, 9.9414e-01, -2.7013e-01,
-9.7769e-01, -1.1832e-01, -9.9976e-01, -4.3269e-02, 2.7017e-02,
9.9011e-01, 9.9801e-01, 7.6135e-01, -9.8868e-01, 1.0000e+00,
-9.9946e-01, 9.7542e-01, 1.4210e-01, 9.9955e-01, 1.0000e+00,
-1.0000e+00, 2.5602e-01, -1.0000e+00, 6.9886e-01, 1.1957e-01,
9.9996e-01, 9.9962e-01, 9.7632e-01, 9.9998e-01, -8.6662e-01,
-9.9994e-01, 9.5777e-01, -1.0000e+00, 9.8048e-01, 1.0000e+00,
9.6255e-02, 5.4608e-01, 9.9999e-01, -6.1723e-01, 9.9141e-01,
-1.0398e-01, -1.9344e-01, -9.9981e-01, 2.0875e-01, 9.4846e-01,
9.9600e-01, -9.9833e-01, -3.6391e-02, 9.8665e-01, -3.1239e-02,
6.7723e-02, -9.9968e-01, -9.9970e-01, 9.9994e-01, 9.9983e-01,
6.2746e-01, -2.7500e-01, 1.0000e+00, -1.1557e-01, 9.9997e-01,
-7.4189e-02, 8.3064e-01, -8.6326e-02, 9.9989e-01, 1.6120e-01,
8.7417e-01, 4.2869e-03, 9.9993e-01, -8.4737e-01, -9.9999e-01,
8.9604e-02, 8.9435e-01, 1.0934e-01, -9.9877e-01, 2.1512e-01,
-4.4630e-01, 9.9997e-01, 1.9113e-01, -9.8081e-01, 9.9929e-01,
-9.9977e-01, 6.1149e-01, -1.0000e+00, -9.9892e-01, 9.9998e-01,
-2.9081e-01, -1.0000e+00, 8.6111e-01, 1.0000e+00, -8.8875e-01,
9.9958e-01, -2.4632e-01, -9.9994e-01, -1.4219e-02, 3.7028e-02,
-1.0000e+00, -9.9450e-01, -1.0000e+00, -8.2727e-01, -1.4345e-01,
9.9392e-01, -1.0000e+00, 1.1743e-01, -9.9999e-01, 9.9873e-01,
9.9997e-01, -1.5349e-01, 1.7382e-01, 1.0000e+00, -3.5095e-01,
1.3408e-01, -8.4305e-01, 3.7473e-01, 2.2783e-02, 9.9625e-01,
3.2440e-01, 9.9899e-01, -9.9979e-01, 2.4282e-01, 8.5081e-01,
-1.0000e+00, -1.0721e-01, 9.9331e-01, 2.8107e-02, 1.0824e-01,
-1.8632e-01, 1.7009e-01, 9.5663e-01, 9.9947e-01, 1.0000e+00,
9.9177e-01, 9.9999e-01, 9.9999e-01, -3.1200e-01, -9.9837e-01,
-5.6503e-01, 2.3465e-01, -1.0000e+00, -9.8613e-01, -9.9979e-01,
9.9075e-01, 1.1560e-01, 1.0000e+00, -1.0000e+00, 1.0000e+00,
-9.6587e-01, 8.5970e-02, -5.3795e-02, 1.2931e-01, -5.4356e-01,
-1.2560e-01, 9.9880e-01, -7.6849e-02, 9.9302e-01, 9.9631e-01,
-4.9744e-03, -2.4950e-01, 2.0312e-01, -2.2919e-01, 9.9857e-01,
-9.9750e-01, 9.9836e-01, 1.0468e-01, 9.9982e-01, -4.5313e-01,
-1.0000e+00, 9.9977e-01, -9.9988e-01, -5.4165e-01, -9.9991e-01,
-9.8466e-01, 9.0576e-02, -9.8760e-01, 7.2146e-01, 9.9684e-01,
2.2268e-01, 1.4701e-01, -9.9999e-01, -9.6879e-01, 9.9483e-01,
9.9992e-01, -9.9977e-01, 9.9892e-01, 9.9656e-01, -9.3349e-01,
2.5862e-01, 9.7359e-01, -9.9937e-01, 9.8777e-01, -9.9999e-01,
1.1818e-01, 9.9960e-01, -1.7951e-01, -9.9984e-01, -9.2495e-01,
-2.2660e-02, 7.8255e-01, -2.6023e-02, 9.9999e-01, -1.2445e-02,
1.5701e-01, -9.9998e-01, -9.9624e-01, -8.6672e-01, 3.4873e-01,
9.9931e-01, -9.9999e-01, -6.6310e-02, 9.9949e-01, -9.9926e-01,
-4.1633e-01, 4.3387e-02, 8.4618e-02, -8.7278e-02, -9.9765e-01,
-9.9999e-01, -9.9998e-01, 9.9993e-01, 1.0225e-01, -5.4221e-04,
9.9924e-01, 9.9998e-01, 9.9997e-01, -9.8936e-01, 9.3540e-01,
9.9986e-01, -3.1887e-01, 1.1548e-01, -9.8294e-01, 1.4084e-01,
-8.1032e-01, -9.9606e-01, 1.2704e-01, 2.7952e-01, -6.5889e-01,
-9.9392e-01, 9.9999e-01, 9.9994e-01, 1.0000e+00, -1.0210e-01,
-9.4733e-01, 8.3178e-01, -9.4359e-01, -9.9962e-01, -4.4847e-02,
9.9938e-01, -9.9812e-01, 1.7198e-01, 7.5851e-02, -9.4664e-01,
9.9917e-01, -9.9949e-01, 1.5547e-01, -1.0000e+00, -9.9998e-01,
-9.9998e-01, 1.0000e+00, 9.2368e-02, -1.2598e-01, -9.9929e-01,
1.0000e+00, 9.8569e-01, -9.6164e-01, -2.5984e-01, 9.9998e-01,
-4.7267e-01, -8.6810e-01, -1.0000e+00, -9.9985e-01, 9.9819e-01,
1.2791e-01, 9.9999e-01, 8.4013e-01, -9.9762e-01, 9.8651e-01,
9.7417e-01, 3.1610e-01, -9.9945e-01, -9.9936e-01, -3.3195e-03,
7.0084e-02, 1.5902e-01, 9.8477e-03, -5.9952e-02, 9.9992e-01,
-3.2020e-02, -9.5302e-02, -3.2294e-01, 1.0000e+00, 8.7427e-01,
-9.9866e-01, -6.7442e-01, -8.8977e-02, -9.9465e-01, -9.9605e-01,
...
1.7911e-02, 4.8671e-01],
[4.0732e-01, 3.8137e-02, 9.6832e-03, ..., 4.4490e-02,
2.2998e-02, 4.0793e-01],
[1.7047e-01, 3.6989e-02, 2.3646e-02, ..., 4.6833e-02,
2.5233e-01, 1.6721e-01]]]], grad_fn=<SoftmaxBackward0>)), cross_attentions=None)
output.last_hidden_state.size()
torch.Size([1, 12, 768])
len(inputs["input_ids"][0])
12
from transformers import AutoModelForSequenceClassification, BertForSequenceClassification
# 设置num_labels控制输出的分类个标签个数
clz_model = AutoModelForSequenceClassification.from_pretrained("rbt3", num_labels=10)
Some weights of BertForSequenceClassification were not initialized from the model checkpoint at /u01/zhanggaoke/project/transformers-code-master/model/rbt3 and are newly initialized: ['classifier.weight', 'classifier.bias']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
clz_model(**inputs) # 将token输入给模型
# 可以看到tensor中有十个分类标签
SequenceClassifierOutput(loss=None, logits=tensor([[-0.0586, 0.7688, -0.0553, 0.2559, -0.2247, 0.4708, -0.6269, -0.3731,
0.0453, -0.1891]], grad_fn=<AddmmBackward0>), hidden_states=None, attentions=None)
clz_model.config.num_labels # 可以通过调这个参数,来控制输出几个标签分类
10