import collections
import math
import torch
from torch import nn
from d2l import torch as d2l
class Seq2SeqEncoder(d2l.Encoder):
def __init__(self,vocab_size,embed_size,num_hiddens,num_layers,dropout=0,**kwargs):
super(Seq2SeqEncoder, self).__init__(**kwargs)
self.embedding = nn.Embedding(vocab_size,embed_size)
self.rnn=nn.GRU(embed_size,num_hiddens,num_layers,dropout=dropout)
def forward(self,x,*args):
x=self.embedding(x)
x=x.permute(1,0,2)
output,state=self.rnn(x)
return output,state
encoder=Seq2SeqEncoder(vocab_size=10,embed_size=8,num_hiddens=16,num_layers=2)
encoder.eval()
x=torch.zeros((4,7),dtype=torch.long)
output,state=encoder(x)
class Seq2SeqDecoder(d2l.Decoder):
def __init__(self, vocab_size, embed_size, num_hiddens, num_layers, dropout=0, **kwargs):
super(Seq2SeqDecoder, self).__init__(**kwargs)
self.embedding = nn.Embedding(vocab_size, embed_size)
self.rnn = nn.GRU(embed_size+num_hiddens, num_hiddens, num_layers, dropout=dropout)
self.dense=nn.Linear(num_hiddens,vocab_size)
def init_state(self,enc_outputs,*args):
return enc_outputs[1]
def forward(self, x, state,*args):
x = self.embedding(x)
x = x.permute(1, 0, 2)
context=state[-1].repeat(x.shape[0],1,1)
x_and_context=torch.cat((x,context),2)
output, state = self.rnn(x_and_context,state)
output=self.dense(output).permute(1,0,2)
return output, state
decoder=Seq2SeqDecoder(vocab_size=10,embed_size=8,num_hiddens=16,num_layers=2)
decoder.eval()
state = decoder.init_state(encoder(x))
output,state=decoder(x,state)
print(output.shape)
print(state.shape)
def sequence_mask(x,valid_len,value=0):
maxlen=x.size(1)
mask=torch.arange((maxlen),dtype=torch.float32,device=x.device)[None,:]<valid_len[:,None]
x[~mask]=value
return x