TVM Compiler中文教程:TVM中Compute和Reduction如何使用元组输入

Compute和Reduction使用元组输入

我们通常希望在单个循环内计算具有相同维度的多个输出,或者,执行涉及argmax等多个值的缩减。

在这篇教程,我们将介绍在TVM中元组输入。

from __future__ import absolute_import, print_function

import tvm
import numpy as np

batch计算

对于具有相同维度的运算符,如果我们希望它们在下一个调度程序中一起调度,我们可以将它们放在一起作为tvm.compute的输入。

n = tvm.var("n")
m = tvm.var("m")
A0 = tvm.placeholder((m,n),name='A0')
A1 = tvm.placeholder((m,n),name='A1')
B0, B1 = tvm.compute((m,n),lambda i,j:(A0[i,j]+2,A1[i,j]*3),name='B')
#生成IR中间表示代码
s = tvm.create_schedule(B0.op)
print(tvm.lower(s, [A0,A1,B0,B1],simple_mode=True))

Reduction使用协同输入

有时,我们需要多个输入来表示一些Reduction算子,输入将协同工作,例如argmax。在Reduction过程中,argmax需要比较操作数的值,也需要去保存操作数的索引。这能使用comm_reducer来表示:

# xy是Reduction的操作数,他们是索引和值的元组
def fcombine(x,y):
    lhs = tvm.expr.Select((x[1]>=y[1]),x[0],y[0])
    rhs = tvm.expr.Select((x[1]>=y[1]),x[1],y[1])
    return lhs,rhs
# 标识元素也需要是一个元组,所以‘fidentity’接受两种类型数据作为输入
def fidentity(t0,t1):#t0,t1为类型dtype
    return tvm.const(-1,t0), tvm.min_value(t1)
argmax = tvm.comm_reducer(fcombine, fidentity, name='argmax')

#描述Reduction计算
m = tvm.var('m')
n = tvm.var('n')
idx = tvm.placeholder((m,n),name='idx',dtype='int32')
val = tvm.placeholder((m,n),name='val',dtype='int32')
k = tvm.reduce_axis((0,n), 'k')
T0, T1 = tvm.compute((m, ), lambda i: argmax(idx[i,k], val(i,k), axis=k),name='T')

#生成IR代码
s = tvm.create_schedule(T0.op)
print(tvm.lower(s, [idx, val, T0, T1], simple_mode=True))

总结

教程介绍元组输入算子的使用:

  • 描述标准batch计算
  • 描述Reduction算子使用元组输入
  • 注意,只能使用操作而不是张量来调度计算

你可能感兴趣的:(TVM深度学习编译器)