以 Wagon 为例, Golang 解析 wasm

wasm 原理

wasm 指令的解析,其实都是 入栈,出栈的操作, 它是一个基于栈的虚拟机,比如
get_local 0, 它就是获取函数的第一个参数,并把它放到栈里.
i32.const 42 就是把一个 42(int32)放入栈中.
i32.add 就是从栈中取出两个数,相加后再放回栈里。

下面看一个具体的例子

cpp 如下

extern "C" {
  int large(int num) {
    if (num > 10) {
      num = num + 12;
    } else {
      num = num + 100;
    }
    return num;
  }
}

指定 Optimization Level -o3 优化后 编译后的 wast 如下

 (table 0 anyfunc)
 (memory $0 1)
 (export "memory" (memory $0))
 (export "large" (func $large))
 (func $large (; 0 ;) (param $0 i32) (result i32)
  (i32.add  // 8
   (select  // 6
    (i32.const 12) // 1
    (i32.const 100) // 2
    (i32.gt_s // 5
     (get_local $0) // 3
     (i32.const 10) // 4
    )
   ) 
   (get_local $0) // 7
  )
 )
)

指令解析 $0 为函数输入参数

  • (i32.const 12) 将 12 pushstack;
    • stack=>[12]
  • (i32.const 100) 将 100 pushstack;
    • stack=>[100, 12]
  • (get_local $0) 将 $0 (参数) 从local中读取,并pushstack;
    • stack=>[$0, 100, 12]
  • (i32.const 10) 将 10 pushstack;
    • stack=>[10, $0, 100, 12]
  • (i32.gt_s ()())stackpopv1, v2,并比较大小,v1 > v2push 1 到stack, 反之 push 0;
    • stack=>[1, 100, 12]
  • (select ()()() )stackpop 1 => v3,从 stackpop 100 => v2, 从 stack 中pop 12 => v1,if v3 为 1(true), 将 v2 pushstack, 反之 将 v1 pushstack;
    • stack=>[100]
  • (get_local $0) 将 $0 (参数) 从local中读取,并pushstack;
    • stack=>[$0, 100]
  • (i32.add ()()) 从stack 中pop 两个数,相加后push 到stack;
    • stack=>[108]
  • 返回结果 108

做一个 webassembly 的虚拟机主要分两块, compileInterpreter. 我们先看 compile 模块.

Compile

  • 编译主要是对 wasm 结构进行解析, 首先看看 module 对象,这是个核心类, wasm 就是解析到这个对象
type Module struct {
    Version  uint32         // wasm 的版本
    Sections []Section      // wasm 中所有的section 数组, 一个 wasm 文件, 就是由version 和多个 section 组成
    Types    *SectionTypes  // wasm 中所有的函数描述
    Import   *SectionImports // wasm 中导入的函数
    Function *SectionFunctions // wasm 中声明的函数,每个函数对应一个index 指向 Types内的函数类型
    Table    *SectionTables 
    Memory   *SectionMemories
    Global   *SectionGlobals
    Export   *SectionExports    // wasm 中导出的函数描述
    Start    *SectionStartFunction  // 需要立刻执行的函数
    Elements *SectionElements   // 定义在 table 中的元素
    Code     *SectionCode   // 该 module 的所有函数信息数据
    Data     *SectionData  // 数据区, 比如一些字符串等数据, 会放在Data里, 用 offset 标记
    Customs  []*SectionCustom 

    // The function index space of the module
    FunctionIndexSpace []Function // wasm 中所有的函数包括 SectionImports 和 SectionFunctions,函数中的 type 指向 Types 中的类型
    GlobalIndexSpace   []GlobalEntry

    // function indices into the global function space
    // the limit of each table is its capacity (cap)
    TableIndexSpace        [][]uint32
    LinearMemoryIndexSpace [][]byte // 线性内存, Data 数据会存放在这里

    imports struct {
        Funcs    []uint32  // 导入的函数
        Globals  int 
        Tables   int
        Memories int
    }
}
  • 读取 wasm 文件
