vscode PyTorch debug 代码

在学习PyTorch的时候,有些函数的代码臃肿,理解起来很烦,这里介绍一种快速理解代码意思的方法。

例如我们要理解下面这段代码

# 已知predict_y和test_label都是tensor类型数据
accuracy = (predict_y == test_label).sum().item() / test_label.size(0)

我们新建一个脚本,创建两个tensor依次拆解步骤测试,尝试理解:

# sum()测试
a = torch.tensor([1, 2, 3])
b = torch.tensor([1, 2, 4])
c = (a == b).sum()
print(c)

# size(0)测试
d = torch.randn(2, 3)
print(d.size(0))
e = d.size(0)
print(e)

先把简单的理解了,后面复杂的代码就理解了,这个方法挺好使。

你可能感兴趣的:(人工智能,pytorch,vscode,python)