Thrift源码解析(二)序列化协议

概述

对于一个RPC框架,定义好网络数据的序列化协议是最基本的工作,thrift的序列化协议主要包含如下几种:

  • TBinaryProtocol
  • TCompactProtocol
  • TJSONProtocol
  • TSimpleJSONProtocol
  • TTupleProtocol(继承自TCompactProtocol)

就如同Thrift源码解析(一)主要类概述的类继承图,上述这些序列化协议都是继承自TProtocol这个抽象类(为什么thrift没有选择interface我也不解)。

序列化协议的基本方法

本文结合thrift源码主要分析TBinaryProtocol的实现方法。

先看看这些序列化协议的基类TProtocol,它包含如下一些抽象函数:

/**
   * Writing methods.
   */

  public abstract void writeMessageBegin(TMessage message) throws TException;

  public abstract void writeMessageEnd() throws TException;

  public abstract void writeStructBegin(TStruct struct) throws TException;

  public abstract void writeStructEnd() throws TException;

  public abstract void writeFieldBegin(TField field) throws TException;

  public abstract void writeFieldEnd() throws TException;

  public abstract void writeFieldStop() throws TException;

  public abstract void writeMapBegin(TMap map) throws TException;

  public abstract void writeMapEnd() throws TException;

  public abstract void writeListBegin(TList list) throws TException;

  public abstract void writeListEnd() throws TException;

  public abstract void writeSetBegin(TSet set) throws TException;

  public abstract void writeSetEnd() throws TException;

  public abstract void writeBool(boolean b) throws TException;

  public abstract void writeByte(byte b) throws TException;

  public abstract void writeI16(short i16) throws TException;

  public abstract void writeI32(int i32) throws TException;

  public abstract void writeI64(long i64) throws TException;

  public abstract void writeDouble(double dub) throws TException;

  public abstract void writeString(String str) throws TException;

  public abstract void writeBinary(ByteBuffer buf) throws TException;

  /**
   * Reading methods.
   */

  public abstract TMessage readMessageBegin() throws TException;

  public abstract void readMessageEnd() throws TException;

  public abstract TStruct readStructBegin() throws TException;

  public abstract void readStructEnd() throws TException;

  public abstract TField readFieldBegin() throws TException;

  public abstract void readFieldEnd() throws TException;

  public abstract TMap readMapBegin() throws TException;

  public abstract void readMapEnd() throws TException;

  public abstract TList readListBegin() throws TException;

  public abstract void readListEnd() throws TException;

  public abstract TSet readSetBegin() throws TException;

  public abstract void readSetEnd() throws TException;

  public abstract boolean readBool() throws TException;

  public abstract byte readByte() throws TException;

  public abstract short readI16() throws TException;

  public abstract int readI32() throws TException;

  public abstract long readI64() throws TException;

  public abstract double readDouble() throws TException;

  public abstract String readString() throws TException;

  public abstract ByteBuffer readBinary() throws TException;

前半部分是write相关的函数,后半部分是read相关的函数,write用户序列化,read用于反序列化。

我们都知道,thrift中的数据类型有如下一些:

/**
 * Type constants in the Thrift protocol.
 */
public final class TType {
  public static final byte STOP   = 0;
  public static final byte VOID   = 1;
  public static final byte BOOL   = 2;
  public static final byte BYTE   = 3;
  public static final byte DOUBLE = 4;
  public static final byte I16    = 6;
  public static final byte I32    = 8;
  public static final byte I64    = 10;
  public static final byte STRING = 11;
  public static final byte STRUCT = 12;
  public static final byte MAP    = 13;
  public static final byte SET    = 14;
  public static final byte LIST   = 15;
  public static final byte ENUM   = 16;
}

相对应的,在TProtocol类中都有相关的write/read函数。

接下来详细看看TBinaryProtocol中是怎样实现这些函数的。

详解TBinaryProtocol

下面是TBinaryProtocol的主要代码:

