MLC--机器学习编译的课程笔记

提示:文章有不对的地方,麻烦大佬们指出下,还有为什么复制图会带水印啊

文章目录

  • 前言
  • 一、什么是机器学习编译?
  • 二、张量抽象
    • 1.什么是算子(元张量函数)
    • 2.张量实践
  • 三、端到端模型整合
  • 四、自动优化
  • 五、与机器学习框架的整合
  • 持续更新中--------


前言

提示:本文是个人上课的精简笔记,尽可能简要的表达一些概念,需要更全面的可以看陈天奇老师课程

课程:https://www.bilibili.com/video/BV15v4y1g7EU
课程主页: https://mlc.ai/summer22-zh
课程笔记:https://mlc.ai/zh/

个人认为课程的内容大部分是API的使用讲解,关注怎么做,如果需要学为什么这么做,比如说编译时的内存怎么一步步优化,可能要额外资料学习。

下面课程中代码涉及的流程图,MLC中程序的每个模块表达成IRModule类,之后我们关注的主要是transform那一步做优化,文中表达元张量函数=元算子
MLC--机器学习编译的课程笔记_第1张图片


一、什么是机器学习编译?

机器学习编译,个人简单理解就是,把机器学习模型经过一定转换,使其能以更接近底层API的形式部署在各种不同的平台上。机器学习编译的三个目标分别为集成和最小化依赖硬件加速通用优化集成和最小化依赖 是指机器学习模型依赖的包可能只占pytorch的一小部分,比如只用到Conv2d, ReLU之类,从节省资源考虑,在部署时只打包必要的依赖。部署时也分为开发环境和部署环境,但二者很多时候是相同的。硬件加速,简单来说是搞cuda编程, 利用硬件cuda的加速库。通用优化,最通用的就是内存优化,比如改变计算时数组元素访问顺序来提高访存效率等。
MLC--机器学习编译的课程笔记_第2张图片
MLC--机器学习编译的课程笔记_第3张图片


二、张量抽象

本章节就是介绍使用tvm优化程序的基本过程,抽象举个例子说,抽象可以把linear和relu操作合并表示为循环操作的矩阵元素操作(或者说元算子的合并)

1.什么是算子(元张量函数)

张量算子函数指的是张量计算过程中的最基本操作,比如加法add和减法等,正常的加法在编译过后可能是逐元素的加法操作,但加法是可并行,可以根据底层支持优化成并行的向量加法。
MLC--机器学习编译的课程笔记_第4张图片

# 注意运行要去掉注释,要不编译过不了
# pip install mlc-ai-nightly -f https://mlc.ai/wheels
# 张量算子计算示例
import tvm
from tvm.ir.module import IRModule
from tvm.script import tir as T
import numpy as np

class MyModule:
    @T.prim_func
    def main(A: T.Buffer[128, 'float32'], #参数大小类型指定,相当于显式开内存
             B: T.Buffer[128, 'float32'],
             C: T.Buffer[128, 'float32']):
        T.func_attr({'global_symbol':"main", "tir.noalian":True})
        for i in range(128):
            with T.block("C"):
                vi = T.axis.spatial(128, i)
                C[vi] = A[vi] + B[vi]

sch = tvm.tir.Schedule(MyModule) #辅助类
block_c = sch.get_block("C") # 拿到上面C部分的代码
i, = sch.get_loops(block_c)
i0, i1, i2 = sch.split(i, factors=[None, 4, 4]) #一个迭代循环分解长成3个
print(sch.mod.script())
'''
输出结果:
@tvm.script.ir_module
class Module:
    @tir.prim_func
    def func(A: tir.Buffer[128, "float32"], B: tir.Buffer[128, "float32"], C: tir.Buffer[128, "float32"]) -> None:
        # function attr dict
        tir.func_attr({"global_symbol": "main", "tir.noalian": True})
        # body
        # with tir.block("root")
        for i_0, i_1, i_2 in tir.grid(8, 4, 4): #和上面的4对应
            with tir.block("C"):
                vi = tir.axis.spatial(128, i_0 * 16 + i_1 * 4 + i_2)
                tir.reads(A[vi], B[vi])
                tir.writes(C[vi])
                C[vi] = A[vi] + B[vi]
    
'''

