代码:
# load checkpoints,如何上线
model = Sequence2Sequence(len(src_word2idx), len(trg_word2idx))
model.load_state_dict(torch.load(f"./best.ckpt", weights_only=True,map_location="cpu"))
class Translator:
def __init__(self, model, src_tokenizer, trg_tokenizer):
self.model = model
self.model.eval() # 切换到验证模式
self.src_tokenizer = src_tokenizer
self.trg_tokenizer = trg_tokenizer
def draw_attention_map(self, scores, src_words_list, trg_words_list):
"""绘制注意力热力图
Args:
- scores (numpy.ndarray): shape = [source sequence length, target sequence length]
"""
plt.matshow(scores.T, cmap='viridis') # 注意力矩阵,显示注意力分数值
# 获取当前的轴
ax = plt.gca()
# 设置热图中每个单元格的分数的文本
for i in range(scores.shape[0]): #输入
for j in range(scores.shape[1]): #输出
ax.text(j, i, f'{scores[i, j]:.2f}', # 格式化数字显示
ha='center', va='center', color='k')
plt.xticks(range(scores.shape[0]), src_words_list)
plt.yticks(range(scores.shape[1]), trg_words_list)
plt.show()
def __call__(self, sentence):
sentence = preprocess_sentence(sentence) # 预处理句子,标点符号处理等
encoder_input, attn_mask = self.src_tokenizer.encode(
[sentence.split()],
padding_first=True,
add_bos=True,
add_eos=True,
return_mask=True,
) # 对输入进行编码,并返回encode_piadding_mask
encoder_input = torch.Tensor(encoder_input).to(dtype=torch.int64) # 转换成tensor
preds, scores = model.infer(encoder_input=encoder_input, attn_mask=attn_mask) #预测
trg_sentence = self.trg_tokenizer.decode([preds], split=True, remove_eos=False)[0] #通过tokenizer转换成文字
src_decoded = self.src_tokenizer.decode(
encoder_input.tolist(),
split=True,
remove_bos=False,
remove_eos=False
)[0] #对输入编码id进行解码,转换成文字,为了画图
self.draw_attention_map(
scores.squeeze(0).numpy(),
src_decoded, # 注意力图的源句子
trg_sentence # 注意力图的目标句子
)
return " ".join(trg_sentence[:-1])
这段代码实现了一个基于序列到序列(Sequence-to-Sequence)模型的翻译器
首先加载模型,从文件best.ckpt加载预训练模型的参数
weights_only=True
确保只加载模型权重,避免潜在的安全风险。map_location="cpu"
强制将模型加载到 CPU(即使训练时用 GPU)。没有GPU的可以使用这个然后是Translator类,初始化翻译器,把模式变成eval模式,绑定源语言和目标语言的分词器。
这里先解释call方法,首先预处理句子,然后编码输入,把文本转化为tokenID,左端填充Padding,添加BOS和EOS标记。
然后执行推理生成语言的TokenID(preds)和注意力分数(scores)。
然后是解码输出把TokenID变为文本,最后可视化注意力。