// 从本地读取一个 wasm 文件,并返回 module
func ReadModule(r io.Reader, resolvePath ResolveFunc) (*Module, error) {
    // 通过解析 二进制 wasm 文件,将数据解析道对应的 section 中去
    m, err := DecodeModule(r)

    ...
    if m.Import != nil && resolvePath != nil {
        if m.Code == nil {
            m.Code = &SectionCode{}
        }
        // 解析 导入 的 module
        err := m.resolveImports(resolvePath)
    }
    for _, fn := range []func() error{
        m.populateGlobals,
        // 将内部函数转化为 Function 对象,并将 导入的函数也 一并添加到 FunctionIndexSpace 中
        m.populateFunctions,
        m.populateTables,
        // 将 m.Data 放到线性内存中
        m.populateLinearMemory,
    } {
        if err := fn(); err != nil {
            return nil, err
        }
    }
    return m, nil
}

func DecodeModule(r io.Reader) (*Module, error) {
    reader := &readpos.ReadPos{
        R:      r,
        CurPos: 0,
    }
    m := &Module{}
    ...

    err = newSectionsReader(m).readSections(reader)
    return m, nil
}
    
  • DecodeModule 新建 sectionReader, 并调用 readSections
func (s *sectionsReader) readSections(r *readpos.ReadPos) error {
    for {
        // 循环读取section,知道读完
        done, err := s.readSection(r)
        switch {
        case err != nil:
            return err
        case done:
            return nil
        }
    }
}

// 从reader 中读取一个有效的 section. The first return value is true if and only if
// the module has been completely read.
func (sr *sectionsReader) readSection(r *readpos.ReadPos) (bool, error) {
    m := sr.m

    logger.Println("Reading section ID")
    // 从 reader 中读取一个字节
    id, err := r.ReadByte()
    ...

    s := RawSection{ID: SectionID(id)}

    logger.Println("Reading payload length")
    // 读取实际 数据
    payloadDataLen, err := leb128.ReadVarUint32(r)
    if err != nil {
        return false, err
    }

    logger.Printf("Section payload length: %d", payloadDataLen)

    s.Start = r.CurPos

    sectionBytes := new(bytes.Buffer)

    sectionBytes.Grow(int(getInitialCap(payloadDataLen)))
    sectionReader := io.LimitReader(io.TeeReader(r, sectionBytes), int64(payloadDataLen))
    
    // 判断section 的类型,并将该类型空的 struct 赋值给 module 对应的属性
    var sec Section
    switch s.ID {
    case SectionIDCustom:
        logger.Println("section custom")
        cs := &SectionCustom{}
        m.Customs = append(m.Customs, cs)
        sec = cs
    case SectionIDType:
        logger.Println("section type")
        m.Types = &SectionTypes{}
        sec = m.Types
    case SectionIDImport:
        logger.Println("section import")
        m.Import = &SectionImports{}
        sec = m.Import
    case SectionIDFunction:
        logger.Println("section function")
        m.Function = &SectionFunctions{}
        sec = m.Function
    case SectionIDTable:
        logger.Println("section table")
        m.Table = &SectionTables{}
        sec = m.Table
    case SectionIDMemory:
        logger.Println("section memory")
        m.Memory = &SectionMemories{}
        sec = m.Memory
    case SectionIDGlobal:
        logger.Println("section global")
        m.Global = &SectionGlobals{}
        sec = m.Global
    case SectionIDExport:
        logger.Println("section export")
        m.Export = &SectionExports{}
        sec = m.Export
    case SectionIDStart:
        logger.Println("section start")
        m.Start = &SectionStartFunction{}
        sec = m.Start
    case SectionIDElement:
        logger.Println("section element")
        m.Elements = &SectionElements{}
        sec = m.Elements
    case SectionIDCode:
        logger.Println("section code")
        m.Code = &SectionCode{}
        sec = m.Code
    case SectionIDData:
        logger.Println("section data")
        m.Data = &SectionData{}
        sec = m.Data
    default:
        return false, InvalidSectionIDError(s.ID)
    }
    // 从reader 中读取数据,存入 section (对应到 module 的某个变量中)
    err = sec.ReadPayload(sectionReader)
    if err != nil {
        logger.Println(err)
        return false, err
    }
    s.End = r.CurPos
    s.Bytes = sectionBytes.Bytes()
    // 将 raw s 保存到 对应的 xxxSection 中
    *sec.GetRawSection() = s
    
    ...
    
    // 保存 section 
    m.Sections = append(m.Sections, sec)
    return false, nil
}
  • 将文件读取到 module 中后,还需要加载 import 的模块