2.张量实践

代码的Block可以理解为计算单元,矩阵乘法和Relu对应代码如下(示例):

from tvm.ir.module import IRModule
from tvm.script import tir as T
import numpy as np
# https://tvm.apache.org/docs tvm官方文档,但课程用的好像不一样

dtype = 'float32'
a_np = np.random.rand(128, 128).astype(dtype)
b_np = np.random.rand(128, 128).astype(dtype)
cmm_relu = np.maximum(a_np @ b_np, 0) #矩阵乘法+ReLU的numpy实现

# 矩阵乘法的low-level numpy写法,尽量贴近C
def lnumpy_mm_relu(A:np.ndarray,
                   B:np.ndarray,
                   C:np.ndarray):
    # https://blog.csdn.net/artorias123/article/details/86527456
    # 矩阵乘法按行计算,内存优化
    Y = np.empty((128,128),dtype="float32")
    for i in range(128):
        for j in range(128):
            for k in range(128):
                if k == 0:
                    Y[i, j] = 0
                Y[i, j] = Y[i, j] + A[i,k] * B[k,j]
    
    for i in range(128):
        for j in range(128):
            C[i,j] = max(Y[i,j], 0)

c_np = np.empty((128,128), dtype=dtype)
lnumpy_mm_relu(a_np, b_np, c_np)
# 检测两个矩阵的差值是否满足一定范围,不满足就报错
np.testing.assert_allclose(cmm_relu, c_np, rtol=1e-5)
# tvm版本,注意运行要去掉注释,要不编译过不了
@tvm.script.ir_module
class MyModule:
    @T.prim_func #元张量函数
    def main(A: T.Buffer[(128, 128), 'float32'],
             B: T.Buffer[(128, 128), 'float32'],
             C: T.Buffer[(128, 128), 'float32']):
        T.func_attr({'global_symbol':"mm_relu",  # 编译后的函数名
                     "tir.noalian":True}) # 内存指针不重复? 
        Y = T.alloc_buffer((128, 128), dtype='float32')  #模拟申请内存 
        '''
        mm 部分
        '''
        for i, j, k in T.grid(128, 128, 128): # 等价上面三重循环
            with T.block("Y"): #执行单元
                vi = T.axis.spatial(128, i) #每次循环固定一个值, 出现在输出结果的某个维度上
                vj = T.axis.spatial(128, j) #每次循环固定一个值
                vk = T.axis.reduce(128, k) #每次调用block, 0..127
                #vi,vj, vk = T.axis.remap("SSR", [i, j, k]) 课程说的简化写法, SSR是spatial, spatial, reduction
                with T.init():
                    Y[vi, vj] = T.float32(0)
                Y[vi, vj] = Y[vi, vj] + A[vi, vk] * B[vk,vj]
        '''
        relu部分 可以与上面分成两个函数
        '''
        for i, j in T.grid(128, 128):
            with T.block("C"):
                vi = T.axis.spatial(128, i)
                vj = T.axis.spatial(128, j)
                C[vi, vj] = T.max(Y[vi, vj], T.float32(0))
def lnumpy_mm_relu_v2(A:np.ndarray,
                   B:np.ndarray,
                   C:np.ndarray):
    # https://blog.csdn.net/artorias123/article/details/86527456
    # 矩阵乘法按行计算,内存优化
    Y = np.empty((128,128),dtype="float32")
    for i in range(128):
        for j0 in range(32):
            for k in range(128):
                for j1 in range(4):
                    j = j0 * 4 + j1
                    if k == 0:
                        Y[i, j] = 0
                    Y[i, j] = Y[i, j] + A[i,j] * B[i,k]
    
    for i in range(128):
        for j in range(128):
            C[i,j] = max(Y[i,j], 0)