public void writeMessageBegin(TMessage message) throws TException {
    if (strictWrite_) {
      int version = VERSION_1 | message.type;
      writeI32(version);
      writeString(message.name);
      writeI32(message.seqid);
    } else {
      writeString(message.name);
      writeByte(message.type);
      writeI32(message.seqid);
    }
  }

  public void writeMessageEnd() {}

  public void writeStructBegin(TStruct struct) {}

  public void writeStructEnd() {}

  public void writeFieldBegin(TField field) throws TException {
    writeByte(field.type);
    writeI16(field.id);
  }

  public void writeFieldEnd() {}

  public void writeFieldStop() throws TException {
    writeByte(TType.STOP);
  }

  public void writeMapBegin(TMap map) throws TException {
    writeByte(map.keyType);
    writeByte(map.valueType);
    writeI32(map.size);
  }

  public void writeMapEnd() {}

  public void writeListBegin(TList list) throws TException {
    writeByte(list.elemType);
    writeI32(list.size);
  }

  public void writeListEnd() {}

  public void writeSetBegin(TSet set) throws TException {
    writeByte(set.elemType);
    writeI32(set.size);
  }

  public void writeSetEnd() {}

  public void writeBool(boolean b) throws TException {
    writeByte(b ? (byte)1 : (byte)0);
  }

  public void writeByte(byte b) throws TException {
    inoutTemp[0] = b;
    trans_.write(inoutTemp, 0, 1);
  }

  public void writeI16(short i16) throws TException {
    inoutTemp[0] = (byte)(0xff & (i16 >> 8));
    inoutTemp[1] = (byte)(0xff & (i16));
    trans_.write(inoutTemp, 0, 2);
  }

  public void writeI32(int i32) throws TException {
    inoutTemp[0] = (byte)(0xff & (i32 >> 24));
    inoutTemp[1] = (byte)(0xff & (i32 >> 16));
    inoutTemp[2] = (byte)(0xff & (i32 >> 8));
    inoutTemp[3] = (byte)(0xff & (i32));
    trans_.write(inoutTemp, 0, 4);
  }

  public void writeI64(long i64) throws TException {
    inoutTemp[0] = (byte)(0xff & (i64 >> 56));
    inoutTemp[1] = (byte)(0xff & (i64 >> 48));
    inoutTemp[2] = (byte)(0xff & (i64 >> 40));
    inoutTemp[3] = (byte)(0xff & (i64 >> 32));
    inoutTemp[4] = (byte)(0xff & (i64 >> 24));
    inoutTemp[5] = (byte)(0xff & (i64 >> 16));
    inoutTemp[6] = (byte)(0xff & (i64 >> 8));
    inoutTemp[7] = (byte)(0xff & (i64));
    trans_.write(inoutTemp, 0, 8);
  }

  public void writeDouble(double dub) throws TException {
    writeI64(Double.doubleToLongBits(dub));
  }

  public void writeString(String str) throws TException {
    try {
      byte[] dat = str.getBytes("UTF-8");
      writeI32(dat.length);
      trans_.write(dat, 0, dat.length);
    } catch (UnsupportedEncodingException uex) {
      throw new TException("JVM DOES NOT SUPPORT UTF-8");
    }
  }

  public void writeBinary(ByteBuffer bin) throws TException {
    int length = bin.limit() - bin.position();
    writeI32(length);
    trans_.write(bin.array(), bin.position() + bin.arrayOffset(), length);
  }

  /**
   * Reading methods.
   */

  public TMessage readMessageBegin() throws TException {
    int size = readI32();
    if (size < 0) {
      int version = size & VERSION_MASK;
      if (version != VERSION_1) {
        throw new TProtocolException(TProtocolException.BAD_VERSION, "Bad version in readMessageBegin");
      }
      return new TMessage(readString(), (byte)(size & 0x000000ff), readI32());
    } else {
      if (strictRead_) {
        throw new TProtocolException(TProtocolException.BAD_VERSION, "Missing version in readMessageBegin, old client?");
      }
      return new TMessage(readStringBody(size), readByte(), readI32());
    }
  }

  public void readMessageEnd() {}

  public TStruct readStructBegin() {
    return ANONYMOUS_STRUCT;
  }

  public void readStructEnd() {}

  public TField readFieldBegin() throws TException {
    byte type = readByte();
    short id = type == TType.STOP ? 0 : readI16();
    return new TField("", type, id);
  }

  public void readFieldEnd() {}

  public TMap readMapBegin() throws TException {
    TMap map = new TMap(readByte(), readByte(), readI32());
    checkContainerReadLength(map.size);
    return map;
  }

  public void readMapEnd() {}

  public TList readListBegin() throws TException {
    TList list = new TList(readByte(), readI32());
    checkContainerReadLength(list.size);
    return list;
  }

  public void readListEnd() {}

  public TSet readSetBegin() throws TException {
    TSet set = new TSet(readByte(), readI32());
    checkContainerReadLength(set.size);
    return set;
  }

  public void readSetEnd() {}

  public boolean readBool() throws TException {
    return (readByte() == 1);
  }

  public byte readByte() throws TException {
    if (trans_.getBytesRemainingInBuffer() >= 1) {
      byte b = trans_.getBuffer()[trans_.getBufferPosition()];
      trans_.consumeBuffer(1);
      return b;
    }
    readAll(inoutTemp, 0, 1);
    return inoutTemp[0];
  }

  public short readI16() throws TException {
    byte[] buf = inoutTemp;
    int off = 0;

    if (trans_.getBytesRemainingInBuffer() >= 2) {
      buf = trans_.getBuffer();
      off = trans_.getBufferPosition();
      trans_.consumeBuffer(2);
    } else {
      readAll(inoutTemp, 0, 2);
    }

    return
      (short)
      (((buf[off] & 0xff) << 8) |
       ((buf[off+1] & 0xff)));
  }

  public int readI32() throws TException {
    byte[] buf = inoutTemp;
    int off = 0;

    if (trans_.getBytesRemainingInBuffer() >= 4) {
      buf = trans_.getBuffer();
      off = trans_.getBufferPosition();
      trans_.consumeBuffer(4);
    } else {
      readAll(inoutTemp, 0, 4);
    }
    return
      ((buf[off] & 0xff) << 24) |
      ((buf[off+1] & 0xff) << 16) |
      ((buf[off+2] & 0xff) <<  8) |
      ((buf[off+3] & 0xff));
  }

  public long readI64() throws TException {
    byte[] buf = inoutTemp;
    int off = 0;

    if (trans_.getBytesRemainingInBuffer() >= 8) {
      buf = trans_.getBuffer();
      off = trans_.getBufferPosition();
      trans_.consumeBuffer(8);
    } else {
      readAll(inoutTemp, 0, 8);
    }

    return
      ((long)(buf[off]   & 0xff) << 56) |
      ((long)(buf[off+1] & 0xff) << 48) |
      ((long)(buf[off+2] & 0xff) << 40) |
      ((long)(buf[off+3] & 0xff) << 32) |
      ((long)(buf[off+4] & 0xff) << 24) |
      ((long)(buf[off+5] & 0xff) << 16) |
      ((long)(buf[off+6] & 0xff) <<  8) |
      ((long)(buf[off+7] & 0xff));
  }

  public double readDouble() throws TException {
    return Double.longBitsToDouble(readI64());
  }

  public String readString() throws TException {
    int size = readI32();

    checkStringReadLength(size);

    if (trans_.getBytesRemainingInBuffer() >= size) {
      try {
        String s = new String(trans_.getBuffer(), trans_.getBufferPosition(), size, "UTF-8");
        trans_.consumeBuffer(size);
        return s;
      } catch (UnsupportedEncodingException e) {
        throw new TException("JVM DOES NOT SUPPORT UTF-8");
      }
    }

    return readStringBody(size);
  }

  public String readStringBody(int size) throws TException {
    try {
      byte[] buf = new byte[size];
      trans_.readAll(buf, 0, size);
      return new String(buf, "UTF-8");
    } catch (UnsupportedEncodingException uex) {
      throw new TException("JVM DOES NOT SUPPORT UTF-8");
    }
  }

  public ByteBuffer readBinary() throws TException {
    int size = readI32();

    checkStringReadLength(size);

    if (trans_.getBytesRemainingInBuffer() >= size) {
      ByteBuffer bb = ByteBuffer.wrap(trans_.getBuffer(), trans_.getBufferPosition(), size);
      trans_.consumeBuffer(size);
      return bb;
    }

    byte[] buf = new byte[size];
    trans_.readAll(buf, 0, size);
    return ByteBuffer.wrap(buf);
  }

