增加最优传输过程中遇到的问题

最近,正在使用最优传输,在使用的过程中遇到了一下问题,简单记录一下这些问题。该文章仅用于记录学习,不做其他用途,参考的文章均声明。

修改代码过程中,遇到的错误总结

Question 1. 描述:

linear(): argument ‘input’ (position 1) must be Tensor, not DataFrame

翻译:使用 torch.nn.Linear() 时,必须要求数据是 tensor 类型
解决方案:torch.nn.linear() 使用时要求数据必须是 tensor 类型的。
增加最优传输过程中遇到的问题_第1张图片

Question 2. 描述

RuntimeError:expected scalar type Float but found Double

翻译:RuntimeError:期望标量类型为 Float,但发现为 Double
解决方案:首先看了很多博主写的博客,然后自己一步一步调试,发现代码出现的问题在使用线性变换时,我用于处理的类型是 float64,而 torch.nn.linear() 生成的 weightbias 两个参数是 float32 类型的,通过查阅资料发现,Tensorfloat64float32 不是一个类型,参考了,具体解决方法如下所示:
增加最优传输过程中遇到的问题_第2张图片

Question 3. 描述

‘function’ object has no attribute ‘parameters’

翻译:函数没有 parameters 这个属性
解决方案:经过了解 torch.nn.linear() 这个函数的用法和 Pytorch.optim之optim.Adam() ,参考了 Python-torch.optim优化算法理解之optim.Adam(),在 ly 的帮助下,了解了不能直接用 function ,于是在 linear_1() 函数中返回 linear1 。具体解决方法如下所示:
增加最优传输过程中遇到的问题_第3张图片

Question 4. 描述

mat1 and mat2 shapes cannot be multiplied (1140x8 and 2x2)

翻译:mat1mat2 大小不能相乘( 11408 和 22 )
解决方案:将 mat2 的大小换为 8*8,可用于和 mat1 进行计算的大小。将 Mapping(2) 中的 2 更换为 8 。就解决了该问题。
图片4

Question 5. 描述

Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu! (when checking argument for argument mat1 in method wrapper_addmm)

翻译:预计所有张量都在同一个设备上,但发现至少有两个设备,cuda:0和cpu!(当检查方法 wrapper_addmm 中的参数 mat1 的参数时)
解决方案:经过查阅解决方案,参考了文章 解决RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cp,逐步找到两个 mat1mat2 对应的设备,并对其中不一样的那个 mat 进行修改。

Question 6. 描述

Can’t call numpy() on Tensor that requires grad. Use tensor.detach().numpy() instead.

翻译:不能在需要梯度下降的 Tensor 上直接使用 numpy() ,用 tensor.detach().numpy() 代替。
解决方案:参考了 RuntimeError: Can‘t call numpy() on Tensor that requires grad. Use tensor.detach().numpy() instead.该博客的主要内容。该博客阐述了出现这种情况的原因:待转化类型的Tensor 变量带有梯度,直接将其转化为 numpy 数据将破坏计算图,因此 numpy 拒绝进行数据转换,实际上这是对开发者的一种提醒。如果自己在转换数据时不需要保留梯度信息,可以在变量转换之前添加 detach() 调用。

X_total = X_total.detach.numpy()

若数据部署在 GPU 上时,则修改为

X_total = X_total.cpu().detach.numpy()

你可能感兴趣的:(实验过程中遇到的问题总结,pytorch,python,深度学习)