源码分享-golang的二进制文件读写库

源码分享-golang的二进制文件读写库

  • 库功能
  • 库源码
    • decode.go
    • encode.go

库功能

功能类似golang标准库encoding/binary,用于二进制码流/文件的读写。对比标准库,本库对以下方面做了功能增强:

  1. 支持bit级别的结构体成员编解码
  2. 支持bitsort结构体标签,用于指定结构体成员的bit数、大小端属性
  3. 支持map类型结构
  4. 支持string字符串
  5. 支持自定义方法编解码

库源码

decode.go

package binary

import (
	"fmt"
	"io"
	"math"
	"reflect"
)

type Unmarshaler interface {
	UnmarshalBinary(dec *Decoder, isBig bool, bit int) error
}

func Unmarshal(buf []byte, isBig bool, e ...any) error {
	return NewDecoder(buf, 0).Unmarshal(isBig, e...)
}

var EOF = newErr("read the end")

type Decoder struct {
	buf     []byte
	prevOff int
	off     int
	bit     int
	arg     any
}

func NewDecoder(buf []byte, prevOff int) *Decoder {
	return &Decoder{buf: buf, prevOff: prevOff}
}

func NewReaderDecoder(r io.Reader) *Decoder {
	buf, _ := io.ReadAll(r)
	return NewDecoder(buf, 0)
}

func (dec *Decoder) Pos() int {
	return (dec.prevOff+dec.off)*8 + dec.bit
}

func (dec *Decoder) Seek(pos int) error {
	if pos < dec.prevOff*8 || pos > (dec.prevOff+len(dec.buf))*8 {
		return fmtErr("pos(%d.%d) illegal", pos/8, pos%8)
	}
	dec.off = pos/8 - dec.prevOff
	dec.bit = pos % 8
	return nil
}

func (dec *Decoder) SubDecoder(n int) *Decoder {
	prevOff := dec.off
	if dec.bit > 0 {
		prevOff++
	}
	if n <= 0 || prevOff+n >= len(dec.buf) {
		n = len(dec.buf) - prevOff
	}
	return &Decoder{
		buf:     dec.buf[prevOff : prevOff+n],
		prevOff: dec.prevOff + prevOff,
		arg:     dec.arg,
	}
}

func (dec *Decoder) SetArg(a any) {
	dec.arg = a
}

func (dec *Decoder) Arg() any {
	return dec.arg
}

func (dec *Decoder) Read(p []byte) (n int, err error) {
	if dec.bit > 0 {
		dec.bit = 0
		dec.off++
	}
	if len(p) == 0 {
		return 0, nil
	}
	n = len(dec.buf) - dec.off
	if n <= 0 {
		return 0, EOF
	}
	if n > len(p) {
		n = len(p)
	}
	copy(p, dec.buf[dec.off:dec.off+n])
	dec.off += n
	return n, nil
}

func (dec *Decoder) ReadByte() (byte, error) {
	if dec.bit > 0 {
		dec.bit = 0
		dec.off++
	}
	if dec.off >= len(dec.buf) {
		return 0, EOF
	}
	b := dec.buf[dec.off]
	dec.off++
	return b, nil
}

func (dec *Decoder) ReadBytes(n int) []byte {
	if dec.bit > 0 {
		dec.bit = 0
		dec.off++
	}
	if dec.off >= len(dec.buf) {
		panic(EOF)
	}
	if n <= 0 || dec.off+n > len(dec.buf) {
		n = len(dec.buf) - dec.off
	}
	bs := dec.buf[dec.off : dec.off+n]
	dec.off += n
	return bs
}

func (dec *Decoder) IsEof() bool {
	if dec.bit > 0 {
		return false
	}
	return dec.off == len(dec.buf)
}

func (dec *Decoder) Unmarshal(isBig bool, e ...any) error {
	for _, v := range e {
		err := dec.Decode(v, isBig, 0)
		if err != nil {
			return err
		}
	}
	return nil
}

func (dec *Decoder) Decode(a any, isBig bool, bit int) (err error) {
	defer func() {
		r := recover()
		if r == nil {
			return
		}
		if e, ok := r.(decErr); ok {
			err = e
		} else {
			panic(r)
		}
	}()
	if a == nil {
		return
	}

	if v, ok := a.(reflect.Value); ok {
		dec.decode(v, isBig, bit)
	}
	dec.decode(reflect.ValueOf(a), isBig, bit)
	return
}