// 解析import 的函数
func (module *Module) resolveImports(resolve ResolveFunc) error {
    if module.Import == nil {
        return nil
    }
    modules := make(map[string]*Module)

    var funcs uint32
    // 遍历 module.Import 下的 ”入口“
    for _, importEntry := range module.Import.Entries {
        importedModule, ok := modules[importEntry.ModuleName]
        if !ok {
            var err error
            // 如果不存在,就调用外部注入的 resolver 函数解析,并返回 module 对象
            importedModule, err = resolve(importEntry.ModuleName)
            if err != nil {
                return err
            }
            // 将导入的 module 保存起来
            modules[importEntry.ModuleName] = importedModule
        }

        if importedModule.Export == nil {
            return ErrNoExportsInImportedModule
        }
        // 判断 导入的module 中是否暴露了 importEntry.FieldName(本module 需要调用的方法)
        exportEntry, ok := importedModule.Export.Entries[importEntry.FieldName]
        if !ok {
            return ExportNotFoundError{importEntry.ModuleName, importEntry.FieldName}
        }
        // 判断 待导入函数类型, 与被导入模块的函数类型 是否一致
        if exportEntry.Kind != importEntry.Type.Kind() {
            return KindMismatchError{
                FieldName:  importEntry.FieldName,
                ModuleName: importEntry.ModuleName,
                Import:     importEntry.Type.Kind(),
                Export:     exportEntry.Kind,
            }
        }

        index := exportEntry.Index
        switch exportEntry.Kind {
        case ExternalFunction:
            // 根据 exportEntry 对应的 functionIndex ,获取对应的 Function 类型
            fn := importedModule.GetFunction(int(index))
            if fn == nil {
                return InvalidFunctionIndexError(index)
            }
            
            importIndex := importEntry.Type.(FuncImport).Type
            // 下面就判断 待带入的function  和 别导入的 function 的类型是否一致
            // 比较参数以及返回值长度
            if len(fn.Sig.ReturnTypes) != len(module.Types.Entries[importIndex].ReturnTypes) || len(fn.Sig.ParamTypes) != len(module.Types.Entries[importIndex].ParamTypes) {
                return InvalidImportError{importEntry.ModuleName, importEntry.FieldName, importIndex}
            }
            // 比较返回值类型
            for i, typ := range fn.Sig.ReturnTypes {
                if typ != module.Types.Entries[importIndex].ReturnTypes[i] {
                    return InvalidImportError{importEntry.ModuleName, importEntry.FieldName, importIndex}
                }
            }
            // 比较参数类型
            for i, typ := range fn.Sig.ParamTypes {
                if typ != module.Types.Entries[importIndex].ParamTypes[i] {
                    return InvalidImportError{importEntry.ModuleName, importEntry.FieldName, importIndex}
                }
            }
            // 将 Function 对象(被导入的函数),添加到 module 的 FunctionIndexSpace 数组中
            module.FunctionIndexSpace = append(module.FunctionIndexSpace, *fn)
            // 保存 Function 的函数体
            module.Code.Bodies = append(module.Code.Bodies, *fn.Body)
            // 将 Function 对象保存到 module 的 import.Funcs 数组中
            module.imports.Funcs = append(module.imports.Funcs, funcs)
            funcs++
        case ExternalGlobal:
            // todo ...
            glb := importedModule.GetGlobal(int(index))
            if glb == nil {
                return InvalidGlobalIndexError(index)
            }
            if glb.Type.Mutable {
                return ErrImportMutGlobal
            }
            module.GlobalIndexSpace = append(module.GlobalIndexSpace, *glb)
            module.imports.Globals++

            // In both cases below, index should be always 0 (according to the MVP)
            // We check it against the length of the index space anyway.
        case ExternalTable:
            if int(index) >= len(importedModule.TableIndexSpace) {
                return InvalidTableIndexError(index)
            }
            module.TableIndexSpace[0] = importedModule.TableIndexSpace[0]
            module.imports.Tables++
        case ExternalMemory:
            if int(index) >= len(importedModule.LinearMemoryIndexSpace) {
                return InvalidLinearMemoryIndexError(index)
            }
            module.LinearMemoryIndexSpace[0] = importedModule.LinearMemoryIndexSpace[0]
            module.imports.Memories++
        default:
            return InvalidExternalError(exportEntry.Kind)
        }
    }
    return nil
}

  • populateFunctions
