tensor2tensor项目中机器翻译中的bug

在利用transformer模型训练中英互译模型时,自己实现了一个translate的problem,重新实现了generate_encoded_sample,并提供自己预处理后的vocab.en和vocab.zh。

tensor2tensor提供了一个非常好的功能,只要在translate后面加上_rev,就能实现源语言和目标语言反转。也就是只要实现英文-->中文的训练过程,datagen只执行一次,中文-->英文的模型训练时基本不用动,只需要在对应的problem后面加_rev,然后训练就可以了,非常方便。

但是代码中有个bug。加上rev之后,训练没有问题。导出后,运行TensorFlow_model_server,运行query.py之后会报错:requested more than 0 entries, but params is empty。也就是这个issue中提到的问题,有人提出了方法解决该问题。

我尝试着彻底修复这个bug,在翻了几遍translate相关的代码之后,发现在构建example和执行request_fn过程中都没有问题。但是通过grpc访问模型时,就出错了。原因在提供serving服务时,重新构建一个metagraph,而该模型的输入就是problem.py中函数serving_input_fn,它其实是个placeholder,在最终通过模型进行predict之前,会进行一系列的处理,其中一条就是 dataset = dataset.map(self.maybe_reverse_and_copy),这个处理过程是不需要的,因为translate_rev中已经把inputs和targets相关的参数、feature、vocabulary等等已经对调过了。这里多这一步相当于重复了。所以在上面提到的issue里有人给出了一个trick来绕过这个bug,他的方法是可行的。原因就在这里。

正确的解决方法就是删掉problem.py中,serving_input_fn函数中的dataset.map(self.maybe_reverse_and_copy)一行,就OK了。

你可能感兴趣的:(自然语言处理)