func (dec *Decoder) decode(v reflect.Value, isBig bool, bit int) {
	if !v.IsValid() {
		return
	}
	if v.CanInterface() && dec.handleMethods(v, isBig, bit) {
		return
	}
	if v.Kind() != reflect.Pointer && v.CanAddr() && dec.handleMethods(v.Addr(), isBig, bit) {
		return
	}

	switch v.Kind() {
	case reflect.Bool:
		if v.CanSet() {
			v.SetBool(dec.decodeInteger(8, isBig, bit) != 0)
		}
	case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
		if v.CanSet() {
			v.SetInt(int64(dec.decodeInteger(kindSize(v.Kind()), isBig, bit)))
		}
	case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr:
		if v.CanSet() {
			v.SetUint(dec.decodeInteger(kindSize(v.Kind()), isBig, bit))
		}
	case reflect.Float32, reflect.Float64:
		if v.CanSet() {
			v.SetFloat(dec.decodeFloat(kindSize(v.Kind()), isBig))
		}
	case reflect.Complex64, reflect.Complex128:
		if v.CanSet() {
			v.SetComplex(dec.decodeComplex(kindSize(v.Kind()), isBig))
		}
	case reflect.Slice:
		if v.IsNil() {
			dec.decodeSlice(v, isBig, bit)
			break
		}
		fallthrough
	case reflect.Array:
		l := v.Len()
		for i := 0; i < l; i++ {
			dec.decode(v.Index(i), isBig, bit)
		}
	case reflect.Interface, reflect.Pointer:
		if v.IsNil() {
			break
		}
		dec.decode(v.Elem(), isBig, bit)
	case reflect.Map:
		dec.decodeMap(v, isBig, bit)
	case reflect.String:
		if v.CanSet() {
			v.SetString(dec.decodeString())
		}
	case reflect.Struct:
		dec.decodeStruct(v, isBig)
	default:
		return
	}
}

func (dec *Decoder) handleMethods(v reflect.Value, isBig bool, bit int) bool {
	f, ok := v.Interface().(Unmarshaler)
	if !ok {
		return false
	}
	err := f.UnmarshalBinary(dec, isBig, bit)
	if err == nil {
		return true
	}
	var e error
	if e, ok = err.(decErr); ok {
		panic(e)
	}
	panic(newErr(err.Error()))
}

func (dec *Decoder) decodeInteger(size int, isBig bool, bit int) uint64 {
	decFunc := dec.decodeLittleInteger
	if bit > 0 {
		size = bit
		if isBig {
			decFunc = dec.decodeBigBits
		} else {
			decFunc = dec.decodeLittleBits
		}
	} else if isBig {
		decFunc = dec.decodeBigInteger
	}
	return decFunc(size)
}

func (dec *Decoder) decodeLittleBits(size int) uint64 {
	v := uint64(0)
	bit := size
	for bit > 0 {
		if dec.off >= len(dec.buf) {
			panic(EOF)
		}
		num := 8 - dec.bit
		if num > bit {
			num = bit
		}

		mask := byte((1 << num) - 1)
		v |= uint64((dec.buf[dec.off]>>dec.bit)&mask) << (size - bit)
		dec.bit += num
		bit -= num

		if dec.bit >= 8 {
			dec.bit -= 8
			dec.off++
		}
	}
	return v
}

func (dec *Decoder) decodeBigBits(size int) uint64 {
	v := uint64(0)
	bit := size
	for bit > 0 {
		if dec.off >= len(dec.buf) {
			panic(EOF)
		}
		num := 8 - dec.bit
		if num > bit {
			num = bit
		}

		mask := byte((1 << num) - 1)
		v <<= num
		v |= uint64((dec.buf[dec.off] >> (8 - dec.bit - num)) & mask)
		dec.bit += num
		bit -= num

		if dec.bit >= 8 {
			dec.bit -= 8
			dec.off++
		}
	}
	return v
}

func (dec *Decoder) decodeLittleInteger(size int) uint64 {
	if size <= 0 || size%8 != 0 {
		panic(fmtErr("unsupport integer size: %d", size))
	}
	size /= 8
	bs := dec.ReadBytes(size)
	v := uint64(0)
	for i := 0; i < size; i++ {
		v |= uint64(bs[i]) << (i * 8)
	}
	return v
}

func (dec *Decoder) decodeBigInteger(size int) uint64 {
	if size <= 0 || size%8 != 0 {
		panic(fmtErr("unsupport integer size: %d", size))
	}
	size /= 8
	bs := dec.ReadBytes(size)
	v := uint64(0)
	for i := 0; i < size; i++ {
		v <<= 8
		v |= uint64(bs[i])
	}
	return v
}