// 函数索引空间索引所有导入和内部定义的函数定义
func (m *Module) populateFunctions() error {
    
    ...
    
    // 给内部定义的 func 构造 fn
    // Add the functions from the wasm itself to the function list
    numImports := len(m.FunctionIndexSpace)
    for codeIndex, typeIndex := range m.Function.Types {
        if int(typeIndex) >= len(m.Types.Entries) {
            return InvalidFunctionIndexError(typeIndex)
        }
        // Create the main function structure
        fn := Function{
            Sig:  &m.Types.Entries[typeIndex],
            Body: &m.Code.Bodies[codeIndex],
            Name: names[uint32(codeIndex+numImports)], // Add the name string if we have it
        }
        m.FunctionIndexSpace = append(m.FunctionIndexSpace, fn)
    }

    funcs := make([]uint32, 0, len(m.Function.Types)+len(m.imports.Funcs))
    funcs = append(funcs, m.imports.Funcs...)
    funcs = append(funcs, m.Function.Types...)
    m.Function.Types = funcs
    return nil
}
  • 新建一个 VM

先看 vm 类型

// VM is the execution context for executing WebAssembly bytecode.
type VM struct {
    ctx context // 执行上下文
        type context struct {
            stack   []uint64  // 栈深度
            locals  []uint64      // 局部变量
            code    []byte  // 函数的字节码
            asm     []asmBlock
            pc      int64       // 当前的字节码 index
            curFunc int64       // 当前函数在 funcs 的index
        }

    module  *wasm.Module  
    globals []uint64
    memory  []byte
    funcs   []function // 函数数组 compiledFunction or goFunction

    funcTable [256]func()  // 指令集,对应的解析函数

    // RecoverPanic controls whether the `ExecCode` method
    // recovers from a panic and returns it as an error
    // instead.
    // A panic can occur either when executing an invalid VM
    // or encountering an invalid instruction, e.g. `unreachable`.
    RecoverPanic bool

    abort bool // Flag for host functions to terminate execution

    nativeBackend *nativeCompiler
}

// 通过 module 对象和 options 构造一个 vm
func NewVM(module *wasm.Module, opts ...VMOption) (*VM, error) {
    var vm VM
    var options config
    for _, opt := range opts {
        opt(&options)
    }
    if module.Memory != nil && len(module.Memory.Entries) != 0 {
        if len(module.Memory.Entries) > 1 {
            return nil, ErrMultipleLinearMemories
        }
        vm.memory = make([]byte, uint(module.Memory.Entries[0].Limits.Initial)*wasmPageSize)
        copy(vm.memory, module.LinearMemoryIndexSpace[0])
    }

    vm.funcs = make([]function, len(module.FunctionIndexSpace))
    vm.globals = make([]uint64, len(module.GlobalIndexSpace))
    vm.newFuncTable()
    vm.module = module

    nNatives := 0
    for i, fn := range module.FunctionIndexSpace {
    
        // 如果是 import 的原生 golang 方法,使用 goFunction 处理
        if fn.IsHost() {
            vm.funcs[i] = goFunction{
                typ: fn.Host.Type(),
                val: fn.Host,
            }
            nNatives++
            continue
        }
        // 将function拆卸并封装成新的结构
        disassembly, err := disasm.NewDisassembly(fn, module)
        if err != nil {
            return nil, err
        }

        totalLocalVars := 0
        totalLocalVars += len(fn.Sig.ParamTypes)
        for _, entry := range fn.Body.Locals {
            totalLocalVars += int(entry.Count)
        }
        // 编译 字节码
        code, meta := compile.Compile(disassembly.Code)
        vm.funcs[i] = compiledFunction{
            codeMeta:       meta,
            code:           code,
            branchTables:   meta.BranchTables,
            maxDepth:       disassembly.MaxDepth,
            totalLocalVars: totalLocalVars,
            args:           len(fn.Sig.ParamTypes),
            returns:        len(fn.Sig.ReturnTypes) != 0,
        }
    }

    ...
    
    return &vm, nil
}

