在代码中用如下 from_pretrained() 函数下载bert等预训练模型时下载巨慢:
from transformers import BertTokenizer,BertModel,BertConfig
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
config = BertConfig.from_pretrained('bert-base-uncased')
config.update({'output_hidden_states':True})
model = BertModel.from_pretrained("bert-base-uncased",config=config)
这里以 ‘bert-base-uncased’ 预训练模型为例(下其它预训练模型也一样的步骤),自己手动下载:
hugging face官网:https://huggingface.co/models
【建议用谷歌浏览器打开,我用qq浏览器打开会有奇奇怪怪的问题】
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
config = BertConfig.from_pretrained('bert-base-uncased')
config.update({'output_hidden_states':True})
model = BertModel.from_pretrained("bert-base-uncased",config=config)
ids = tokenizer(['hello world fuck','I love you'],return_tensors='pt')
outs = model(**ids)
print(len(outs))
print(outs[0].shape)
print(outs[1].shape)
print(outs[2][0].shape)
输出如下,成功调用预训练模型:
3
torch.Size([2, 5, 768])
torch.Size([2, 768])
torch.Size([2, 5, 768])
清华源镜像了 Hugging Face Model Hub,为国内用户下载预训练模型数据提供便利。
在from_pretrained()中加入参数mirror='tuna’即可加速下载,绝了。
【!!!⭐不过清华源上可能有些模型没有,且可能没hugging face官网更新的及时。所以有时还是得手动下载】
model = BertModel.from_pretrained('bert-base-uncased', mirror='tuna')