func (dec *Decoder) decodeFloat(size int, isBig bool) float64 {
	decFunc := dec.decodeLittleInteger
	if isBig {
		decFunc = dec.decodeBigInteger
	}
	i := decFunc(size)

	switch size {
	case 32:
		return float64(math.Float32frombits(uint32(i)))
	case 64:
		return math.Float64frombits(i)
	default:
		panic(fmtErr("unsupport float size: %d", size))
	}
}

func (dec *Decoder) decodeComplex(size int, isBig bool) complex128 {
	r := dec.decodeFloat(size/2, isBig)
	i := dec.decodeFloat(size/2, isBig)
	return complex(r, i)
}

func (dec *Decoder) decodeString() string {
	if dec.bit > 0 {
		dec.bit = 0
		dec.off++
	}
	if dec.off >= len(dec.buf) {
		panic(EOF)
	}
	bs := dec.buf[dec.off:]
	n := 0
	for n < len(bs) && bs[n] != 0 {
		n++
	}
	dec.off += n
	if n < len(bs) {
		dec.off++
	}
	return string(bs[:n])
}

func (dec *Decoder) decodeSlice(v reflect.Value, isBig bool, bit int) {
	if !v.CanSet() {
		return
	}
	t := v.Type().Elem()
	if t.Kind() == reflect.Uint8 && bit == 0 {
		v.SetBytes(dec.ReadBytes(0))
		return
	}

	vt := v
	for i := 0; !dec.IsEof(); i++ {
		sv := reflect.New(t).Elem()
		dec.decode(sv, isBig, bit)
		vt = reflect.Append(vt, sv)
	}
	v.Set(vt)
}

func (dec *Decoder) decodeMap(v reflect.Value, isBig bool, bit int) {
	if !v.CanSet() {
		return
	}
	kt := v.Type().Key()
	vt := v.Type().Elem()
	switch kt.Kind() {
	case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
	case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr:
	case reflect.String:
	default:
		panic(newErr("unsupported key type of map: " + kt.String()))
	}

	for !dec.IsEof() {
		kv := reflect.New(kt).Elem()
		vv := reflect.New(vt).Elem()
		dec.decode(kv, isBig, 0)
		dec.decode(vv, isBig, bit)
		v.SetMapIndex(kv, vv)
	}
}

func (dec *Decoder) decodeStruct(v reflect.Value, isBig bool) {
	t := v.Type()
	l := v.NumField()
	for i := 0; i < l; i++ {
		sv := v.Field(i)
		st := t.Field(i)
		if !st.IsExported() {
			continue
		}
		if !sv.CanSet() && st.Name == "_" {
			continue
		}

		big, bit, e := getSortBitFromTag(&st.Tag, isBig)
		if len(e) > 0 {
			panic(newErr(e))
		}

		dec.decode(sv, big, bit)
	}
}

type decErr string

func newErr(err string) decErr {
	return decErr(err)
}

func fmtErr(format string, a ...any) decErr {
	return decErr(fmt.Sprintf(format, a...))
}

func (err decErr) Error() string {
	return string(err)
}

encode.go

package binary

import (
	"bytes"
	"errors"
	"fmt"
	"io"
	"math"
	"reflect"
	"strconv"
)

type Marshaler interface {
	MarshalBinary(enc *Encoder, isBig bool, bit int) error
}

func Marshal(isBig bool, e ...any) ([]byte, error) {
	buf := bytes.NewBuffer(nil)
	err := NewEncoder(buf).Marshal(isBig, e...)
	if err != nil {
		return nil, err
	}
	return buf.Bytes(), nil
}

func kindSize(kind reflect.Kind) int {
	size := 0
	switch kind {
	case reflect.Int, reflect.Uint:
		size = 32
	case reflect.Bool, reflect.Int8, reflect.Uint8:
		size = 8
	case reflect.Int16, reflect.Uint16:
		size = 16
	case reflect.Int32, reflect.Uint32, reflect.Float32:
		size = 32
	case reflect.Int64, reflect.Uint64, reflect.Uintptr, reflect.Float64, reflect.Complex64:
		size = 64
	case reflect.Complex128:
		size = 128
	default:
	}
	return size
}

