功能类似golang标准库encoding/binary
,用于二进制码流/文件的读写。对比标准库,本库对以下方面做了功能增强:
bit
、sort
结构体标签,用于指定结构体成员的bit数、大小端属性map
类型结构string
字符串package binary
import (
"fmt"
"io"
"math"
"reflect"
)
type Unmarshaler interface {
UnmarshalBinary(dec *Decoder, isBig bool, bit int) error
}
func Unmarshal(buf []byte, isBig bool, e ...any) error {
return NewDecoder(buf, 0).Unmarshal(isBig, e...)
}
var EOF = newErr("read the end")
type Decoder struct {
buf []byte
prevOff int
off int
bit int
arg any
}
func NewDecoder(buf []byte, prevOff int) *Decoder {
return &Decoder{buf: buf, prevOff: prevOff}
}
func NewReaderDecoder(r io.Reader) *Decoder {
buf, _ := io.ReadAll(r)
return NewDecoder(buf, 0)
}
func (dec *Decoder) Pos() int {
return (dec.prevOff+dec.off)*8 + dec.bit
}
func (dec *Decoder) Seek(pos int) error {
if pos < dec.prevOff*8 || pos > (dec.prevOff+len(dec.buf))*8 {
return fmtErr("pos(%d.%d) illegal", pos/8, pos%8)
}
dec.off = pos/8 - dec.prevOff
dec.bit = pos % 8
return nil
}
func (dec *Decoder) SubDecoder(n int) *Decoder {
prevOff := dec.off
if dec.bit > 0 {
prevOff++
}
if n <= 0 || prevOff+n >= len(dec.buf) {
n = len(dec.buf) - prevOff
}
return &Decoder{
buf: dec.buf[prevOff : prevOff+n],
prevOff: dec.prevOff + prevOff,
arg: dec.arg,
}
}
func (dec *Decoder) SetArg(a any) {
dec.arg = a
}
func (dec *Decoder) Arg() any {
return dec.arg
}
func (dec *Decoder) Read(p []byte) (n int, err error) {
if dec.bit > 0 {
dec.bit = 0
dec.off++
}
if len(p) == 0 {
return 0, nil
}
n = len(dec.buf) - dec.off
if n <= 0 {
return 0, EOF
}
if n > len(p) {
n = len(p)
}
copy(p, dec.buf[dec.off:dec.off+n])
dec.off += n
return n, nil
}
func (dec *Decoder) ReadByte() (byte, error) {
if dec.bit > 0 {
dec.bit = 0
dec.off++
}
if dec.off >= len(dec.buf) {
return 0, EOF
}
b := dec.buf[dec.off]
dec.off++
return b, nil
}
func (dec *Decoder) ReadBytes(n int) []byte {
if dec.bit > 0 {
dec.bit = 0
dec.off++
}
if dec.off >= len(dec.buf) {
panic(EOF)
}
if n <= 0 || dec.off+n > len(dec.buf) {
n = len(dec.buf) - dec.off
}
bs := dec.buf[dec.off : dec.off+n]
dec.off += n
return bs
}
func (dec *Decoder) IsEof() bool {
if dec.bit > 0 {
return false
}
return dec.off == len(dec.buf)
}
func (dec *Decoder) Unmarshal(isBig bool, e ...any) error {
for _, v := range e {
err := dec.Decode(v, isBig, 0)
if err != nil {
return err
}
}
return nil
}
func (dec *Decoder) Decode(a any, isBig bool, bit int) (err error) {
defer func() {
r := recover()
if r == nil {
return
}
if e, ok := r.(decErr); ok {
err = e
} else {
panic(r)
}
}()
if a == nil {
return
}
if v, ok := a.(reflect.Value); ok {
dec.decode(v, isBig, bit)
}
dec.decode(reflect.ValueOf(a), isBig, bit)
return
}
func (dec *Decoder) decode(v reflect.Value, isBig bool, bit int) {
if !v.IsValid() {
return
}
if v.CanInterface() && dec.handleMethods(v, isBig, bit) {
return
}
if v.Kind() != reflect.Pointer && v.CanAddr() && dec.handleMethods(v.Addr(), isBig, bit) {
return
}
switch v.Kind() {
case reflect.Bool:
if v.CanSet() {
v.SetBool(dec.decodeInteger(8, isBig, bit) != 0)
}
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
if v.CanSet() {
v.SetInt(int64(dec.decodeInteger(kindSize(v.Kind()), isBig, bit)))
}
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr:
if v.CanSet() {
v.SetUint(dec.decodeInteger(kindSize(v.Kind()), isBig, bit))
}
case reflect.Float32, reflect.Float64:
if v.CanSet() {
v.SetFloat(dec.decodeFloat(kindSize(v.Kind()), isBig))
}
case reflect.Complex64, reflect.Complex128:
if v.CanSet() {
v.SetComplex(dec.decodeComplex(kindSize(v.Kind()), isBig))
}
case reflect.Slice:
if v.IsNil() {
dec.decodeSlice(v, isBig, bit)
break
}
fallthrough
case reflect.Array:
l := v.Len()
for i := 0; i < l; i++ {
dec.decode(v.Index(i), isBig, bit)
}
case reflect.Interface, reflect.Pointer:
if v.IsNil() {
break
}
dec.decode(v.Elem(), isBig, bit)
case reflect.Map:
dec.decodeMap(v, isBig, bit)
case reflect.String:
if v.CanSet() {
v.SetString(dec.decodeString())
}
case reflect.Struct:
dec.decodeStruct(v, isBig)
default:
return
}
}
func (dec *Decoder) handleMethods(v reflect.Value, isBig bool, bit int) bool {
f, ok := v.Interface().(Unmarshaler)
if !ok {
return false
}
err := f.UnmarshalBinary(dec, isBig, bit)
if err == nil {
return true
}
var e error
if e, ok = err.(decErr); ok {
panic(e)
}
panic(newErr(err.Error()))
}
func (dec *Decoder) decodeInteger(size int, isBig bool, bit int) uint64 {
decFunc := dec.decodeLittleInteger
if bit > 0 {
size = bit
if isBig {
decFunc = dec.decodeBigBits
} else {
decFunc = dec.decodeLittleBits
}
} else if isBig {
decFunc = dec.decodeBigInteger
}
return decFunc(size)
}
func (dec *Decoder) decodeLittleBits(size int) uint64 {
v := uint64(0)
bit := size
for bit > 0 {
if dec.off >= len(dec.buf) {
panic(EOF)
}
num := 8 - dec.bit
if num > bit {
num = bit
}
mask := byte((1 << num) - 1)
v |= uint64((dec.buf[dec.off]>>dec.bit)&mask) << (size - bit)
dec.bit += num
bit -= num
if dec.bit >= 8 {
dec.bit -= 8
dec.off++
}
}
return v
}
func (dec *Decoder) decodeBigBits(size int) uint64 {
v := uint64(0)
bit := size
for bit > 0 {
if dec.off >= len(dec.buf) {
panic(EOF)
}
num := 8 - dec.bit
if num > bit {
num = bit
}
mask := byte((1 << num) - 1)
v <<= num
v |= uint64((dec.buf[dec.off] >> (8 - dec.bit - num)) & mask)
dec.bit += num
bit -= num
if dec.bit >= 8 {
dec.bit -= 8
dec.off++
}
}
return v
}
func (dec *Decoder) decodeLittleInteger(size int) uint64 {
if size <= 0 || size%8 != 0 {
panic(fmtErr("unsupport integer size: %d", size))
}
size /= 8
bs := dec.ReadBytes(size)
v := uint64(0)
for i := 0; i < size; i++ {
v |= uint64(bs[i]) << (i * 8)
}
return v
}
func (dec *Decoder) decodeBigInteger(size int) uint64 {
if size <= 0 || size%8 != 0 {
panic(fmtErr("unsupport integer size: %d", size))
}
size /= 8
bs := dec.ReadBytes(size)
v := uint64(0)
for i := 0; i < size; i++ {
v <<= 8
v |= uint64(bs[i])
}
return v
}
func (dec *Decoder) decodeFloat(size int, isBig bool) float64 {
decFunc := dec.decodeLittleInteger
if isBig {
decFunc = dec.decodeBigInteger
}
i := decFunc(size)
switch size {
case 32:
return float64(math.Float32frombits(uint32(i)))
case 64:
return math.Float64frombits(i)
default:
panic(fmtErr("unsupport float size: %d", size))
}
}
func (dec *Decoder) decodeComplex(size int, isBig bool) complex128 {
r := dec.decodeFloat(size/2, isBig)
i := dec.decodeFloat(size/2, isBig)
return complex(r, i)
}
func (dec *Decoder) decodeString() string {
if dec.bit > 0 {
dec.bit = 0
dec.off++
}
if dec.off >= len(dec.buf) {
panic(EOF)
}
bs := dec.buf[dec.off:]
n := 0
for n < len(bs) && bs[n] != 0 {
n++
}
dec.off += n
if n < len(bs) {
dec.off++
}
return string(bs[:n])
}
func (dec *Decoder) decodeSlice(v reflect.Value, isBig bool, bit int) {
if !v.CanSet() {
return
}
t := v.Type().Elem()
if t.Kind() == reflect.Uint8 && bit == 0 {
v.SetBytes(dec.ReadBytes(0))
return
}
vt := v
for i := 0; !dec.IsEof(); i++ {
sv := reflect.New(t).Elem()
dec.decode(sv, isBig, bit)
vt = reflect.Append(vt, sv)
}
v.Set(vt)
}
func (dec *Decoder) decodeMap(v reflect.Value, isBig bool, bit int) {
if !v.CanSet() {
return
}
kt := v.Type().Key()
vt := v.Type().Elem()
switch kt.Kind() {
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr:
case reflect.String:
default:
panic(newErr("unsupported key type of map: " + kt.String()))
}
for !dec.IsEof() {
kv := reflect.New(kt).Elem()
vv := reflect.New(vt).Elem()
dec.decode(kv, isBig, 0)
dec.decode(vv, isBig, bit)
v.SetMapIndex(kv, vv)
}
}
func (dec *Decoder) decodeStruct(v reflect.Value, isBig bool) {
t := v.Type()
l := v.NumField()
for i := 0; i < l; i++ {
sv := v.Field(i)
st := t.Field(i)
if !st.IsExported() {
continue
}
if !sv.CanSet() && st.Name == "_" {
continue
}
big, bit, e := getSortBitFromTag(&st.Tag, isBig)
if len(e) > 0 {
panic(newErr(e))
}
dec.decode(sv, big, bit)
}
}
type decErr string
func newErr(err string) decErr {
return decErr(err)
}
func fmtErr(format string, a ...any) decErr {
return decErr(fmt.Sprintf(format, a...))
}
func (err decErr) Error() string {
return string(err)
}
package binary
import (
"bytes"
"errors"
"fmt"
"io"
"math"
"reflect"
"strconv"
)
type Marshaler interface {
MarshalBinary(enc *Encoder, isBig bool, bit int) error
}
func Marshal(isBig bool, e ...any) ([]byte, error) {
buf := bytes.NewBuffer(nil)
err := NewEncoder(buf).Marshal(isBig, e...)
if err != nil {
return nil, err
}
return buf.Bytes(), nil
}
func kindSize(kind reflect.Kind) int {
size := 0
switch kind {
case reflect.Int, reflect.Uint:
size = 32
case reflect.Bool, reflect.Int8, reflect.Uint8:
size = 8
case reflect.Int16, reflect.Uint16:
size = 16
case reflect.Int32, reflect.Uint32, reflect.Float32:
size = 32
case reflect.Int64, reflect.Uint64, reflect.Uintptr, reflect.Float64, reflect.Complex64:
size = 64
case reflect.Complex128:
size = 128
default:
}
return size
}
func getSortBitFromTag(tag *reflect.StructTag, defBit bool) (isBit bool, bit int, err string) {
isBit = defBit
sort := tag.Get("sort")
if len(sort) > 0 {
if sort == "big" {
isBit = true
} else if sort == "little" {
isBit = false
} else {
err = "unsupported tag: `sort:\"" + sort + "\"`"
return
}
}
bitS := tag.Get("bit")
if len(bitS) > 0 {
num, e := strconv.Atoi(bitS)
if (e != nil) || (num <= 0 || num > 64) {
err = "unsupported tag: `bit:\"" + bitS + "\"`"
return
}
bit = num
}
return
}
type Encoder struct {
io.Writer
bit int
bitBuf byte
arg any
}
func NewEncoder(w io.Writer) *Encoder {
return &Encoder{Writer: w}
}
func (enc *Encoder) Write(p []byte) (n int, err error) {
if enc.bit > 0 {
n, err = enc.Writer.Write([]byte{enc.bitBuf})
if err != nil || n < 1 {
return
}
enc.bitBuf = 0
enc.bit = 0
}
if len(p) == 0 {
return 0, nil
}
return enc.Writer.Write(p)
}
func (enc *Encoder) WriteByte(c byte) error {
if enc.bit > 0 {
_, err := enc.Writer.Write([]byte{enc.bitBuf})
if err != nil {
return err
}
enc.bitBuf = 0
enc.bit = 0
}
_, err := enc.Writer.Write([]byte{c})
return err
}
func (enc *Encoder) WriteBytes(p []byte) error {
n, err := enc.Write(p)
for err == nil && n < len(p) {
n1 := 0
n1, err = enc.Write(p[n:])
n += n1
}
if err != nil {
return err
}
if n < len(p) {
return fmt.Errorf("partial(%d) write, all(%d)", n, len(p))
}
return nil
}
func (enc *Encoder) SetArg(a any) {
enc.arg = a
}
func (enc *Encoder) Arg() any {
return enc.arg
}
func (enc *Encoder) Marshal(isBig bool, e ...any) error {
for _, v := range e {
err := enc.Encode(v, isBig, 0)
if err != nil {
return err
}
}
return nil
}
func (enc *Encoder) Encode(a any, isBig bool, bit int) (err error) {
if a == nil {
return
}
switch v := a.(type) {
case []byte:
if bit == 0 {
err = enc.WriteBytes(v)
} else {
err = enc.encode(reflect.ValueOf(a), isBig, bit)
}
case *[]byte:
if bit == 0 {
err = enc.WriteBytes(*v)
} else {
err = enc.encode(reflect.ValueOf(a), isBig, bit)
}
case reflect.Value:
err = enc.encode(v, isBig, bit)
default:
err = enc.encode(reflect.ValueOf(a), isBig, bit)
}
return
}
func (enc *Encoder) encode(v reflect.Value, isBig bool, bit int) error {
if !v.IsValid() {
return nil
}
if v.CanInterface() {
if f, ok := v.Interface().(Marshaler); ok {
return f.MarshalBinary(enc, isBig, bit)
}
}
if v.Kind() != reflect.Pointer && v.CanAddr() {
if f, ok := v.Addr().Interface().(Marshaler); ok {
return f.MarshalBinary(enc, isBig, bit)
}
}
var err error
switch v.Kind() {
case reflect.Bool:
t := uint64(0)
if v.Bool() {
t = 1
}
err = enc.encodeInteger(t, 8, isBig, bit)
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
err = enc.encodeInteger(uint64(v.Int()), kindSize(v.Kind()), isBig, bit)
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr:
err = enc.encodeInteger(v.Uint(), kindSize(v.Kind()), isBig, bit)
case reflect.Float32, reflect.Float64:
err = enc.encodeFloat(v.Float(), kindSize(v.Kind()), isBig)
case reflect.Complex64, reflect.Complex128:
err = enc.encodeComplex(v.Complex(), kindSize(v.Kind()), isBig)
case reflect.Slice:
if v.IsNil() {
break
}
fallthrough
case reflect.Array:
l := v.Len()
for i := 0; i < l; i++ {
err = enc.encode(v.Index(i), isBig, bit)
if err != nil {
break
}
}
case reflect.Interface, reflect.Pointer:
if v.IsNil() {
break
}
err = enc.encode(v.Elem(), isBig, bit)
case reflect.Map:
if v.IsNil() {
break
}
iter := v.MapRange()
for iter.Next() {
sv := iter.Value()
if !sv.IsValid() {
continue
}
err = enc.encode(iter.Key(), isBig, 0)
if err != nil {
break
}
err = enc.encode(sv, isBig, bit)
if err != nil {
break
}
}
case reflect.String:
err = enc.WriteBytes(append([]byte(v.String()), 0))
case reflect.Struct:
err = enc.encodeStruct(v, isBig)
default:
}
return err
}
func (enc *Encoder) encodeInteger(v uint64, size int, isBig bool, bit int) error {
encFunc := enc.encodeLittleInteger
if bit > 0 {
size = bit
if isBig {
encFunc = enc.encodeBigBits
} else {
encFunc = enc.encodeLittleBits
}
} else if isBig {
encFunc = enc.encodeBigInteger
}
return encFunc(v, size)
}
func (enc *Encoder) encodeLittleBits(v uint64, size int) error {
bit := size
for bit > 0 {
num := 8 - enc.bit
if num > bit {
num = bit
}
mask := uint64(1<<num) - 1
enc.bitBuf &= ^byte(mask << enc.bit)
enc.bitBuf |= byte((v & mask) << enc.bit)
v >>= num
enc.bit += num
bit -= num
if enc.bit >= 8 {
enc.bit -= 8
n, err := enc.Writer.Write([]byte{enc.bitBuf})
if err != nil {
return err
} else if n < 1 {
return errors.New("write bit buffer error")
}
enc.bitBuf = 0
}
}
return nil
}
func (enc *Encoder) encodeBigBits(v uint64, size int) error {
bit := size
for bit > 0 {
num := 8 - enc.bit
if num > bit {
num = bit
}
mask := uint64(1<<num) - 1
enc.bitBuf &= ^byte(mask << (8 - enc.bit - num))
enc.bitBuf |= byte(((v >> (bit - num)) & mask) << (8 - enc.bit - num))
enc.bit += num
bit -= num
if enc.bit >= 8 {
enc.bit -= 8
n, err := enc.Writer.Write([]byte{enc.bitBuf})
if err != nil {
return err
} else if n < 1 {
return errors.New("write bit buffer error")
}
enc.bitBuf = 0
}
}
return nil
}
func (enc *Encoder) encodeLittleInteger(v uint64, size int) error {
if size <= 0 || size%8 != 0 {
return fmt.Errorf("unsupport integer size: %d", size)
}
size /= 8
bs := make([]byte, size)
for i := 0; i < size; i++ {
bs[i] = byte(v)
v >>= 8
}
return enc.WriteBytes(bs)
}
func (enc *Encoder) encodeBigInteger(v uint64, size int) error {
if size <= 0 || size%8 != 0 {
return fmt.Errorf("unsupport integer size: %d", size)
}
size /= 8
bs := make([]byte, size)
for i := size - 1; i >= 0; i-- {
bs[i] = byte(v)
v >>= 8
}
return enc.WriteBytes(bs)
}
func (enc *Encoder) encodeFloat(v float64, size int, isBig bool) error {
encFunc := enc.encodeLittleInteger
if isBig {
encFunc = enc.encodeBigInteger
}
switch size {
case 32:
return encFunc(uint64(math.Float32bits(float32(v))), size)
case 64:
return encFunc(math.Float64bits(v), size)
default:
return fmt.Errorf("unsupport float size: %d", size)
}
}
func (enc *Encoder) encodeComplex(v complex128, size int, isBig bool) error {
err := enc.encodeFloat(real(v), size/2, isBig)
if err != nil {
return err
}
return enc.encodeFloat(imag(v), size/2, isBig)
}
func (enc *Encoder) encodeStruct(v reflect.Value, isBig bool) error {
t := v.Type()
l := v.NumField()
for i := 0; i < l; i++ {
sv := v.Field(i)
st := t.Field(i)
if !st.IsExported() {
continue
}
if !sv.CanSet() && st.Name == "_" {
continue
}
big, bit, e := getSortBitFromTag(&st.Tag, isBig)
if len(e) > 0 {
return errors.New(e)
}
err := enc.encode(sv, big, bit)
if err != nil {
return err
}
}
return nil
}