seq2seq推理模块设计

代码:

# load checkpoints,如何上线
model = Sequence2Sequence(len(src_word2idx), len(trg_word2idx))
model.load_state_dict(torch.load(f"./best.ckpt", weights_only=True,map_location="cpu"))

class Translator:
    def __init__(self, model, src_tokenizer, trg_tokenizer):
        self.model = model
        self.model.eval() # 切换到验证模式
        self.src_tokenizer = src_tokenizer
        self.trg_tokenizer = trg_tokenizer

    def draw_attention_map(self, scores, src_words_list, trg_words_list):
        """绘制注意力热力图

        Args:
            - scores (numpy.ndarray): shape = [source sequence length, target sequence length]
        """
        plt.matshow(scores.T, cmap='viridis') # 注意力矩阵,显示注意力分数值
        # 获取当前的轴
        ax = plt.gca()

        # 设置热图中每个单元格的分数的文本
        for i in range(scores.shape[0]): #输入
            for j in range(scores.shape[1]): #输出
                ax.text(j, i, f'{scores[i, j]:.2f}',  # 格式化数字显示
                               ha='center', va='center', color='k')

        plt.xticks(range(scores.shape[0]), src_words_list)
        plt.yticks(range(scores.shape[1]), trg_words_list)
        plt.show()

    def __call__(self, sentence):
        sentence = preprocess_sentence(sentence) # 预处理句子,标点符号处理等
        encoder_input, attn_mask = self.src_tokenizer.encode(
            [sentence.split()],
            padding_first=True,
            add_bos=True,
            add_eos=True,
            return_mask=True,
            ) # 对输入进行编码,并返回encode_piadding_mask
        encoder_input = torch.Tensor(encoder_input).to(dtype=torch.int64) # 转换成tensor

        preds, scores = model.infer(encoder_input=encoder_input, attn_mask=attn_mask) #预测

        trg_sentence = self.trg_tokenizer.decode([preds], split=True, remove_eos=False)[0] #通过tokenizer转换成文字

        src_decoded = self.src_tokenizer.decode(
            encoder_input.tolist(),
            split=True,
            remove_bos=False,
            remove_eos=False
            )[0] #对输入编码id进行解码,转换成文字,为了画图

        self.draw_attention_map(
            scores.squeeze(0).numpy(),
            src_decoded, # 注意力图的源句子
            trg_sentence # 注意力图的目标句子
            )
        return " ".join(trg_sentence[:-1])

这段代码实现了一个基于序列到序列(Sequence-to-Sequence)模型的翻译器

首先加载模型,从文件best.ckpt加载预训练模型的参数 

  • weights_only=True 确保只加载模型权重,避免潜在的安全风险。
  • map_location="cpu" 强制将模型加载到 CPU(即使训练时用 GPU)。没有GPU的可以使用这个

然后是Translator类,初始化翻译器,把模式变成eval模式,绑定源语言和目标语言的分词器。

这里先解释call方法,首先预处理句子,然后编码输入,把文本转化为tokenID,左端填充Padding,添加BOS和EOS标记。

然后执行推理生成语言的TokenID(preds)和注意力分数(scores)。

然后是解码输出把TokenID变为文本,最后可视化注意力。


 

你可能感兴趣的:(PyTorch,深度学习,人工智能)