以太坊源码阅读-数据结构篇-Trie树

encoding.go主要处理trie树的三种编码格式的相互转换工作,三种格式包括:KEYBYTES/HEX/COMPACT encoding。
// hexToKeybytes turns hex nibbles into key bytes.
// This can only be used for keys of even length.
func hexToKeybytes(hex []byte) []byte {
if hasTerm(hex) {
hex = hex[:len(hex)-1]
}
if len(hex)&1 != 0 {
panic("can't convert hex key of odd length")
}
key := make([]byte, len(hex)/2)
decodeNibbles(hex, key)
return key
}

func compactToHex(compact []byte) []byte {
base := keybytesToHex(compact)
// delete terminator flag
if base[0] < 2 {
base = base[:len(base)-1]
}
// apply odd flag
chop := 2 - base[0]&1 //按照转化为compact时的奇偶来决定-1还是-0
return base[chop:]
}

func keybytesToHex(str []byte) []byte {
l := len(str)2 + 1
var nibbles = make([]byte, l)
for i, b := range str {
nibbles[i
2] = b / 16
nibbles[i*2+1] = b % 16
}
nibbles[l-1] = 16
return nibbles
}

// hexToKeybytes turns hex nibbles into key bytes.
// This can only be used for keys of even length.
func hexToKeybytes(hex []byte) []byte {
if hasTerm(hex) {
hex = hex[:len(hex)-1]
}
if len(hex)&1 != 0 {
panic("can't convert hex key of odd length")
}
key := make([]byte, len(hex)/2)
decodeNibbles(hex, key)
return key
}

func decodeNibbles(nibbles []byte, bytes []byte) {
for bi, ni := 0, 0; ni < len(nibbles); bi, ni = bi+1, ni+2 {
bytes[bi] = nibbles[ni]<<4 | nibbles[ni+1]
}
}

// prefixLen returns the length of the common prefix of a and b.
func prefixLen(a, b []byte) int {
var i, length = 0, len(a)
if len(b) < length {
length = len(b)
}
for ; i < length; i++ {
if a[i] != b[i] {
break
}
}
return i
}

// hasTerm returns whether a hex key has the terminator flag.
func hasTerm(s []byte) bool {
return len(s) > 0 && s[len(s)-1] == 16
}

Trie树主要包含以下几种数据结构
node节点一共包含四种:rootNode, fullNode(branchNode), shortNode(leafNode, extensionNode)
type node interface {
fstring(string) string
cache() (hashNode, bool)
canUnload(cachegen, cachelimit uint16) bool
}

type (
fullNode struct {
Children [17]node // Actual trie node data to encode/decode (needs custom encoder)
flags nodeFlag
}
shortNode struct {
Key []byte
Val node //根据其类型判断valueNode or fullNode
flags nodeFlag
}
hashNode []byte
valueNode []byte
)

type Trie struct {
root node
db Database
originalRoot common.Hash

// Cache generation values.
// cachegen increases by one with each commit operation.
// new nodes are tagged with the current generation and unloaded
// when their generation is older than than cachegen-cachelimit.
cachegen, cachelimit uint16

}

Trie树的构造方法,接受一个hash值和database参数,如果hash存在则返回,不存在则重新构造。
func New(root common.Hash, db Database) (*Trie, error) {
trie := &Trie{db: db, originalRoot: root}
if (root != common.Hash{}) && root != emptyRoot {
if db == nil {
panic("trie.New: cannot use existing root without a database")
}
rootnode, err := trie.resolveHash(root[:], nil)
if err != nil {
return nil, err
}
trie.root = rootnode
}
return trie, nil
}