func getSortBitFromTag(tag *reflect.StructTag, defBit bool) (isBit bool, bit int, err string) {
	isBit = defBit
	sort := tag.Get("sort")
	if len(sort) > 0 {
		if sort == "big" {
			isBit = true
		} else if sort == "little" {
			isBit = false
		} else {
			err = "unsupported tag: `sort:\"" + sort + "\"`"
			return
		}
	}

	bitS := tag.Get("bit")
	if len(bitS) > 0 {
		num, e := strconv.Atoi(bitS)
		if (e != nil) || (num <= 0 || num > 64) {
			err = "unsupported tag: `bit:\"" + bitS + "\"`"
			return
		}
		bit = num
	}
	return
}

type Encoder struct {
	io.Writer
	bit    int
	bitBuf byte
	arg    any
}

func NewEncoder(w io.Writer) *Encoder {
	return &Encoder{Writer: w}
}

func (enc *Encoder) Write(p []byte) (n int, err error) {
	if enc.bit > 0 {
		n, err = enc.Writer.Write([]byte{enc.bitBuf})
		if err != nil || n < 1 {
			return
		}
		enc.bitBuf = 0
		enc.bit = 0
	}
	if len(p) == 0 {
		return 0, nil
	}
	return enc.Writer.Write(p)
}

func (enc *Encoder) WriteByte(c byte) error {
	if enc.bit > 0 {
		_, err := enc.Writer.Write([]byte{enc.bitBuf})
		if err != nil {
			return err
		}
		enc.bitBuf = 0
		enc.bit = 0
	}
	_, err := enc.Writer.Write([]byte{c})
	return err
}

func (enc *Encoder) WriteBytes(p []byte) error {
	n, err := enc.Write(p)
	for err == nil && n < len(p) {
		n1 := 0
		n1, err = enc.Write(p[n:])
		n += n1
	}
	if err != nil {
		return err
	}
	if n < len(p) {
		return fmt.Errorf("partial(%d) write, all(%d)", n, len(p))
	}
	return nil
}

func (enc *Encoder) SetArg(a any) {
	enc.arg = a
}

func (enc *Encoder) Arg() any {
	return enc.arg
}

func (enc *Encoder) Marshal(isBig bool, e ...any) error {
	for _, v := range e {
		err := enc.Encode(v, isBig, 0)
		if err != nil {
			return err
		}
	}
	return nil
}

func (enc *Encoder) Encode(a any, isBig bool, bit int) (err error) {
	if a == nil {
		return
	}

	switch v := a.(type) {
	case []byte:
		if bit == 0 {
			err = enc.WriteBytes(v)
		} else {
			err = enc.encode(reflect.ValueOf(a), isBig, bit)
		}
	case *[]byte:
		if bit == 0 {
			err = enc.WriteBytes(*v)
		} else {
			err = enc.encode(reflect.ValueOf(a), isBig, bit)
		}
	case reflect.Value:
		err = enc.encode(v, isBig, bit)
	default:
		err = enc.encode(reflect.ValueOf(a), isBig, bit)
	}
	return
}

func (enc *Encoder) encode(v reflect.Value, isBig bool, bit int) error {
	if !v.IsValid() {
		return nil
	}
	if v.CanInterface() {
		if f, ok := v.Interface().(Marshaler); ok {
			return f.MarshalBinary(enc, isBig, bit)
		}
	}
	if v.Kind() != reflect.Pointer && v.CanAddr() {
		if f, ok := v.Addr().Interface().(Marshaler); ok {
			return f.MarshalBinary(enc, isBig, bit)
		}
	}

	var err error
	switch v.Kind() {
	case reflect.Bool:
		t := uint64(0)
		if v.Bool() {
			t = 1
		}
		err = enc.encodeInteger(t, 8, isBig, bit)
	case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
		err = enc.encodeInteger(uint64(v.Int()), kindSize(v.Kind()), isBig, bit)
	case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr:
		err = enc.encodeInteger(v.Uint(), kindSize(v.Kind()), isBig, bit)
	case reflect.Float32, reflect.Float64:
		err = enc.encodeFloat(v.Float(), kindSize(v.Kind()), isBig)
	case reflect.Complex64, reflect.Complex128:
		err = enc.encodeComplex(v.Complex(), kindSize(v.Kind()), isBig)
	case reflect.Slice:
		if v.IsNil() {
			break
		}
		fallthrough
	case reflect.Array:
		l := v.Len()
		for i := 0; i < l; i++ {
			err = enc.encode(v.Index(i), isBig, bit)
			if err != nil {
				break
			}
		}
	case reflect.Interface, reflect.Pointer:
		if v.IsNil() {
			break
		}
		err = enc.encode(v.Elem(), isBig, bit)
	case reflect.Map:
		if v.IsNil() {
			break
		}
		iter := v.MapRange()
		for iter.Next() {
			sv := iter.Value()
			if !sv.IsValid() {
				continue
			}
			err = enc.encode(iter.Key(), isBig, 0)
			if err != nil {
				break
			}
			err = enc.encode(sv, isBig, bit)
			if err != nil {
				break
			}
		}
	case reflect.String:
		err = enc.WriteBytes(append([]byte(v.String()), 0))
	case reflect.Struct:
		err = enc.encodeStruct(v, isBig)
	default:
	}
	return err
}