Interpreter

下面执行代码的过程,即是翻译代码的过程

// fnIndex 函数的index, args 是该函数的参数
func (vm *VM) ExecCode(fnIndex int64, args ...uint64) (rtrn interface{}, err error) {
    ...
    
    if int(fnIndex) > len(vm.funcs) {
        return nil, InvalidFunctionIndexError(fnIndex)
    }
    if len(vm.module.GetFunction(int(fnIndex)).Sig.ParamTypes) != len(args) {
        return nil, ErrInvalidArgumentCount
    }
    compiled, ok := vm.funcs[fnIndex].(compiledFunction)
    if !ok {
        panic(fmt.Sprintf("exec: function at index %d is not a compiled function", fnIndex))
    }

    depth := compiled.maxDepth + 1
    // 初始化执行栈
    if cap(vm.ctx.stack) < depth {
        vm.ctx.stack = make([]uint64, 0, depth)
    } else {
        vm.ctx.stack = vm.ctx.stack[:0]
    }

    vm.ctx.locals = make([]uint64, compiled.totalLocalVars)
    vm.ctx.pc = 0
    vm.ctx.code = compiled.code
    vm.ctx.asm = compiled.asm
    vm.ctx.curFunc = fnIndex
    // 给函数的参数赋值
    for i, arg := range args {
        vm.ctx.locals[i] = arg
    }

    res := vm.execCode(compiled)
    if compiled.returns {
        rtrnType := vm.module.GetFunction(int(fnIndex)).Sig.ReturnTypes[0]
        switch rtrnType {
        case wasm.ValueTypeI32:
            rtrn = uint32(res)
        case wasm.ValueTypeI64:
            rtrn = uint64(res)
        case wasm.ValueTypeF32:
            rtrn = math.Float32frombits(uint32(res))
        case wasm.ValueTypeF64:
            rtrn = math.Float64frombits(res)
        default:
            return nil, InvalidReturnTypeError(rtrnType)
        }
    }

    return rtrn, nil
}


func (vm *VM) execCode(compiled compiledFunction) uint64 {
outer:
    for int(vm.ctx.pc) < len(vm.ctx.code) && !vm.abort {
        op := vm.ctx.code[vm.ctx.pc]
        vm.ctx.pc++
        switch op {
        // 解析到 return 指令的时候,退出循环
        case ops.Return:
            break outer
            
        // 省略一些不常用的case
        ...
        
        default:
            // 大部分会走这个case
            vm.funcTable[op]()
        }
    }

    if compiled.returns {
        //如果有返回值,从栈中取出返回
        return vm.ctx.stack[len(vm.ctx.stack)-1]
    }
    return 0
}

funcTable [256]func() 的初始化 ,一个指令(Op)对应一个解析方法, 看看 Op 的结构

// Op describes a WASM operator.
type Op struct {
    Code byte   // The single-byte opcode
    Name string // 该操作的名称

    // Whether this operator is polymorphic.
    // A polymorphic operator has a variable arity. call, call_indirect, and
    // drop are examples of polymorphic operators.
    Polymorphic bool // 是否是动态的, true:比如一些逻辑控制语句, 还有 get/setlocal 等
    Args        []wasm.ValueType // 该指令需要的参数类型(数量)(会从栈中pop出来)
    Returns     wasm.ValueType   // 返回的参数类型
}