节点插入方法,node-当前插入节点,prefix是当前已经处理的key,key= prefix + key,value是带插入的值.
func (t *Trie) insert(n node, prefix, key []byte, value node) (bool, node, error) {
if len(key) == 0 {
if v, ok := n.(valueNode); ok {
return !bytes.Equal(v, value.(valueNode)), value, nil
}
return true, value, nil
}
switch n := n.(type) {
case *shortNode:
matchlen := prefixLen(key, n.Key)
// If the whole key matches, keep this short node as is
// and only update the value.
if matchlen == len(n.Key) {
dirty, nn, err := t.insert(n.Val, append(prefix, key[:matchlen]...), key[matchlen:], value)
if !dirty || err != nil {
return false, n, err
}
return true, &shortNode{n.Key, nn, t.newFlag()}, nil
}
// Otherwise branch out at the index where they differ.
branch := &fullNode{flags: t.newFlag()}
var err error
_, branch.Children[n.Key[matchlen]], err = t.insert(nil, append(prefix, n.Key[:matchlen+1]...), n.Key[matchlen+1:], n.Val)
if err != nil {
return false, nil, err
}
_, branch.Children[key[matchlen]], err = t.insert(nil, append(prefix, key[:matchlen+1]...), key[matchlen+1:], value)
if err != nil {
return false, nil, err
}
// Replace this shortNode with the branch if it occurs at index 0.
if matchlen == 0 {
return true, branch, nil
}
// Otherwise, replace it with a short node leading up to the branch.
return true, &shortNode{key[:matchlen], branch, t.newFlag()}, nil

case *fullNode:
    dirty, nn, err := t.insert(n.Children[key[0]], append(prefix, key[0]), key[1:], value)
    if !dirty || err != nil {
        return false, n, err
    }
    n = n.copy()
    n.flags = t.newFlag()
    n.Children[key[0]] = nn
    return true, n, nil

case nil:
    return true, &shortNode{key, value, t.newFlag()}, nil

case hashNode:
    // We've hit a part of the trie that isn't loaded yet. Load
    // the node and insert into it. This leaves all child nodes on
    // the path to the value in the trie.
    rn, err := t.resolveHash(n, prefix)
    if err != nil {
        return false, nil, err
    }
    dirty, nn, err := t.insert(rn, prefix, key, value)
    if !dirty || err != nil {
        return false, rn, err
    }
    return true, nn, nil

default:
    panic(fmt.Sprintf("%T: invalid node: %v", n, n))
}

}

Trie树的Get方法,简单的便利Trie树,来获取Key值。

func (t *Trie) tryGet(origNode node, key []byte, pos int) (value []byte, newnode node, didResolve bool, err error) {
switch n := (origNode).(type) {
case nil:
return nil, nil, false, nil
case valueNode:
return n, n, false, nil
case *shortNode:
if len(key)-pos < len(n.Key) || !bytes.Equal(n.Key, key[pos:pos+len(n.Key)]) {
// key not found in trie
return nil, n, false, nil
}
value, newnode, didResolve, err = t.tryGet(n.Val, key, pos+len(n.Key))
if err == nil && didResolve {
n = n.copy()
n.Val = newnode
n.flags.gen = t.cachegen
}
return value, n, didResolve, err
case *fullNode:
value, newnode, didResolve, err = t.tryGet(n.Children[key[pos]], key, pos+1)
if err == nil && didResolve {
n = n.copy()
n.flags.gen = t.cachegen
n.Children[key[pos]] = newnode
}
return value, n, didResolve, err
case hashNode:
child, err := t.resolveHash(n, key[:pos])
if err != nil {
return nil, n, true, err
}
value, newnode, _, err := t.tryGet(child, key, pos)
return value, newnode, true, err
default:
panic(fmt.Sprintf("%T: invalid node: %v", origNode, origNode))
}
}

Trie树的序列化和反序列化

store方法,如果一个node的所有子节点都替换成了子节点的hash值,那么直接调用rlp.Encode方法对这个节点进行编码,如果编码后的值小于32,并且这个节点不是根节点,那么就把他们直接存储在他们的父节点里面。调用h.sha.write方法进行hash运算,然后把hash值和编码后的数据存在db中,并返回hash值。进一步的追踪外部引用来自账号和存储trie
func (h *hasher) store(n node, db *Database, force bool) (node, error) {
// Don't store hashes or empty nodes.
if _, isHash := n.(hashNode); n == nil || isHash {
return n, nil
}
// Generate the RLP encoding of the node
h.tmp.Reset()
if err := rlp.Encode(&h.tmp, n); err != nil {
panic("encode error: " + err.Error())
}
if len(h.tmp) < 32 && !force {
return n, nil // Nodes smaller than 32 bytes are stored inside their parent
}
// Larger nodes are replaced by their hash and stored in the database.
hash, _ := n.cache()
if hash == nil {
hash = h.makeHashNode(h.tmp)
}

if db != nil {
    // We are pooling the trie nodes into an intermediate memory cache
    hash := common.BytesToHash(hash)

    db.lock.Lock()
    db.insert(hash, h.tmp, n)
    db.lock.Unlock()

    // Track external references from account->storage trie
    if h.onleaf != nil {
        switch n := n.(type) {
        case *shortNode:
            if child, ok := n.Val.(valueNode); ok {
                h.onleaf(child, hash)
            }
        case *fullNode:
            for i := 0; i < 16; i++ {
                if child, ok := n.Children[i].(valueNode); ok {
                    h.onleaf(child, hash)
                }
            }
        }
    }
}
return hash, nil

}

反序列化
func (t *Trie) resolveHash(n hashNode, prefix []byte) (node, error) {
cacheMissCounter.Inc(1)

hash := common.BytesToHash(n)
if node := t.db.node(hash, t.cachegen); node != nil {
    return node, nil
}
return nil, &MissingNodeError{NodeHash: hash, Path: prefix}

}

