Pytorch小技巧:布尔类型(True/False)转换为浮点类型(1.0/0.0)的成长史

背景

  • 有些时候,我们会有这么一个小需求:把布尔类型的tensor转换为浮点类型/整数类型的tensor(下文代码以浮点类型示范)。
  • 身处不同的编程时期,为了实现同一功能,我们往往会写出不一样的代码,下面三个不同代码是我身处不同阶段的缩影。

代码

编程时期1——学会了列表生成式

import torch
import time
bool_tensor = torch.tensor([True, False, True])
start_time = time.time()
# 统计10w次,比较三种代码的时间复杂度
for _ in range(100000): 
	# 列表生成式	
    float_tensor = torch.tensor([1.0 if value else 0.0 for value in bool_tensor])
end_time = time.time()
print(float_tensor, "cost_time: ", end_time - start_time)

# tensor([1., 0., 1.]) cost_time:  0.7398970127105713
  • 小结:使用列表生成式处理10w次需求,需要大概0.74s

编程时期2——学会了torch.where()

import torch
import time
bool_tensor = torch.tensor([True, False, True])
start_time = time.time()
# 统计10w次,比较三种代码的时间复杂度
for _ in range(100000): 
	# torch.where	
    float_tensor = torch.where(bool_tensor, 1.0, 0.0)
end_time = time.time()
print(float_tensor, "cost_time: ", end_time - start_time)
# tensor([1., 0., 1.]) cost_time:  0.632805347442627
  • 小结:使用torch.where这个api处理10w次需求,需要大概0.63s

编程时期3——突然发现新大陆:原来bool类型和浮点类型也能进行类型转换

import torch
import time
bool_tensor = torch.tensor([True, False, True])
start_time = time.time()
# 统计10w次,比较三种代码的时间复杂度
for _ in range(100000): 
	# 类型转换	
    float_tensor = bool_tensor.float()
end_time = time.time()
print(float_tensor, "cost_time: ", end_time - start_time)
# tensor([1., 0., 1.]) cost_time:  0.1671144962310791
  • 小结:使用强制类型转换处理10w次需求,需要大概0.17s

总结:强制类型转换真香!!!

你可能感兴趣的:(PyTorch,python,PyTorch实战,pytorch,深度学习,人工智能)