[转载]PyTorch代码调试利器: 自动print每行代码的Tensor信息

[转载]PyTorch代码调试利器: 自动print每行代码的Tensor信息
https://cloud.tencent.com/developer/article/1449507

GitHub 项目地址: https://github.com/zasdfgbnm/TorchSnooper

大家可能遇到这样子的困扰:比如说运行自己编写的 PyTorch 代码的时候,PyTorch 提示你说数据类型不匹配,需要一个 double 的 tensor 但是你给的却是 float;再或者就是需要一个 CUDA tensor, 你给的却是个 CPU tensor。

这种问题调试起来很麻烦,因为你不知道从哪里开始出问题的。

TorchSnooper 就是一个设计了用来解决这个问题的工具。TorchSnooper 的安装非常简单,只需要执行标准的 Python 包安装指令就好:

pip install torchsnooper

安装完了以后,只需要用 @torchsnooper.snoop() 装饰一下要调试的函数,这个函数在执行的时候,就会自动 print 出来每一行的执行结果的 tensor 的形状、数据类型、设备、是否需要梯度的信息。

你可能感兴趣的:([转载]PyTorch代码调试利器: 自动print每行代码的Tensor信息)