func (vm *VM) newFuncTable() {
    vm.funcTable[ops.I32Clz] = vm.i32Clz
    vm.funcTable[ops.I32Ctz] = vm.i32Ctz
    vm.funcTable[ops.I32Popcnt] = vm.i32Popcnt
    
    vm.funcTable[ops.I32Add] = vm.i32Add
    vm.funcTable[ops.I32Sub] = vm.i32Sub
    vm.funcTable[ops.I32Mul] = vm.i32Mul

    ....
    ....

    vm.funcTable[ops.Drop] = vm.drop
    vm.funcTable[ops.Select] = vm.selectOp

    vm.funcTable[ops.GetLocal] = vm.getLocal
    vm.funcTable[ops.SetLocal] = vm.setLocal
    vm.funcTable[ops.TeeLocal] = vm.teeLocal
    vm.funcTable[ops.GetGlobal] = vm.getGlobal
    vm.funcTable[ops.SetGlobal] = vm.setGlobal

    vm.funcTable[ops.Unreachable] = vm.unreachable
    vm.funcTable[ops.Nop] = vm.nop

    vm.funcTable[ops.Call] = vm.call
    vm.funcTable[ops.CallIndirect] = vm.callIndirect
}


例如

// 从栈中pop 两个uint32 出来,相加后在push 到栈中
func (vm *VM) i32Add() {
    vm.pushUint32(vm.popUint32() + vm.popUint32())
}

这里再讲一下 vm.call

func (vm *VM) call() {
    index := vm.fetchUint32()
    // 这里会从 funcs 数组里取出 Function(or goFunction) 对象,调用call
    vm.funcs[index].call(vm, int64(index))
}

// goFunction 利用反射机制,执行函数
func (fn goFunction) call(vm *VM, index int64) {
    // numIn = # of call inputs + vm, as the function expects
    // an additional *VM argument
    numIn := fn.typ.NumIn()
    args := make([]reflect.Value, numIn)
    proc := NewProcess(vm)

    // 第一个参数必须是 *Process 类型
    if reflect.ValueOf(proc).Kind() != fn.typ.In(0).Kind() {
        panic(fmt.Sprintf("exec: the first argument of a host function was %s, expected %s", fn.typ.In(0).Kind(), reflect.ValueOf(vm).Kind()))
    }
    args[0] = reflect.ValueOf(proc)
    // 给函数的参数赋值
    for i := numIn - 1; i >= 1; i-- {
        val := reflect.New(fn.typ.In(i)).Elem()
        raw := vm.popUint64()
        kind := fn.typ.In(i).Kind()

        switch kind {
        case reflect.Float64, reflect.Float32:
            val.SetFloat(math.Float64frombits(raw))
        case reflect.Uint32, reflect.Uint64:
            val.SetUint(raw)
        case reflect.Int32, reflect.Int64:
            val.SetInt(int64(raw))
        default:
            panic(fmt.Sprintf("exec: args %d invalid kind=%v", i, kind))
        }

        args[i] = val
    }
    // 执行函数
    rtrns := fn.val.Call(args)
    // 将返回值 push 到栈中
    for i, out := range rtrns {
        kind := out.Kind()
        switch kind {
        case reflect.Float64, reflect.Float32:
            vm.pushFloat64(out.Float())
        case reflect.Uint32, reflect.Uint64:
            vm.pushUint64(out.Uint())
        case reflect.Int32, reflect.Int64:
            vm.pushInt64(out.Int())
        default:
            panic(fmt.Sprintf("exec: return value %d invalid kind=%v", i, kind))
        }
    }
}

// 非原生函数实现
func (compiled compiledFunction) call(vm *VM, index int64) {

    newStack := make([]uint64, 0, compiled.maxDepth+1)
    locals := make([]uint64, compiled.totalLocalVars)
    // 给参数赋值
    for i := compiled.args - 1; i >= 0; i-- {
        locals[i] = vm.popUint64()
    }

    //保存执行上下文
    prevCtxt := vm.ctx
    // 新建执行上下文
    vm.ctx = context{
        stack:   newStack,
        locals:  locals,
        code:    compiled.code,
        asm:     compiled.asm,
        pc:      0,
        curFunc: index,
    }
    
    rtrn := vm.execCode(compiled)

    //被调用函数执行完了,恢复上下文
    vm.ctx = prevCtxt
    
    if compiled.returns {
        // 把返回值push到栈中
        vm.pushUint64(rtrn)
    }
}

你可能感兴趣的:(以 Wagon 为例, Golang 解析 wasm)