import torch
from torch import nn
import torch.nn.functional as F
import math
class AttentionHead(nn.Module):
def __init__(self, embed_dim, head_dim):
super().__init__()
self.q = nn.Linear(embed_dim, head_dim)
self.k = nn.Linear(embed_dim, head_dim)
self.v = nn.Linear(embed_dim, head_dim)
self.dropout = nn.Dropout(0.2)
def forward(self, query, key, value, mask=None):
query, key, value = self.q(query), self.k(key), self.v(value)
scores = torch.bmm(query, key.transpose(1, 2)) / math.sqrt(query.size(-1))
if mask is not None:
mask = mask.unsqueeze(dim=1).repeat(1, mask.size(1), 1)
assert scores.size() == mask.size()
scores = scores.masked_fill(mask == 0, -float("inf"))
weights = self.dropout(F.softmax(scores, dim=-1))
return torch.bmm(weights, value)
class SequentialMultiHeadAttention(nn.Module):
def __init__(self, hidden_size, num_heads):
super().__init__()
head_dim = hidden_size // num_heads
self.heads = nn.ModuleList(
[AttentionHead(hidden_size, head_dim) for _ in range(num_heads)]
)
self.output_linear = nn.Linear(hidden_size, hidden_size)
self.dropout = nn.Dropout(0.2)
def forward(self, query, key, value, attn_mask=None, query_mask=None, key_mask=None):
if query_mask is not None and key_mask is not None:
attn_mask = torch.bmm(query_mask.unsqueeze(-1), key_mask.unsqueeze(1))
x = torch.cat([h(query, key, value, attn_mask) for h in self.heads], dim=-1)
x = self.dropout(self.output_linear(x))
return x
class ScaledDotProductAttention(nn.Module):
def __init__(self, d_k, dropout=.1):
super(ScaledDotProductAttention, self).__init__()
self.scale_factor = math.sqrt(d_k)
self.dropout = nn.Dropout(dropout)
def forward(self, q, k, v, attn_mask=None):
scores = torch.matmul(q, k.transpose(-1, -2)) / self.scale_factor
if attn_mask is not None:
assert attn_mask.size() == scores.size()
scores.masked_fill_(attn_mask == 0, -1e9)
attn = self.dropout(F.softmax(scores, dim=-1))
context = torch.matmul(attn, v)
return context, attn
class _MultiHeadAttention(nn.Module):
def __init__(self, d_k, d_v, d_model, n_heads, dropout):
super(_MultiHeadAttention, self).__init__()
self.d_k = d_k
self.d_v = d_v
self.d_model = d_model
self.n_heads = n_heads
self.w_q = nn.Linear(d_model, d_k * n_heads)
self.w_k = nn.Linear(d_model, d_k * n_heads)
self.w_v = nn.Linear(d_model, d_v * n_heads)
self.attention = ScaledDotProductAttention(d_k, dropout)
def forward(self, q, k, v, attn_mask):
b_size = q.size(0)
q_s = self.w_q(q).view(b_size, -1, self.n_heads, self.d_k).transpose(1, 2)
k_s = self.w_k(k).view(b_size, -1, self.n_heads, self.d_k).transpose(1, 2)
v_s = self.w_v(v).view(b_size, -1, self.n_heads, self.d_v).transpose(1, 2)
if attn_mask is not None:
attn_mask = attn_mask[:, None, None,:].repeat(1, self.n_heads, attn_mask.size(1), 1)
context, attn = self.attention(q_s, k_s, v_s, attn_mask=attn_mask)
context = context.transpose(1, 2).contiguous().view(b_size, -1, self.n_heads * self.d_v)
return context, attn
class MultiHeadAttention(nn.Module):
def __init__(self, d_k, d_v, d_model, n_heads, dropout):
super(MultiHeadAttention, self).__init__()
self.n_heads = n_heads
self.multihead_attn = _MultiHeadAttention(d_k, d_v, d_model, n_heads, dropout)
self.proj = nn.Linear(n_heads * d_v, d_model)
self.dropout = nn.Dropout(dropout)
def forward(self, q, k, v, attn_mask):
context, attn = self.multihead_attn(q, k, v, attn_mask=attn_mask)
output = self.dropout(self.proj(context))
return output
class FeedForward(nn.Module):
def __init__(self, hidden_size, intermediate_size, dropout):
super().__init__()
self.linear_1 = nn.Linear(hidden_size, intermediate_size)
self.linear_2 = nn.Linear(intermediate_size, hidden_size)
self.gelu = nn.GELU()
self.dropout = nn.Dropout(dropout)
def forward(self, x):
x = self.linear_1(x)
x = self.gelu(x)
x = self.linear_2(x)
x = self.dropout(x)
return x
class TransformerEncoderLayer(nn.Module):
def __init__(self, config):
super().__init__()
self.layer_norm_1 = nn.LayerNorm(config.hidden_size)
self.layer_norm_2 = nn.LayerNorm(config.hidden_size)
d_k = config.hidden_size // config.num_heads
self.attention = MultiHeadAttention(
d_k=d_k,
d_v=d_k,
d_model=config.hidden_size,
n_heads=config.num_heads,
dropout=config.dropout
)
self.feed_forward = FeedForward(
config.hidden_size,
config.intermediate_size,
config.dropout
)
def forward(self, x, mask=None):
x = x + self.attention(x, x, x, attn_mask=mask)
x = self.layer_norm_1(x)
x = x + self.feed_forward(x)
x = self.layer_norm_2(x)
return x
class Embeddings(nn.Module):
def __init__(self, config):
super().__init__()
self.token_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=0)
self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)
self.layer_norm = nn.LayerNorm(config.hidden_size, eps=1e-12)
self.dropout = nn.Dropout(0.2)
def forward(self, input_ids):
seq_length = input_ids.size(1)
position_ids = torch.arange(seq_length, dtype=torch.long, device=input_ids.device).unsqueeze(0)
token_embeddings = self.token_embeddings(input_ids)
position_embeddings = self.position_embeddings(position_ids)
embeddings = token_embeddings + position_embeddings
embeddings = self.layer_norm(embeddings)
embeddings = self.dropout(embeddings)
return embeddings
class TransformerEncoder(nn.Module):
def __init__(self, config):
super().__init__()
self.layers = nn.ModuleList(
[TransformerEncoderLayer(config) for _ in range(config.num_hidden_layers)]
)
def forward(self, x, mask=None):
for layer in self.layers:
x = layer(x, mask)
return x
class TransModel(nn.Module):
def __init__(self, config):
super(TransModel, self).__init__()
self.encoder = TransformerEncoder(config)
self.embedding = Embeddings(config)
self.tanh1 = nn.Tanh()
self.w1 = nn.Parameter(torch.randn(config.hidden_size))
self.tanh2 = nn.Tanh()
self.w2 = nn.Parameter(torch.randn(config.hidden_size))
if config.use_know:
self.classifier = nn.Sequential(
nn.Linear(config.hidden_size * 2, config.hidden_size2),
nn.ReLU(),
nn.Linear(config.hidden_size2, config.num_classes)
)
else:
self.classifier = nn.Sequential(
nn.Linear(config.hidden_size , config.hidden_size2),
nn.ReLU(),
nn.Linear(config.hidden_size2, config.num_classes)
)
self.loss_fn = nn.CrossEntropyLoss()
self.config = config
def get_features(self, input_ids, masks):
if len(input_ids.size()) == 3:
input_ids = input_ids.view(-1, input_ids.size(-1))
emb = self.embedding(input_ids)
hidden_vec = self.encoder(emb, masks)
hidden_vec = self.tanh1(hidden_vec)
alpha = F.softmax(torch.matmul(hidden_vec, self.w1), dim=-1).unsqueeze(-1)
out = hidden_vec * alpha
out = torch.sum(out, dim=-2)
return out
def forward(self, input_ids, masks, know_input_ids, know_masks, labels=None):
sent_feature = self.get_features(input_ids, masks)
if self.config.use_know:
know_feature = self.get_features(know_input_ids, know_masks)
know_feature = know_feature.view(input_ids.size(0), -1, know_feature.size(-1))
alpha = F.softmax(self.tanh2(torch.matmul(know_feature, self.w2)), dim=-1).unsqueeze(dim=-1)
know_feature = (know_feature * alpha).sum(dim=-2)
out = self.classifier(torch.cat([sent_feature, know_feature], dim=-1))
else:
out = self.classifier(sent_feature)
if labels is not None:
loss = self.loss_fn(out, labels)
return_tuple = (loss, out)
else:
return_tuple = (out, )
return return_tuple
if __name__ == '__main__':
from transformers import AutoConfig
from transformers import AutoTokenizer
model_ckpt = "bert-base-uncased"
tokenizer = AutoTokenizer.from_pretrained(model_ckpt)
config = AutoConfig.from_pretrained(model_ckpt)
text = "time flies like an arrow"
inputs = tokenizer(text, return_tensors="pt", add_special_tokens=False)
encoder = TransformerEncoder(config)
print(encoder(inputs.input_ids).size())