torch.nn.GRU()函数解读

参考链接

  • 代码示例

一个序列时:

>>> import torch.nn as nn
>>> gru = nn.GRU(input_size=50, hidden_size=50, batch_first=True)
>>> embed = nn.Embedding(3, 50)
>>> x = torch.LongTensor([[0, 1, 2]])
>>> x_embed = embed(x)
>>> x.size()
torch.Size([1, 3])
>>> x_embed.size()
torch.Size([1, 3, 50])
>>> out, hidden = gru(x_embed)
>>> out.size()
torch.Size([1, 3, 50])
>>> hidden.size()
torch.Size([1, 1, 50])

两个示例时:

>>> x = torch.LongTensor([[0, 1, 2], [0, 1, 2]])
>>> x_embed = embed(x)
>>> x_embed.size()
torch.Size([2, 3, 50])
>>> out, hidden = gru(x_embed)
>>> out.size()
torch.Size([2, 3, 50])
>>> hidden.size()
torch.Size([1, 2, 50])

嵌入时:

>>> x = torch.LongTensor([[0, 1, 2], [0, 1, 2]])
>>> x_embed = embed(x)
>>> out1, hidden = gru(x_embed)
>>> out1.size()
torch.Size([2, 3, 50])
>>> hidden.size()
torch.Size([1, 2, 50])
>>> out2, hidden = gru(x_embed, hidden)
>>> out.size()
torch.Size([2, 3, 50])
>>> out2.size()
torch.Size([2, 3, 50])
>>> hidden.size()
torch.Size([1, 2, 50])

你可能感兴趣的:(torch.nn.GRU()函数解读)