c_np = np.empty((128,128), dtype=dtype)
lnumpy_mm_relu_v2(a_np, b_np, c_np)
# 检测两个矩阵的差值是否满足一定范围,不满足就报错
np.testing.assert_allclose(cmm_relu, c_np, rtol=1e-5)
def lnumpy_mm_relu_v3(A: np.ndarray, B: np.ndarray, C: np.ndarray):
    Y = np.empty((128, 128), dtype="float32")
    for i in range(128):
        for j0 in range(32):
            # Y_init
            for j1 in range(4):
                j = j0 * 4 + j1 
                Y[i, j] = 0
            # Y_update
            for k in range(128):
                for j1 in range(4):
                    j = j0 * 4 + j1 
                    Y[i, j] = Y[i, j] + A[i, k] * B[k, j]
            # C
            for j1 in range(4):
                j = j0 * 4 + j1 
                C[i, j] = max(Y[i, j], 0)

c_np = np.empty((128, 128), dtype=dtype)
lnumpy_mm_relu_v3(a_np, b_np, c_np)
np.testing.assert_allclose(c_mm_relu, c_np, rtol=1e-5)

将上述numpy操作的改进用TVM的API实现

# mm_relu_v2版本的对应操作
sch = tvm.tir.Schedule(MyModule)
block_Y = sch.get_block("Y", func_name="mm_relu")
i, j, k = sch.get_loops(block_Y)
j0, j1 = sch.split(j, factors=[None, 4]) #j循环拆解成两个循环
sch.reorder(j0, k, j1) #循环重排序,不懂为什么这么做
# mm_relu_v3版本的对应操作
# 算子融合,把block C的循环移动到block Y里
block_C = sch.get_block("C", "mm_relu") 
sch.reverse_compute_at(block_C, j0)
block_Y = sch.get_block("Y", "mm_relu") #
sch.decompose_reduction(block_Y, k)
print(sch.mod.script()) 

# 比较计算内存优化前后的运行时间
rt_lib = tvm.build(MyModule, target='llvm')
a_nd = tvm.nd.array(a_np)
b_nd = tvm.nd.array(b_np)
c_nd = tvm.nd.empty((128, 128), dtype='float32')

rt_lib_after = tvm.build(sch.mod, target='llvm')
f_time_before = rt_lib.time_evaluator("mm_relu", tvm.cpu())
print(f"Time cost of Mymodule:{f_time_before(a_nd, b_nd, c_nd).mean}")
f_time_after =  rt_lib_after.time_evaluator("mm_relu", tvm.cpu())
print(f"Time cost after transform:{f_time_after(a_nd, b_nd, c_nd).mean}")

上述优化的意思是,由于内存读写的特性,RAM和CPU缓存的速率差很多,CPU计算时的数据是以一行行的形式从RAM放进CPU的各级(L1~L3)Cache的,也就是A[1,3]元素在缓存时,A[1]行的元素或者其左右的元素都在缓存,所以在计算的时候尽量把连续的元素放一起算。

图的意思是,矩阵乘法计算是Y = A行 x B列,正常三重循环遍历k次算一个, Y[i, j] = A[i, k] * B[k, j], 这里B要行指针K需移动多次,但B[k, j]读进缓存时B[K]行也在了,所以可以多算几个同一行的Y[i, j],提高效率。
MLC--机器学习编译的课程笔记_第5张图片


三、端到端模型整合

本章节是讲如何把神经网络的计算过程,抽象成元函数的组合形式,重点是call_tir()的使用如何表示计算图
MLC--机器学习编译的课程笔记_第6张图片