func (db *Database) node(hash common.Hash, cachegen uint16) node {
// Retrieve the node from cache if available
db.lock.RLock()
node := db.nodes[hash]
db.lock.RUnlock()

if node != nil {
    return node.obj(hash, cachegen)
}
// Content unavailable in memory, attempt to retrieve from disk
enc, err := db.diskdb.Get(hash[:])
if err != nil || enc == nil {
    return nil
}
return mustDecodeNode(hash[:], enc, cachegen)

}

func mustDecodeNode(hash, buf []byte, cachegen uint16) node {
n, err := decodeNode(hash, buf, cachegen)
if err != nil {
panic(fmt.Sprintf("node %x: %v", hash, err))
}
return n
}

// decodeNode parses the RLP encoding of a trie node.
func decodeNode(hash, buf []byte, cachegen uint16) (node, error) {
if len(buf) == 0 {
return nil, io.ErrUnexpectedEOF
}
elems, _, err := rlp.SplitList(buf)
if err != nil {
return nil, fmt.Errorf("decode error: %v", err)
}
switch c, _ := rlp.CountValues(elems); c {
case 2:
n, err := decodeShort(hash, elems, cachegen)
return n, wrapError(err, "short")
case 17:
n, err := decodeFull(hash, elems, cachegen)
return n, wrapError(err, "full")
default:
return nil, fmt.Errorf("invalid number of list elements: %v", c)
}
}

decodeShort方法,通过key是否有终结符号来判断到底是叶子节点还是中间节点。如果有终结符那么就是叶子结点,通过SplitString方法解析出来val然后生成一个shortNode。 不过没有终结符,那么说明是扩展节点, 通过decodeRef来解析剩下的节点,然后生成一个shortNode。

func decodeShort(hash, buf, elems []byte, cachegen uint16) (node, error) {
kbuf, rest, err := rlp.SplitString(elems)
if err != nil {
return nil, err
}
flag := nodeFlag{hash: hash, gen: cachegen}
key := compactToHex(kbuf)
if hasTerm(key) {
// value node
val, _, err := rlp.SplitString(rest)
if err != nil {
return nil, fmt.Errorf("invalid value node: %v", err)
}
return &shortNode{key, append(valueNode{}, val...), flag}, nil
}
r, _, err := decodeRef(rest, cachegen)
if err != nil {
return nil, wrapError(err, "val")
}
return &shortNode{key, r, flag}, nil
}
decodeRef方法根据数据类型进行解析,如果类型是list,那么有可能是内容<32的值,那么调用decodeNode进行解析。 如果是空节点,那么返回空,如果是hash值,那么构造一个hashNode进行返回,注意的是这里没有继续进行解析,如果需要继续解析这个hashNode,那么需要继续调用resolveHash方法。
func decodeRef(buf []byte, cachegen uint16) (node, []byte, error) {
kind, val, rest, err := rlp.Split(buf)
if err != nil {
return nil, buf, err
}
switch {
case kind == rlp.List:
// 'embedded' node reference. The encoding must be smaller
// than a hash in order to be valid.
if size := len(buf) - len(rest); size > hashLen {
err := fmt.Errorf("oversized embedded node (size is %d bytes, want size < %d)", size, hashLen)
return nil, buf, err
}
n, err := decodeNode(nil, buf, cachegen)
return n, rest, err
case kind == rlp.String && len(val) == 0:
// empty node
return nil, rest, nil
case kind == rlp.String && len(val) == 32:
return append(hashNode{}, val...), rest, nil
default:
return nil, nil, fmt.Errorf("invalid RLP string size %d (want 0 or 32)", len(val))
}
}

decodeFull方法。根decodeShort方法的流程差不多。

func decodeFull(hash, buf, elems []byte, cachegen uint16) (*fullNode, error) {
n := &fullNode{flags: nodeFlag{hash: hash, gen: cachegen}}
for i := 0; i < 16; i++ {
cld, rest, err := decodeRef(elems, cachegen)
if err != nil {
return n, wrapError(err, fmt.Sprintf("[%d]", i))
}
n.Children[i], elems = cld, rest
}
val, _, err := rlp.SplitString(elems)
if err != nil {
return n, err
}
if len(val) > 0 {
n.Children[16] = append(valueNode{}, val...)
}
return n, nil
}

rie树每一次调用Commit方法,会导致当前的cachegen增加1。
func (t *Trie) CommitTo(db DatabaseWriter) (root common.Hash, err error) {
hash, cached, err := t.hashRoot(db)
if err != nil {
return (common.Hash{}), err
}
t.root = cached
t.cachegen++
return common.BytesToHash(hash.(hashNode)), nil
}