从代码中可以看出TBinaryProtocol协议对各种数据类型的序列化方案:

  1. writeMessageBegin: 什么是Message,顾名思义,thrift中Message表示一次接口调用、接口调用结果、或者异常。从该函数的实现中可以看出,在序列化Message的开始,先append thrift的版本,message的name,以及message的seqId等基本信息;
  2. writeMessageEnd: Message序列化结束时,do nothing;
  3. writeStructBegin: 序列化struct开始,do nothing;
  4. writeStructEnd: 序列化struct结束,do nothing;
  5. writeFieldBegin: 序列化struct内部字段开始,先append字段的类型和字段id;
  6. writeFieldEnd: 序列化struct结束,do nothing;
  7. writeFieldStop: 所有字段序列化完成,append一个TType.STOP,也就是0(所在在thrift的序列化数据中,0表示一个struct的结束);
  8. writeMapBegin: 序列化Map开始,先append Map的key类型、Map的value类型、以及Map的size;
  9. writeMapEnd:序列化Map结束,do nothing;
  10. writeListBegin: 序列化List开始,先append list的元素类型和list的size;
  11. writeListEnd:序列化List结束,do nothing;
  12. writeSetBegin:序列化Set开始,先append元素类型和Set的size;
  13. writeSetEnd:序列化Set结束,do nothing;
  14. writeBool:序列化boolean开始,直接append值,true为1,false为0;
  15. writeByte:序列化byte开始,直接append值;
  16. writeI16:序列化short,直接append两个字节的int值;
  17. writeI32:序列化int,直接append四个字节的short值;
  18. writeI64:序列化long,直接append八个字节的long值;
  19. writeDouble: 序列化double,根据double的 IEEE 754 浮点双精度格式 (“double format”) 位布局,得到指定浮点值的表示形式,然后append八个字节的long;
  20. writeString:序列化字符串,先append字符串的长度,再append字符串的内容;
  21. writeBinary:序列化二进制数据,先append数据的长度,再append数据的内容;