# This is needed for deferring annotation parsing in TVMScript, 比如后面的Tensor
from __future__ import annotations  # 放后面的话,我这个会报错
import tvm
from tvm.ir.module import IRModule
from tvm.script import tir as T, relax as R
import numpy as np
# 章节重点,基于relax构建计算图和编译,但relax具体是什么不清楚,只知道是很高级的特性
from tvm import relax
'''
lv0 = R.call_tir(linear0, (x, w0, b0), (1, 128), dtype="float32")
'''
# call_tir的numpy实现,就是把输出和输入都当参数,在底层显式把所需内存开出来
# 课程里说这种指针式返回out,而不是函数返回值的形式,是许多底层库的通用写法
def lnumpy_call_tir(prim_func, inputs, shape, dtype):
    res = np.empty(shape, dtype=dtype)
    prim_func(*inputs, res)
    return res
  1. 定义上图神经的numpy实现和tvm script实现(课程把里面的叫TensorIR函数)
# 上图简单神经网络的numpy表示
def numpy_mlp(data, w0, b0, w1, b1):
    lv0 = data @ w0.T + b0
    lv1 = np.maximum(lv0, 0)
    lv2 = lv1 @ w1.T + b1
    return lv2
# 上图简单神经网络的实现,编译执行要去掉注释
@tvm.script.ir_module
class MyModule: 
    @T.prim_func
    def relu0(X: T.Buffer[(1, 128), "float32"], 
              Y: T.Buffer[(1, 128), "float32"]):
        # function attr dict
        T.func_attr({"global_symbol": "relu0", "tir.noalias": True})
        for i, j in T.grid(1, 128):
            with T.block("Y"):
                vi, vj = T.axis.remap("SS", [i, j])
                Y[vi, vj] = T.max(X[vi, vj], T.float32(0))

    @T.prim_func
    def linear0(X: T.Buffer[(1, 784), "float32"], 
                W: T.Buffer[(128, 784), "float32"], 
                B: T.Buffer[(128,), "float32"], 
                Z: T.Buffer[(1, 128), "float32"]):
        T.func_attr({"global_symbol": "linear0", "tir.noalias": True})
        Y = T.alloc_buffer((1, 128), "float32")
        for i, j, k in T.grid(1, 128, 784):
            with T.block("Y"):
                vi, vj, vk = T.axis.remap("SSR", [i, j, k])
                with T.init():
                    Y[vi, vj] = T.float32(0)
                Y[vi, vj] = Y[vi, vj] + X[vi, vk] * W[vj, vk]
    
        for i, j in T.grid(1, 128):
            with T.block("Z"):
                vi, vj = T.axis.remap("SS", [i, j])
                Z[vi, vj] =  Y[vi, vj] + B[vj]

    @T.prim_func
    def linear1(X: T.Buffer[(1, 128), "float32"], 
                W: T.Buffer[(10, 128), "float32"], 
                B: T.Buffer[(10,), "float32"], 
                Z: T.Buffer[(1, 10), "float32"]):
        T.func_attr({"global_symbol": "linear1", "tir.noalias": True})
        Y = T.alloc_buffer((1, 10), "float32")
        for i, j, k in T.grid(1, 10, 128):
            with T.block("Y"):
                vi, vj, vk = T.axis.remap("SSR", [i, j, k])
                with T.init():
                    Y[vi, vj] = T.float32(0)
                Y[vi, vj] = Y[vi, vj] + X[vi, vk] * W[vj, vk]
    
        for i, j in T.grid(1, 10):
            with T.block("Z"):
                vi, vj = T.axis.remap("SS", [i, j])
                Z[vi, vj] = Y[vi, vj] + B[vj]

    @R.function #比元张量函数更高级的抽象表示,理解为元操作的组合,方便优化
    def main(x: Tensor((1, 784), "float32"), 
             w0: Tensor((128, 784), "float32"), 
             b0: Tensor((128,), "float32"), 
             w1: Tensor((10, 128), "float32"), 
             b1: Tensor((10,), "float32")):
        '''
        构建计算图作用域的声明范式
        计算图的每个操作都应该是side-effect free的
        side-effect free: 一个函数只从其输入中读取并通过其输出返回结果,它不会改变程序的其他部分(例如递增全局计数器),是pure的
        '''     
        with R.dataflow():
            lv0 = R.call_tir(linear0, (x, w0, b0), (1, 128), dtype="float32")
            lv1 = R.call_tir(relu0, (lv0,), (1, 128), dtype="float32")
            out = R.call_tir(linear1, (lv1, w1, b1), (1, 10), dtype="float32")
            R.output(out)
        return out
  1. 加载数据–>加载预训练权重–>编译执行
