pytorch炼丹工具必备——自动查看每个Tensor的shape

说明

在构建深度学习神经网络中,每个变量的数据类型以及形状大小一行一行print太麻烦了。
介绍一个PyTorch代码调试利器TorchSnooper,用于自动print每行代码的Tensor信息
GitHub 项目地址: https://github.com/zasdfgbnm/TorchSnooper
参考文章:
https://mp.weixin.qq.com/s/PVIWWIqbZuEe4lwVMDYA5Q

使用方法:
直接查看一个函数内的所有变量:

import torch
import torchsnooper


@torchsnooper.snoop()
def myfunc(mask, x):
    y = torch.zeros(6)
    y.masked_scatter_(mask, x)
    return y


mask = torch.tensor([0, 1, 0, 1, 1, 0], device='cuda')
source = torch.tensor([1.0, 2.0, 3.0], device='cuda')
y = myfunc(mask, source)

直接查看循环中的变量

import torch
import torchsnooper


model = torch.nn.Linear(2, 1)


x = torch.tensor([[0.0, 0.0], [0.0, 1.0], [1.0, 0.0], [1.0, 1.0]])
y = torch.tensor([3.0, 5.0, 4.0, 6.0])


optimizer = torch.optim.SGD(model.parameters(), lr=0.1)
循环
with torchsnooper.snoop():
    for _ in range(10):
        optimizer.zero_grad()
        pred = model(x)
        squared_diff = (y - pred) ** 2
        loss = squared_diff.mean()
        print(loss.item())
        loss.backward()
        optimizer.step()

你可能感兴趣的:(炼丹神器,pytorch,深度学习,python)