pytorch | 使用vmap对自定义函数进行并行化/ 向量化的执行

0. 参考

  1. pytorch官方文档:https://pytorch.org/docs/stable/generated/torch.func.vmap.html#torch-func-vmap
  2. 关于if语句如何执行:https://github.com/pytorch/functorch/issues/257

1. 问题背景

  1. 笔者现在需要执行如下的功能:
    root_ls = [func(x,b) for x in input]
    因此突然想到pytorch或许存在对于自定义的函数的向量化执行的支持

  2. 一顿搜索发现了from functorch import vmap这种好东西,虽然还在开发中,但是很多功能已经够用了

2. 具体例子

  1. 这里只介绍笔者需要的一个方面,vmap的其他支持还请参阅pytorch官方文档
  2. 自定义函数及其输入:
# 自定义函数
def func_2(t,b):
    return torch.where((t>5.),
                        t*b,
                        -t)
# 输入

t = torch.tensor([1.,2.,3.,4.,5.,6.,7.,8.])
b = torch.tensor([1.],requires_grad=True)

  • 注意1:自定义函数不要出现if,用torch.where替代。至于为什么参阅这个issue,大概的原因是“if isn’t a differentiability requirement;”,强行使用会报错error of Data-dependent control flow
  1. 然后对于b,我们需要扩张到和t同样的大小:
    b_extend = torch.expand_copy(b,size=t.shape) # 必须把b扩张到和t同一个size否则报错

  2. 利用vmap,它返回一个新的函数func_vec ,具有向量化执行的支持,也可以利用autograd求导

# Use vmap() to construct a new function.  
func_vec = vmap(func_2)  				# [N, D], [N, D] -> [N]
ans = func_vec(t,b_extend)
ans.sum().backward()   # 等价于: ans.backward(torch.ones(b_extend.shape))
b_extend.grad          # 可以预见:b的导数是t:在t>5.时导数是t,在t<=5.时导数是0

  1. 全部代码:
import torch
from functorch import vmap

# if分支isn't a differentiability requirement;
def func(t,b):
    tmp = t*b
    if tmp > 5:     # error: Data-dependent control flow
        root = t*b
    else:
        root = -t
    return root

def func_2(t,b):
    return torch.where((t>5.),
                        t*b,
                        -t)

t = torch.tensor([1.,2.,3.,4.,5.,6.,7.,8.])
b = torch.tensor([1.],requires_grad=True)
b_extend = torch.expand_copy(b,size=t.shape)    # 必须把b扩张到和t同一个size否则报错
b_extend.retain_grad()

print(f"shape of t:{t.shape}, shape of b_extend:{b_extend.shape}")
# shape of t:torch.Size([8]), shape of b_extend:torch.Size([8])


# Use vmap() to construct a new function.  # [D], [D] -> []
func_vec = vmap(func_2)  # [N, D], [N, D] -> [N]
ans = func_vec(t,b_extend)
ans.sum().backward()   # 等价于: ans.backward(torch.ones(b_extend.shape))

b_extend.grad          # 可以预见:b的导数是t:在t>5.时导数是t,在t<=5.时导数是0
# tensor([0., 0., 0., 0., 0., 6., 7., 8.])

  1. 问题在于,它真的比root_ls = [func(x,b) for x in input]这种快吗?在笔者的设计中确实是使用vmap更快一些,但是不见得总是好用,只是在pytorch中写大量的for实在是太愚蠢了QAQ

感谢阅读,欢迎交流

你可能感兴趣的:(pytorch,神经网络,python,pytorch,python,深度学习)