import torchvision
import torch
# 加载数据集
test_data = torchvision.datasets.FashionMNIST(
    root="data",
    train=False,
    download=True,
    transform=torchvision.transforms.ToTensor()
)
img, label = test_data[0]
img = img.reshape(1, 28, 28).numpy()

# 模型参数下载
# wget https://github.com/mlc-ai/web-data/raw/main/models/fasionmnist_mlp_params.pkl
import pickle as pkl
# 加载模型参数,用numpy测试下
mlp_params = pkl.load(open("fasionmnist_mlp_params.pkl", "rb"))
res = numpy_mlp(img.reshape(1, 784),
                mlp_params["w0"],
                mlp_params["b0"],
                mlp_params["w1"],
                mlp_params["b1"])

# relax编译模型,和之前tir一样,多了虚拟机操作,不知道和java虚拟机概念是否一致
ex = relax.vm.build(MyModule, target="llvm") 
# type(ex) 可以在relax虚拟机执行的函数
vm = relax.VirtualMachine(ex, tvm.cpu()) #生成一个虚拟机,tvm.cpu运行设备

# 执行预测
data_nd = tvm.nd.array(img.reshape(1, 784))
nd_params = {k: tvm.nd.array(v) for k, v in mlp_params.items()}
nd_res = vm["main"](data_nd,
                    nd_params["w0"],
                    nd_params["b0"],
                    nd_params["w1"],
                    nd_params["b1"])
                    
pred_kind = np.argmax(nd_res.numpy(), axis=1)
print("预测的类别:", class_names[pred_kind[0]])

四、自动优化

本章节是调用API对程序的可能优化步骤(各种循环换顺序)进行随机搜索,关键就是两点,分别是构建随机搜索空间应用自动优化

from __future__ import annotations
import numpy as np
import tvm
from tvm import relax
from tvm.ir.module import IRModule
from tvm.script import relax as R
from tvm.script import tir as T

import IPython # code2html在notebook里用,高亮代码,print也行
def code2html(code):
    """Helper function to use pygments to turn the code string into highlighted html."""
    import pygments
    from pygments.formatters import HtmlFormatter
    from pygments.lexers import Python3Lexer
    formatter = HtmlFormatter()
    html = pygments.highlight(code, Python3Lexer(), formatter)
    return "%s\n" % (formatter.get_style_defs(".highlight"), html)
  1. 首先定义一个元张量函数(矩阵乘法)
# SSR 的S是spatial, R是reduction,具体解释在第二章张量实践
@tvm.script.ir_module
class MyModule:
    @T.prim_func
    def main(
        A: T.Buffer[(128, 128), "float32"],
        B: T.Buffer[(128, 128), "float32"],
        C: T.Buffer[(128, 128), "float32"],
    ):
        T.func_attr({"global_symbol": "main", "tir.noalias": True})
        for i, j, k in T.grid(128, 128, 128):
            with T.block("C"):
                vi, vj, vk = T.axis.remap("SSR", [i, j, k])
                with T.init():
                    C[vi, vj] = 0.0
                C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vk, vj]\
  1. 定义一个正常变换和随机变换过程(=定义搜索空间)
