transformers BertModel

API

最原始的Bert Model,输出raw hidden-states。
是PyTorch torch.nn.Module sub-class.

class transformers.BertModel
参数 描述
config (BertConfig) 配置对象,初始化配置,并不会载入权重。
forward(input_ids=None, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None, inputs_embeds=None, encoder_hidden_states=None, encoder_attention_mask=None, output_attentions=None, output_hidden_states=None, return_dict=None)

The BertModel forward method, overrides the __call__()special method.

参数 描述
input_ids (torch.LongTensor of shape (batch_size, sequence_length))
attention_mask (torch.FloatTensor of shape (batch_size, sequence_length), optional, defaults to None)
token_type_ids (torch.LongTensor of shape (batch_size, sequence_length), optional, defaults to None)
position_ids (torch.LongTensor of shape (batch_size, sequence_length), optional, defaults to None)

返回

参数 描述
loss (optional, returned when labels is provided, torch.FloatTensor of shape (1,))
prediction_logits (torch.FloatTensor of shape (batch_size, sequence_length, config.vocab_size))
seq_relationship_logits (torch.FloatTensor of shape (batch_size, 2))
hidden_states (tuple(torch.FloatTensor), optional, returned when output_hidden_states=True is passed or when config.output_hidden_states=True)
attentions (tuple(torch.FloatTensor), optional, returned when output_attentions=True is passed or when config.output_attentions=True)
from transformers import BertTokenizer, BertForPreTraining
import torch
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
model = BertForPreTraining.from_pretrained('bert-base-uncased', return_dict=True)
inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
outputs = model(**inputs)
prediction_logits = outputs.prediction_logits
seq_relationship_logits = outputs.seq_relationship_logits

你可能感兴趣的:(Python,python)