func (enc *Encoder) encodeInteger(v uint64, size int, isBig bool, bit int) error {
	encFunc := enc.encodeLittleInteger
	if bit > 0 {
		size = bit
		if isBig {
			encFunc = enc.encodeBigBits
		} else {
			encFunc = enc.encodeLittleBits
		}
	} else if isBig {
		encFunc = enc.encodeBigInteger
	}
	return encFunc(v, size)
}

func (enc *Encoder) encodeLittleBits(v uint64, size int) error {
	bit := size
	for bit > 0 {
		num := 8 - enc.bit
		if num > bit {
			num = bit
		}

		mask := uint64(1<<num) - 1
		enc.bitBuf &= ^byte(mask << enc.bit)
		enc.bitBuf |= byte((v & mask) << enc.bit)
		v >>= num
		enc.bit += num
		bit -= num

		if enc.bit >= 8 {
			enc.bit -= 8
			n, err := enc.Writer.Write([]byte{enc.bitBuf})
			if err != nil {
				return err
			} else if n < 1 {
				return errors.New("write bit buffer error")
			}
			enc.bitBuf = 0
		}
	}
	return nil
}

func (enc *Encoder) encodeBigBits(v uint64, size int) error {
	bit := size
	for bit > 0 {
		num := 8 - enc.bit
		if num > bit {
			num = bit
		}

		mask := uint64(1<<num) - 1
		enc.bitBuf &= ^byte(mask << (8 - enc.bit - num))
		enc.bitBuf |= byte(((v >> (bit - num)) & mask) << (8 - enc.bit - num))
		enc.bit += num
		bit -= num

		if enc.bit >= 8 {
			enc.bit -= 8
			n, err := enc.Writer.Write([]byte{enc.bitBuf})
			if err != nil {
				return err
			} else if n < 1 {
				return errors.New("write bit buffer error")
			}
			enc.bitBuf = 0
		}
	}
	return nil
}

func (enc *Encoder) encodeLittleInteger(v uint64, size int) error {
	if size <= 0 || size%8 != 0 {
		return fmt.Errorf("unsupport integer size: %d", size)
	}
	size /= 8
	bs := make([]byte, size)
	for i := 0; i < size; i++ {
		bs[i] = byte(v)
		v >>= 8
	}
	return enc.WriteBytes(bs)
}

func (enc *Encoder) encodeBigInteger(v uint64, size int) error {
	if size <= 0 || size%8 != 0 {
		return fmt.Errorf("unsupport integer size: %d", size)
	}
	size /= 8
	bs := make([]byte, size)
	for i := size - 1; i >= 0; i-- {
		bs[i] = byte(v)
		v >>= 8
	}
	return enc.WriteBytes(bs)
}

func (enc *Encoder) encodeFloat(v float64, size int, isBig bool) error {
	encFunc := enc.encodeLittleInteger
	if isBig {
		encFunc = enc.encodeBigInteger
	}
	switch size {
	case 32:
		return encFunc(uint64(math.Float32bits(float32(v))), size)
	case 64:
		return encFunc(math.Float64bits(v), size)
	default:
		return fmt.Errorf("unsupport float size: %d", size)
	}
}

func (enc *Encoder) encodeComplex(v complex128, size int, isBig bool) error {
	err := enc.encodeFloat(real(v), size/2, isBig)
	if err != nil {
		return err
	}
	return enc.encodeFloat(imag(v), size/2, isBig)
}

func (enc *Encoder) encodeStruct(v reflect.Value, isBig bool) error {
	t := v.Type()
	l := v.NumField()
	for i := 0; i < l; i++ {
		sv := v.Field(i)
		st := t.Field(i)
		if !st.IsExported() {
			continue
		}
		if !sv.CanSet() && st.Name == "_" {
			continue
		}

		big, bit, e := getSortBitFromTag(&st.Tag, isBig)
		if len(e) > 0 {
			return errors.New(e)
		}

		err := enc.encode(sv, big, bit)
		if err != nil {
			return err
		}
	}
	return nil
}

你可能感兴趣的:(源码分析,golang,开发语言,后端)