# 正常循环拆分和重排
def schedule_mm(sch: tvm.tir.Schedule, jfactor=4):
    block_C = sch.get_block("C", "main")
    i, j, k = sch.get_loops(block=block_C)
    j_0, j_1 = sch.split(loop=j, factors=[None, jfactor]) # 比如128的循环分成jfactor x None
    sch.reorder(i, j_0, k, j_1) # 根据随机变量做变换
    sch.decompose_reduction(block_C, k)
    return sch

# 随机调度变换 (Stochastic Schedule Transformation)
def stochastic_schedule_mm(sch: tvm.tir.Schedule):
    block_C = sch.get_block("C", "main")
    i, j, k = sch.get_loops(block=block_C)
    # j_factor 执行一次才随机采样一次,每次执行值不同,不是采样个列表出来,
    j_factors = sch.sample_perfect_tile(loop=j, n=2) # 对j随机采样分开成两个循环, 搜索空间为[8, 16]、[32, 4].......
    j_0, j_1 = sch.split(loop=j, factors=j_factors)
    sch.reorder(i, j_0, k, j_1)
    sch.decompose_reduction(block_C, k)
    return sch
  1. 执行随机变换,重点使用 sch.traceAPI获取变换的历史记录,便于之后优化
sch = tvm.tir.Schedule(MyModule)
# sch = schedule_mm(sch)
sch = stochastic_schedule_mm(sch)
# IPython.display.HTML(code2html(sch.mod.script()))
print(sch.trace) # 可以看到记录与定义的过程一致
'''
.....
v4, v5 = sch.sample_perfect_tile(loop=l2, n=2, max_innermost_factor=16, decision=[128, 1])
# decision 是随机采样的结果
......
'''
  1. 应用自动优化,随机变换过程已经指定了程序的搜索空间,这时使用 tune_tir API 在搜索空间内搜索并找到最优的调度变换(也可以使用自动生成的搜索空间)。
#!pip install xgboost 搜索要调这个
from tvm import meta_schedule as ms # meta_schedule 是支持搜索可能变换空间的命名空间的API
# 个人重点:API是基于历史轨迹进行遗传搜索 (evolutionary search),而不是每次都随机采样
sch_tuned = ms.tune_tir(
    mod=MyModule,
    target="llvm --num-cores=1",
    config=ms.TuneConfig(
      max_trials_global=64,
      num_trials_per_iter=64,
    ),
    # 注释掉后是自动搜索空间
    # space=ms.space_generator.ScheduleFn(stochastic_schedule_mm), 人为定义的搜索空间随机搜索 
    work_dir="./tune_tmp",
    task_name="main"
)
# IPython.display.HTML(code2html(sch_tuned.mod.script())) 打印结果
  1. 与之前的端到端模型结合,对linear0进行优化
# relax编译第三章的神经网络模型,不是这章的MyModule
ex = relax.vm.build(MyModule, target="llvm") 
# type(ex) 可以在relax虚拟机执行的函数
vm = relax.VirtualMachine(ex, tvm.cpu()) #生成一个虚拟机,tvm.cpu运行设备
'''
目前,调优 API 只接受一个带有一个 main 函数的 IRModule,所以先将 linear0 取出到另一个模块的 main 函数中并将其传递给 tune_tir。
问:之前模型的IRModule也只有一个main,为啥多取一次?
'''
mod_linear = tvm.IRModule.from_expr(MyModule["linear0"].with_attr("global_symbol", "main"))
#
sch_tuned_linear = ms.tune_tir(
    mod=mod_linear,
    target="llvm --num-cores=1",
    config=ms.TuneConfig(
      max_trials_global=64,
      num_trials_per_iter=64,
    ),
    work_dir="./tune_tmp",
    task_name="main",
)
'''
调优后用新函数替换原来的 linear0, 首先获得一个 global_var(一个指向 IRModule 中函数的
pointer引用),然后调用 update_func 来用新的函数替换原本的函数
'''
MyModuleWithParams = relax.transform.BindParams("main", nd_params)(MyModule)
new_func = sch_tuned_linear.mod["main"].with_attr("global_symbol", "linear0")
gv = MyModuleWithParams.get_global_var("linear0")
MyModuleWithParams.update_func(gv, new_func)
IPython.display.HTML(code2html(MyModuleWithParams2.script()))

