wasm 原理
wasm
指令的解析,其实都是 入栈,出栈的操作, 它是一个基于栈的虚拟机,比如
get_local 0
, 它就是获取函数的第一个参数,并把它放到栈里.
i32.const 42
就是把一个 42(int32)
放入栈中.
i32.add
就是从栈中取出两个数,相加后再放回栈里。
下面看一个具体的例子
cpp
如下extern "C" { int large(int num) { if (num > 10) { num = num + 12; } else { num = num + 100; } return num; } }
指定
Optimization Level
-o3 优化后 编译后的wast
如下(table 0 anyfunc) (memory $0 1) (export "memory" (memory $0)) (export "large" (func $large)) (func $large (; 0 ;) (param $0 i32) (result i32) (i32.add // 8 (select // 6 (i32.const 12) // 1 (i32.const 100) // 2 (i32.gt_s // 5 (get_local $0) // 3 (i32.const 10) // 4 ) ) (get_local $0) // 7 ) ) )
指令解析 $0 为函数输入参数
-
(i32.const 12)
将 12push
到stack
;stack=>[12]
-
(i32.const 100)
将 100push
到stack
;stack=>[100, 12]
-
(get_local $0)
将 $0 (参数) 从local
中读取,并push
到stack
;stack=>[$0, 100, 12]
-
(i32.const 10)
将 10push
到stack
;stack=>[10, $0, 100, 12]
-
(i32.gt_s ()())
从stack
中pop
至v1, v2
,并比较大小,v1 > v2
则push
1 到stack
, 反之push 0
;stack=>[1, 100, 12]
-
(select ()()() )
从stack
中pop 1 => v3
,从stack
中pop 100 => v2
, 从 stack 中pop 12 => v1
,if v3 为 1(true), 将v2 push
到stack
, 反之 将v1 push
到stack
;stack=>[100]
-
(get_local $0)
将 $0 (参数) 从local
中读取,并push
到stack
;stack=>[$0, 100]
-
(i32.add ()())
从stack 中pop 两个数,相加后push 到stack;stack=>[108]
- 返回结果 108
做一个 webassembly
的虚拟机主要分两块, compile
和 Interpreter
. 我们先看 compile
模块.
Compile
- 编译主要是对 wasm 结构进行解析, 首先看看 module 对象,这是个核心类, wasm 就是解析到这个对象
type Module struct {
Version uint32 // wasm 的版本
Sections []Section // wasm 中所有的section 数组, 一个 wasm 文件, 就是由version 和多个 section 组成
Types *SectionTypes // wasm 中所有的函数描述
Import *SectionImports // wasm 中导入的函数
Function *SectionFunctions // wasm 中声明的函数,每个函数对应一个index 指向 Types内的函数类型
Table *SectionTables
Memory *SectionMemories
Global *SectionGlobals
Export *SectionExports // wasm 中导出的函数描述
Start *SectionStartFunction // 需要立刻执行的函数
Elements *SectionElements // 定义在 table 中的元素
Code *SectionCode // 该 module 的所有函数信息数据
Data *SectionData // 数据区, 比如一些字符串等数据, 会放在Data里, 用 offset 标记
Customs []*SectionCustom
// The function index space of the module
FunctionIndexSpace []Function // wasm 中所有的函数包括 SectionImports 和 SectionFunctions,函数中的 type 指向 Types 中的类型
GlobalIndexSpace []GlobalEntry
// function indices into the global function space
// the limit of each table is its capacity (cap)
TableIndexSpace [][]uint32
LinearMemoryIndexSpace [][]byte // 线性内存, Data 数据会存放在这里
imports struct {
Funcs []uint32 // 导入的函数
Globals int
Tables int
Memories int
}
}
- 读取 wasm 文件
// 从本地读取一个 wasm 文件,并返回 module
func ReadModule(r io.Reader, resolvePath ResolveFunc) (*Module, error) {
// 通过解析 二进制 wasm 文件,将数据解析道对应的 section 中去
m, err := DecodeModule(r)
...
if m.Import != nil && resolvePath != nil {
if m.Code == nil {
m.Code = &SectionCode{}
}
// 解析 导入 的 module
err := m.resolveImports(resolvePath)
}
for _, fn := range []func() error{
m.populateGlobals,
// 将内部函数转化为 Function 对象,并将 导入的函数也 一并添加到 FunctionIndexSpace 中
m.populateFunctions,
m.populateTables,
// 将 m.Data 放到线性内存中
m.populateLinearMemory,
} {
if err := fn(); err != nil {
return nil, err
}
}
return m, nil
}
func DecodeModule(r io.Reader) (*Module, error) {
reader := &readpos.ReadPos{
R: r,
CurPos: 0,
}
m := &Module{}
...
err = newSectionsReader(m).readSections(reader)
return m, nil
}
-
DecodeModule
新建sectionReader
, 并调用readSections
func (s *sectionsReader) readSections(r *readpos.ReadPos) error {
for {
// 循环读取section,知道读完
done, err := s.readSection(r)
switch {
case err != nil:
return err
case done:
return nil
}
}
}
// 从reader 中读取一个有效的 section. The first return value is true if and only if
// the module has been completely read.
func (sr *sectionsReader) readSection(r *readpos.ReadPos) (bool, error) {
m := sr.m
logger.Println("Reading section ID")
// 从 reader 中读取一个字节
id, err := r.ReadByte()
...
s := RawSection{ID: SectionID(id)}
logger.Println("Reading payload length")
// 读取实际 数据
payloadDataLen, err := leb128.ReadVarUint32(r)
if err != nil {
return false, err
}
logger.Printf("Section payload length: %d", payloadDataLen)
s.Start = r.CurPos
sectionBytes := new(bytes.Buffer)
sectionBytes.Grow(int(getInitialCap(payloadDataLen)))
sectionReader := io.LimitReader(io.TeeReader(r, sectionBytes), int64(payloadDataLen))
// 判断section 的类型,并将该类型空的 struct 赋值给 module 对应的属性
var sec Section
switch s.ID {
case SectionIDCustom:
logger.Println("section custom")
cs := &SectionCustom{}
m.Customs = append(m.Customs, cs)
sec = cs
case SectionIDType:
logger.Println("section type")
m.Types = &SectionTypes{}
sec = m.Types
case SectionIDImport:
logger.Println("section import")
m.Import = &SectionImports{}
sec = m.Import
case SectionIDFunction:
logger.Println("section function")
m.Function = &SectionFunctions{}
sec = m.Function
case SectionIDTable:
logger.Println("section table")
m.Table = &SectionTables{}
sec = m.Table
case SectionIDMemory:
logger.Println("section memory")
m.Memory = &SectionMemories{}
sec = m.Memory
case SectionIDGlobal:
logger.Println("section global")
m.Global = &SectionGlobals{}
sec = m.Global
case SectionIDExport:
logger.Println("section export")
m.Export = &SectionExports{}
sec = m.Export
case SectionIDStart:
logger.Println("section start")
m.Start = &SectionStartFunction{}
sec = m.Start
case SectionIDElement:
logger.Println("section element")
m.Elements = &SectionElements{}
sec = m.Elements
case SectionIDCode:
logger.Println("section code")
m.Code = &SectionCode{}
sec = m.Code
case SectionIDData:
logger.Println("section data")
m.Data = &SectionData{}
sec = m.Data
default:
return false, InvalidSectionIDError(s.ID)
}
// 从reader 中读取数据,存入 section (对应到 module 的某个变量中)
err = sec.ReadPayload(sectionReader)
if err != nil {
logger.Println(err)
return false, err
}
s.End = r.CurPos
s.Bytes = sectionBytes.Bytes()
// 将 raw s 保存到 对应的 xxxSection 中
*sec.GetRawSection() = s
...
// 保存 section
m.Sections = append(m.Sections, sec)
return false, nil
}
- 将文件读取到 module 中后,还需要加载 import 的模块
// 解析import 的函数
func (module *Module) resolveImports(resolve ResolveFunc) error {
if module.Import == nil {
return nil
}
modules := make(map[string]*Module)
var funcs uint32
// 遍历 module.Import 下的 ”入口“
for _, importEntry := range module.Import.Entries {
importedModule, ok := modules[importEntry.ModuleName]
if !ok {
var err error
// 如果不存在,就调用外部注入的 resolver 函数解析,并返回 module 对象
importedModule, err = resolve(importEntry.ModuleName)
if err != nil {
return err
}
// 将导入的 module 保存起来
modules[importEntry.ModuleName] = importedModule
}
if importedModule.Export == nil {
return ErrNoExportsInImportedModule
}
// 判断 导入的module 中是否暴露了 importEntry.FieldName(本module 需要调用的方法)
exportEntry, ok := importedModule.Export.Entries[importEntry.FieldName]
if !ok {
return ExportNotFoundError{importEntry.ModuleName, importEntry.FieldName}
}
// 判断 待导入函数类型, 与被导入模块的函数类型 是否一致
if exportEntry.Kind != importEntry.Type.Kind() {
return KindMismatchError{
FieldName: importEntry.FieldName,
ModuleName: importEntry.ModuleName,
Import: importEntry.Type.Kind(),
Export: exportEntry.Kind,
}
}
index := exportEntry.Index
switch exportEntry.Kind {
case ExternalFunction:
// 根据 exportEntry 对应的 functionIndex ,获取对应的 Function 类型
fn := importedModule.GetFunction(int(index))
if fn == nil {
return InvalidFunctionIndexError(index)
}
importIndex := importEntry.Type.(FuncImport).Type
// 下面就判断 待带入的function 和 别导入的 function 的类型是否一致
// 比较参数以及返回值长度
if len(fn.Sig.ReturnTypes) != len(module.Types.Entries[importIndex].ReturnTypes) || len(fn.Sig.ParamTypes) != len(module.Types.Entries[importIndex].ParamTypes) {
return InvalidImportError{importEntry.ModuleName, importEntry.FieldName, importIndex}
}
// 比较返回值类型
for i, typ := range fn.Sig.ReturnTypes {
if typ != module.Types.Entries[importIndex].ReturnTypes[i] {
return InvalidImportError{importEntry.ModuleName, importEntry.FieldName, importIndex}
}
}
// 比较参数类型
for i, typ := range fn.Sig.ParamTypes {
if typ != module.Types.Entries[importIndex].ParamTypes[i] {
return InvalidImportError{importEntry.ModuleName, importEntry.FieldName, importIndex}
}
}
// 将 Function 对象(被导入的函数),添加到 module 的 FunctionIndexSpace 数组中
module.FunctionIndexSpace = append(module.FunctionIndexSpace, *fn)
// 保存 Function 的函数体
module.Code.Bodies = append(module.Code.Bodies, *fn.Body)
// 将 Function 对象保存到 module 的 import.Funcs 数组中
module.imports.Funcs = append(module.imports.Funcs, funcs)
funcs++
case ExternalGlobal:
// todo ...
glb := importedModule.GetGlobal(int(index))
if glb == nil {
return InvalidGlobalIndexError(index)
}
if glb.Type.Mutable {
return ErrImportMutGlobal
}
module.GlobalIndexSpace = append(module.GlobalIndexSpace, *glb)
module.imports.Globals++
// In both cases below, index should be always 0 (according to the MVP)
// We check it against the length of the index space anyway.
case ExternalTable:
if int(index) >= len(importedModule.TableIndexSpace) {
return InvalidTableIndexError(index)
}
module.TableIndexSpace[0] = importedModule.TableIndexSpace[0]
module.imports.Tables++
case ExternalMemory:
if int(index) >= len(importedModule.LinearMemoryIndexSpace) {
return InvalidLinearMemoryIndexError(index)
}
module.LinearMemoryIndexSpace[0] = importedModule.LinearMemoryIndexSpace[0]
module.imports.Memories++
default:
return InvalidExternalError(exportEntry.Kind)
}
}
return nil
}
- populateFunctions
// 函数索引空间索引所有导入和内部定义的函数定义
func (m *Module) populateFunctions() error {
...
// 给内部定义的 func 构造 fn
// Add the functions from the wasm itself to the function list
numImports := len(m.FunctionIndexSpace)
for codeIndex, typeIndex := range m.Function.Types {
if int(typeIndex) >= len(m.Types.Entries) {
return InvalidFunctionIndexError(typeIndex)
}
// Create the main function structure
fn := Function{
Sig: &m.Types.Entries[typeIndex],
Body: &m.Code.Bodies[codeIndex],
Name: names[uint32(codeIndex+numImports)], // Add the name string if we have it
}
m.FunctionIndexSpace = append(m.FunctionIndexSpace, fn)
}
funcs := make([]uint32, 0, len(m.Function.Types)+len(m.imports.Funcs))
funcs = append(funcs, m.imports.Funcs...)
funcs = append(funcs, m.Function.Types...)
m.Function.Types = funcs
return nil
}
- 新建一个 VM
先看 vm 类型
// VM is the execution context for executing WebAssembly bytecode.
type VM struct {
ctx context // 执行上下文
type context struct {
stack []uint64 // 栈深度
locals []uint64 // 局部变量
code []byte // 函数的字节码
asm []asmBlock
pc int64 // 当前的字节码 index
curFunc int64 // 当前函数在 funcs 的index
}
module *wasm.Module
globals []uint64
memory []byte
funcs []function // 函数数组 compiledFunction or goFunction
funcTable [256]func() // 指令集,对应的解析函数
// RecoverPanic controls whether the `ExecCode` method
// recovers from a panic and returns it as an error
// instead.
// A panic can occur either when executing an invalid VM
// or encountering an invalid instruction, e.g. `unreachable`.
RecoverPanic bool
abort bool // Flag for host functions to terminate execution
nativeBackend *nativeCompiler
}
// 通过 module 对象和 options 构造一个 vm
func NewVM(module *wasm.Module, opts ...VMOption) (*VM, error) {
var vm VM
var options config
for _, opt := range opts {
opt(&options)
}
if module.Memory != nil && len(module.Memory.Entries) != 0 {
if len(module.Memory.Entries) > 1 {
return nil, ErrMultipleLinearMemories
}
vm.memory = make([]byte, uint(module.Memory.Entries[0].Limits.Initial)*wasmPageSize)
copy(vm.memory, module.LinearMemoryIndexSpace[0])
}
vm.funcs = make([]function, len(module.FunctionIndexSpace))
vm.globals = make([]uint64, len(module.GlobalIndexSpace))
vm.newFuncTable()
vm.module = module
nNatives := 0
for i, fn := range module.FunctionIndexSpace {
// 如果是 import 的原生 golang 方法,使用 goFunction 处理
if fn.IsHost() {
vm.funcs[i] = goFunction{
typ: fn.Host.Type(),
val: fn.Host,
}
nNatives++
continue
}
// 将function拆卸并封装成新的结构
disassembly, err := disasm.NewDisassembly(fn, module)
if err != nil {
return nil, err
}
totalLocalVars := 0
totalLocalVars += len(fn.Sig.ParamTypes)
for _, entry := range fn.Body.Locals {
totalLocalVars += int(entry.Count)
}
// 编译 字节码
code, meta := compile.Compile(disassembly.Code)
vm.funcs[i] = compiledFunction{
codeMeta: meta,
code: code,
branchTables: meta.BranchTables,
maxDepth: disassembly.MaxDepth,
totalLocalVars: totalLocalVars,
args: len(fn.Sig.ParamTypes),
returns: len(fn.Sig.ReturnTypes) != 0,
}
}
...
return &vm, nil
}
Interpreter
下面执行代码的过程,即是翻译代码的过程
// fnIndex 函数的index, args 是该函数的参数
func (vm *VM) ExecCode(fnIndex int64, args ...uint64) (rtrn interface{}, err error) {
...
if int(fnIndex) > len(vm.funcs) {
return nil, InvalidFunctionIndexError(fnIndex)
}
if len(vm.module.GetFunction(int(fnIndex)).Sig.ParamTypes) != len(args) {
return nil, ErrInvalidArgumentCount
}
compiled, ok := vm.funcs[fnIndex].(compiledFunction)
if !ok {
panic(fmt.Sprintf("exec: function at index %d is not a compiled function", fnIndex))
}
depth := compiled.maxDepth + 1
// 初始化执行栈
if cap(vm.ctx.stack) < depth {
vm.ctx.stack = make([]uint64, 0, depth)
} else {
vm.ctx.stack = vm.ctx.stack[:0]
}
vm.ctx.locals = make([]uint64, compiled.totalLocalVars)
vm.ctx.pc = 0
vm.ctx.code = compiled.code
vm.ctx.asm = compiled.asm
vm.ctx.curFunc = fnIndex
// 给函数的参数赋值
for i, arg := range args {
vm.ctx.locals[i] = arg
}
res := vm.execCode(compiled)
if compiled.returns {
rtrnType := vm.module.GetFunction(int(fnIndex)).Sig.ReturnTypes[0]
switch rtrnType {
case wasm.ValueTypeI32:
rtrn = uint32(res)
case wasm.ValueTypeI64:
rtrn = uint64(res)
case wasm.ValueTypeF32:
rtrn = math.Float32frombits(uint32(res))
case wasm.ValueTypeF64:
rtrn = math.Float64frombits(res)
default:
return nil, InvalidReturnTypeError(rtrnType)
}
}
return rtrn, nil
}
func (vm *VM) execCode(compiled compiledFunction) uint64 {
outer:
for int(vm.ctx.pc) < len(vm.ctx.code) && !vm.abort {
op := vm.ctx.code[vm.ctx.pc]
vm.ctx.pc++
switch op {
// 解析到 return 指令的时候,退出循环
case ops.Return:
break outer
// 省略一些不常用的case
...
default:
// 大部分会走这个case
vm.funcTable[op]()
}
}
if compiled.returns {
//如果有返回值,从栈中取出返回
return vm.ctx.stack[len(vm.ctx.stack)-1]
}
return 0
}
funcTable [256]func()
的初始化 ,一个指令(Op)
对应一个解析方法, 看看 Op
的结构
// Op describes a WASM operator.
type Op struct {
Code byte // The single-byte opcode
Name string // 该操作的名称
// Whether this operator is polymorphic.
// A polymorphic operator has a variable arity. call, call_indirect, and
// drop are examples of polymorphic operators.
Polymorphic bool // 是否是动态的, true:比如一些逻辑控制语句, 还有 get/setlocal 等
Args []wasm.ValueType // 该指令需要的参数类型(数量)(会从栈中pop出来)
Returns wasm.ValueType // 返回的参数类型
}
func (vm *VM) newFuncTable() {
vm.funcTable[ops.I32Clz] = vm.i32Clz
vm.funcTable[ops.I32Ctz] = vm.i32Ctz
vm.funcTable[ops.I32Popcnt] = vm.i32Popcnt
vm.funcTable[ops.I32Add] = vm.i32Add
vm.funcTable[ops.I32Sub] = vm.i32Sub
vm.funcTable[ops.I32Mul] = vm.i32Mul
....
....
vm.funcTable[ops.Drop] = vm.drop
vm.funcTable[ops.Select] = vm.selectOp
vm.funcTable[ops.GetLocal] = vm.getLocal
vm.funcTable[ops.SetLocal] = vm.setLocal
vm.funcTable[ops.TeeLocal] = vm.teeLocal
vm.funcTable[ops.GetGlobal] = vm.getGlobal
vm.funcTable[ops.SetGlobal] = vm.setGlobal
vm.funcTable[ops.Unreachable] = vm.unreachable
vm.funcTable[ops.Nop] = vm.nop
vm.funcTable[ops.Call] = vm.call
vm.funcTable[ops.CallIndirect] = vm.callIndirect
}
例如
// 从栈中pop 两个uint32 出来,相加后在push 到栈中
func (vm *VM) i32Add() {
vm.pushUint32(vm.popUint32() + vm.popUint32())
}
这里再讲一下 vm.call
func (vm *VM) call() {
index := vm.fetchUint32()
// 这里会从 funcs 数组里取出 Function(or goFunction) 对象,调用call
vm.funcs[index].call(vm, int64(index))
}
// goFunction 利用反射机制,执行函数
func (fn goFunction) call(vm *VM, index int64) {
// numIn = # of call inputs + vm, as the function expects
// an additional *VM argument
numIn := fn.typ.NumIn()
args := make([]reflect.Value, numIn)
proc := NewProcess(vm)
// 第一个参数必须是 *Process 类型
if reflect.ValueOf(proc).Kind() != fn.typ.In(0).Kind() {
panic(fmt.Sprintf("exec: the first argument of a host function was %s, expected %s", fn.typ.In(0).Kind(), reflect.ValueOf(vm).Kind()))
}
args[0] = reflect.ValueOf(proc)
// 给函数的参数赋值
for i := numIn - 1; i >= 1; i-- {
val := reflect.New(fn.typ.In(i)).Elem()
raw := vm.popUint64()
kind := fn.typ.In(i).Kind()
switch kind {
case reflect.Float64, reflect.Float32:
val.SetFloat(math.Float64frombits(raw))
case reflect.Uint32, reflect.Uint64:
val.SetUint(raw)
case reflect.Int32, reflect.Int64:
val.SetInt(int64(raw))
default:
panic(fmt.Sprintf("exec: args %d invalid kind=%v", i, kind))
}
args[i] = val
}
// 执行函数
rtrns := fn.val.Call(args)
// 将返回值 push 到栈中
for i, out := range rtrns {
kind := out.Kind()
switch kind {
case reflect.Float64, reflect.Float32:
vm.pushFloat64(out.Float())
case reflect.Uint32, reflect.Uint64:
vm.pushUint64(out.Uint())
case reflect.Int32, reflect.Int64:
vm.pushInt64(out.Int())
default:
panic(fmt.Sprintf("exec: return value %d invalid kind=%v", i, kind))
}
}
}
// 非原生函数实现
func (compiled compiledFunction) call(vm *VM, index int64) {
newStack := make([]uint64, 0, compiled.maxDepth+1)
locals := make([]uint64, compiled.totalLocalVars)
// 给参数赋值
for i := compiled.args - 1; i >= 0; i-- {
locals[i] = vm.popUint64()
}
//保存执行上下文
prevCtxt := vm.ctx
// 新建执行上下文
vm.ctx = context{
stack: newStack,
locals: locals,
code: compiled.code,
asm: compiled.asm,
pc: 0,
curFunc: index,
}
rtrn := vm.execCode(compiled)
//被调用函数执行完了,恢复上下文
vm.ctx = prevCtxt
if compiled.returns {
// 把返回值push到栈中
vm.pushUint64(rtrn)
}
}