然后在Trie树插入的时候,会把当前的cachegen存放到节点中。
func (t *Trie) insert(n node, prefix, key []byte, value node) (bool, node, error) {
....
return true, &shortNode{n.Key, nn, t.newFlag()}, nil
}

// newFlag returns the cache flag value for a newly created node.
func (t *Trie) newFlag() nodeFlag {
return nodeFlag{dirty: true, gen: t.cachegen}
}

Prove方法,从根节点开始。把经过的节点的hash值一个一个存入到list中。然后返回。
func (t *Trie) Prove(key []byte, fromLevel uint, proofDb ethdb.Putter) error {
// Collect all nodes on the path to key.
key = keybytesToHex(key)
nodes := []node{}
tn := t.root
for len(key) > 0 && tn != nil {
switch n := tn.(type) {
case *shortNode:
if len(key) < len(n.Key) || !bytes.Equal(n.Key, key[:len(n.Key)]) {
// The trie doesn't contain the key.
tn = nil
} else {
tn = n.Val
key = key[len(n.Key):]
}
nodes = append(nodes, n)
case *fullNode:
tn = n.Children[key[0]]
key = key[1:]
nodes = append(nodes, n)
case hashNode:
var err error
tn, err = t.resolveHash(n, nil)
if err != nil {
log.Error(fmt.Sprintf("Unhandled trie error: %v", err))
return err
}
default:
panic(fmt.Sprintf("%T: invalid node: %v", tn, tn))
}
}
hasher := newHasher(0, 0, nil)
for i, n := range nodes {
// Don't bother checking for errors here since hasher panics
// if encoding doesn't work and we're not writing to any database.
n, _, _ = hasher.hashChildren(n, nil)
hn, _ := hasher.store(n, nil, false)
if hash, ok := hn.(hashNode); ok || i == 0 {
// If the node's database encoding is a hash (or is the
// root node), it becomes a proof element.
if fromLevel > 0 {
fromLevel--
} else {
enc, _ := rlp.EncodeToBytes(n)
if !ok {
hash = crypto.Keccak256(enc)
}
proofDb.Put(hash, enc)
}
}
}
return nil
}

func get(tn node, key []byte) ([]byte, node) {
for {
switch n := tn.(type) {
case *shortNode:
if len(key) < len(n.Key) || !bytes.Equal(n.Key, key[:len(n.Key)]) {
return nil, nil
}
tn = n.Val
key = key[len(n.Key):]
case *fullNode:
tn = n.Children[key[0]]
key = key[1:]
case hashNode:
return key, n
case nil:
return key, nil
case valueNode:
return nil, n
default:
panic(fmt.Sprintf("%T: invalid node: %v", tn, tn))
}
}
}

为了避免刻意使用很长的key导致访问时间的增加, security_trie包装了一下trie树, 所有的key都转换成keccak256算法计算的hash值。同时在数据库里面存储hash值对应的原始的key。

type SecureTrie struct {
trie Trie //原始的Trie树
hashKeyBuf [secureKeyLength]byte //计算hash值的buf
secKeyBuf [200]byte //hash值对应的key存储的时候的数据库前缀
secKeyCache map[string][]byte //记录hash值和对应的key的映射
secKeyCacheOwner *SecureTrie // Pointer to self, replace the key cache on mismatch
}

func NewSecure(root common.Hash, db Database, cachelimit uint16) (*SecureTrie, error) {
if db == nil {
panic("NewSecure called with nil database")
}
trie, err := New(root, db)
if err != nil {
return nil, err
}
trie.SetCacheLimit(cachelimit)
return &SecureTrie{trie: *trie}, nil
}

// Get returns the value for key stored in the trie.
// The value bytes must not be modified by the caller.
func (t *SecureTrie) Get(key []byte) []byte {
res, err := t.TryGet(key)
if err != nil {
log.Error(fmt.Sprintf("Unhandled trie error: %v", err))
}
return res
}

// TryGet returns the value for key stored in the trie.
// The value bytes must not be modified by the caller.
// If a node was not found in the database, a MissingNodeError is returned.
func (t *SecureTrie) TryGet(key []byte) ([]byte, error) {
return t.trie.TryGet(t.hashKey(key))
}
func (t *SecureTrie) CommitTo(db DatabaseWriter) (root common.Hash, err error) {
if len(t.getSecKeyCache()) > 0 {
for hk, key := range t.secKeyCache {
if err := db.Put(t.secKey([]byte(hk)), key); err != nil {
return common.Hash{}, err
}
}
t.secKeyCache = make(map[string][]byte)
}
return t.trie.CommitTo(db)
}

你可能感兴趣的:(以太坊源码阅读-数据结构篇-Trie树)