余下的read操作就是上述write操作的你操作,不再赘述。但是需要提的是,在write操作中有一些关于数据长度的check,比如:

public TMap readMapBegin() throws TException {
    TMap map = new TMap(readByte(), readByte(), readI32());
    checkContainerReadLength(map.size);
    return map;
  }

反序列化map的一开始,先读取map的size,然后调用checkContainerReadLength函数检查这个size是否正常:

private void checkContainerReadLength(int length) throws TProtocolException {
    if (length < 0) {
      throw new TProtocolException(TProtocolException.NEGATIVE_SIZE,
                                   "Negative length: " + length);
    }
    if (containerLengthLimit_ != NO_LENGTH_LIMIT && length > containerLengthLimit_) {
      throw new TProtocolException(TProtocolException.SIZE_LIMIT,
                                   "Length exceeded max allowed: " + length);
    }
  }

如果map的size大小大于一个给定的大小,就抛出异常。记得很久以前的版本(好像是0.5.0)没有这个check逻辑,就会出现这样一种情况:

反序列化数据的时候,如果源数据被写坏,那么读到的size可能是一个很大的数字,这个时候如果没有上述的check逻辑,程序会根据这个size分配一个很大的空间,就会导致程序OOM,异常终止!

类似的,在readList,readSet,readString的时候都有这样的额check逻辑。

结语

上面只是以TBinaryProtocol为例,详细介绍thrift的一种序列化协议,其他的协议大家可以自行看着代码翻译~~

你可能感兴趣的:(Thrift,源码分析,thrift源码解析)