Encoder-Decoer模型共享embedding矩阵,embedding矩阵的参数更新问题

最近做生成式问答,尝试用bert做encoder,transformer-decoder做decoder框架来做。遇到一个问题,就是我想让decoder共享bert的embedding矩阵,但是由于设置了decoder和encoder学习速率不同,因此,我不知道embedding矩阵参数如何更新?会不会收到decoder端的影响,于是做了下面的实验。

import torch
import torch.nn as nn


class Encoder(nn.Module):
    def __init__(self, ):
        super(Encoder, self).__init__()
        self.embeddings = nn.Embedding(100, 50)
        self.fc = nn.Linear(50, 1)

    def forward(self, input):

        feature = self.embeddings(input)
        feature = self.fc(feature)

        return feature


class Decoder(nn.Module):
    def __init__(self):
        super(Decoder, self).__init__()
        self.embeddings = None
        self.fc = nn.Linear(50, 1)

    def forward(self, input):
        feature = self.embeddings(input)
        feature = self.fc(feature)

        return feature


class myModel(nn.Module):
    def __init__(self):
        super(myModel, self).__init__()
        self.encoder = Encoder()
        self.decoder = Decoder()

        self.decoder.embeddings = self.encoder.embeddings

    def forward(self, enc_input, dec_input):
        enc_ = self.encoder(enc_input)
        dec_ = self.decoder(dec_input)

        return enc_.sum() + dec_.sum()


model = myModel()

enc_param = []
dec_param = []
for n,p in list(model.named_parameters()):
    if n.split('.')[0] == 'encoder':
        enc_param.append((n, p))
    else:
        dec_param.append((n, p))

optimizer_grouped_parameters = [
            # bert other module
            {"params": [p for n, p in enc_param],
             'lr': 0.01},
            {"params": [p for n, p in dec_param],
             'lr': 0.001},
        ]


optim = torch.optim.SGD(optimizer_grouped_parameters)

enc_input = torch.arange(0, 10).unsqueeze(0)
dec_input = torch.arange(5, 15).unsqueeze(0)

loss = model(enc_input, dec_input)

optim.zero_grad()
loss.backward()
optim.step()


print(id(model.encoder.embeddings))
print(id(model.decoder.embeddings))

print([n for (n, p) in dec_param])
print([n for (n, p) in enc_param])

'''输出
140206391178048
140206391178048
['decoder.fc.weight', 'decoder.fc.bias']
['encoder.embeddings.weight', 'encoder.fc.weight', 'encoder.fc.bias']

'''


根据打印结果,发现embedding只在encoder的参数组中,而且decoder的embedding与encoder的embedding在内存中地址一样,说明是共享的,所以我的担心是多虑的。

你可能感兴趣的:(pytorch,pytorch,自然语言处理,python)