torch.tensor 和numpy的相互转化及float格式

numpy中的ndarray转化成pytorch中的tensor : torch.from_numpy()

pytorch中的tensor转化成numpy中的ndarray : numpy()

需要注意的是 torch.from_numpy() 的默认格式是 torch.float64

为了在pytorch 中工作,我们需要将其转为 torch.float32,

需要使用 torch.from_numpy(x).float()

你可能感兴趣的:(人工智能,python)