import torch from torch import nn from typing import Optional class BertEmbeddings(nn.Module): """Construct the embeddings from word, position and token_type embeddings.""" def __init__(self, config): super().__init__() self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id) self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size) self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size) # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load # any TensorFlow checkpoint file self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.dropout = nn.Dropout(config.hidden_dropout_prob) # position_ids (1, len position emb) is contiguous in memory and exported when serialized self.position_embedding_type = getattr(config, "position_embedding_type", "absolute") self.register_buffer( "position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False ) self.register_buffer( "token_type_ids", torch.zeros(self.position_ids.size(), dtype=torch.long), persistent=False ) def forward( self, input_ids: Optional[torch.LongTensor] = None, token_type_ids: Optional[torch.LongTensor] = None, position_ids: Optional[torch.LongTensor] = None, inputs_embeds: Optional[torch.FloatTensor] = None, past_key_values_length: int = 0, ) -> torch.Tensor: if input_ids is not None: input_shape = input_ids.size() else: input_shape = inputs_embeds.size()[:-1] seq_length = input_shape[1] if position_ids is None: position_ids = self.position_ids[:, past_key_values_length : seq_length + past_key_values_length] # Setting the token_type_ids to the registered buffer in constructor where it is all zeros, which usually occurs # when its auto-generated, registered buffer helps users when tracing the model without passing token_type_ids, solves # issue #5664 if token_type_ids is None: if hasattr(self, "token_type_ids"): buffered_token_type_ids = self.token_type_ids[:, :seq_length] buffered_token_type_ids_expanded = buffered_token_type_ids.expand(input_shape[0], seq_length) token_type_ids = buffered_token_type_ids_expanded else: token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=self.position_ids.device) if inputs_embeds is None: inputs_embeds = self.word_embeddings(input_ids) token_type_embeddings = self.token_type_embeddings(token_type_ids) embeddings = inputs_embeds + token_type_embeddings if self.position_embedding_type == "absolute": position_embeddings = self.position_embeddings(position_ids) embeddings += position_embeddings embeddings = self.LayerNorm(embeddings) embeddings = self.dropout(embeddings) return embeddings # 配置类 class BertConfig: def __init__(self): self.vocab_size = 30522 # BERT 基础模型的词表大小 self.hidden_size = 768 # 隐藏层维度 self.pad_token_id = 0 # 填充token的ID self.max_position_embeddings = 512 # 最大位置编码长度 self.type_vocab_size = 2 # token类型数量(通常为2:句子A和句子B) self.layer_norm_eps = 1e-12 # LayerNorm的epsilon值 self.hidden_dropout_prob = 0.1 # dropout概率 # 创建配置实例 config = BertConfig() # 初始化BertEmbeddings embeddings = BertEmbeddings(config) # 示例1:基本输入(使用input_ids) input_ids = torch.tensor([ [101, 2054, 2003, 102], # [CLS] Hello world [SEP] [101, 2023, 4248, 102] # [CLS] How are [SEP] ]) # 形状 (batch_size=2, seq_length=4) # 前向传播 output = embeddings( input_ids=input_ids, token_type_ids=None, # 自动生成全零 position_ids=None, # 自动从position_ids缓冲区获取 inputs_embeds=None, # 使用input_ids past_key_values_length=0 # 无历史token ) print(f"输出形状: {output.shape}") # 应为 torch.Size([2, 4, 768]) # 示例2:使用预计算的inputs_embeds inputs_embeds = torch.rand(2, 4, config.hidden_size) # 随机初始化嵌入 output = embeddings( input_ids=None, # 使用inputs_embeds inputs_embeds=inputs_embeds ) print(f"输出形状: {output.shape}") # 应为 torch.Size([2, 4, 768]) # 示例3:自定义token_type_ids(句子对任务) token_type_ids = torch.tensor([ [0, 0, 0, 1], # 前3个token属于句子A,最后1个属于句子B [0, 0, 1, 1] # 前2个token属于句子A,后2个属于句子B ]) output = embeddings( input_ids=input_ids, token_type_ids=token_type_ids ) # 示例4:生成任务中使用past_key_values_length # 假设已生成3个token,当前输入长度为1 output = embeddings( input_ids=torch.tensor([[2054]]), # 当前token: "Hello" past_key_values_length=3 # 已生成3个token ) print(f"输出形状: {output.shape}") # 应为 torch.Size([1, 1, 768])