五、与机器学习框架的整合

本章节使用API直接将pytorch的模型代码转换成IRMoudle,之前的章节都是用TVMScript编写模型的计算过程再转换。

  1. 使用TorchFX构建Pytorch模型的计算图
class MyModel(nn.Module):
    def __init__(self):
        super(MyModel, self).__init__()
        self.weight = nn.Parameter(torch.randn(128, 128))

    def forward(self, x):
        x = torch.matmul(x, self.weight)
        x = torch.relu(x)
        return x

model = MyModel()
fx_module = fx.symbolic_trace(model)
type(fx_module)
fx_module.graph.print_tabular()
  1. 构建计算图到IRMMoudle的映射
def map_param(param: nn.Parameter):
    ndim = len(param.data.shape)
    return relax.const(
        param.data.cpu().numpy(), relax.DynTensorType(ndim, "float32")
    )

def fetch_attr(fx_mod, target: str):
    """Helper function to fetch an attr"""
    target_atoms = target.split('.')
    attr_itr = fx_mod
    for i, atom in enumerate(target_atoms):
        if not hasattr(attr_itr, atom):
            raise RuntimeError(f"Node referenced nonexistant target {'.'.join(target_atoms[:i])}")
        attr_itr = getattr(attr_itr, atom)
    return attr_itr

def from_fx(fx_mod, input_shapes, call_function_map, call_module_map):
    input_index = 0
    node_map = {}
    named_modules = dict(fx_mod.named_modules())

    bb = relax.BlockBuilder()

    fn_inputs = []
    fn_output = None
    with bb.function("main"):
        with bb.dataflow():
            for node in fx_mod.graph.nodes:
                if node.op == "placeholder": # 输入变量映射Var
                    # create input placeholder
                    shape = input_shapes[input_index]
                    input_index += 1
                    input_var = relax.Var(
                        node.target, shape, relax.DynTensorType(len(shape), "float32")
                    )
                    fn_inputs.append(input_var)
                    node_map[node] = input_var #输入节点
                elif node.op == "get_attr": # 权重参数映射
                    node_map[node] = map_param(fetch_attr(fx_mod, node.target))  #node.target = fx.mod.weight
                elif node.op == "call_function": # relu之类,从node_map取出要操作的node
                    node_map[node] = call_function_map[node.target](bb, node_map, node)
                elif node.op == "call_module": # module里面有weight,也要注意转换, 但没看出来
                    named_module = named_modules[node.target]
                    node_map[node] = call_module_map[type(named_module)](bb, node_map, node, named_module)
                elif node.op == "output":
                    output = node_map[node.args[0]]
                    assert fn_output is None
                    fn_output = bb.emit_output(output)
        # output and finalize the function
        bb.emit_func_output(output, fn_inputs)
    return bb.get()
  1. 执行映射转换
def map_matmul(bb, node_map, node: fx.Node):
    A = node_map[node.args[0]]
    B = node_map[node.args[1]]
    return bb.emit_te(te_matmul, A, B) # => x @ w.T

def map_relu(bb, node_map, node: fx.Node):
    A = node_map[node.args[0]]
    return bb.emit_te(te_relu, A)

MyModule = from_fx(
    fx_module,
    input_shapes = [(1, 128)],
    call_function_map = {
      torch.matmul: map_matmul,
      torch.relu: map_relu,
    },
    call_module_map={},
)

MyModule.show()

持续更新中--------

你可能感兴趣的:(笔记,机器学习,人工智能,python)