Spark 之 ArrowColumnVector

ArrowColumnVector 直接继承 ColumnVector,而ColumnVector 只有 Get 方法

构造函数
public ArrowColumnVector(ValueVector vector) {
    this(ArrowUtils.fromArrowField(vector.getField()));
    initAccessor(vector);
  }
ArrowColumnVector(DataType type) {
     super(type);
  }
void initAccessor(ValueVector vector) {
    if (vector instanceof BitVector) {
      accessor = new BooleanAccessor((BitVector) vector);
    } else if (vector instanceof TinyIntVector) {
      accessor = new ByteAccessor((TinyIntVector) vector);
    } else if (vector instanceof SmallIntVector) {
      accessor = new ShortAccessor((SmallIntVector) vector);
    } else if (vector instanceof IntVector) {
      accessor = new IntAccessor((IntVector) vector);
    } else if (vector instanceof BigIntVector) {
      accessor = new LongAccessor((BigIntVector) vector);
    } else if (vector instanceof Float4Vector) {
      accessor = new FloatAccessor((Float4Vector) vector);
    } else if (vector instanceof Float8Vector) {
      accessor = new DoubleAccessor((Float8Vector) vector);
    } else if (vector instanceof DecimalVector) {
      accessor = new DecimalAccessor((DecimalVector) vector);
    } else if (vector instanceof VarCharVector) {
      accessor = new StringAccessor((VarCharVector) vector);
    } else if (vector instanceof VarBinaryVector) {
      accessor = new BinaryAccessor((VarBinaryVector) vector);
    } else if (vector instanceof DateDayVector) {
      accessor = new DateAccessor((DateDayVector) vector);
    } else if (vector instanceof TimeStampMicroTZVector) {
      accessor = new TimestampAccessor((TimeStampMicroTZVector) vector);
    } else if (vector instanceof TimeStampMicroVector) {
      accessor = new TimestampNTZAccessor((TimeStampMicroVector) vector);
    } else if (vector instanceof MapVector) {
      MapVector mapVector = (MapVector) vector;
      accessor = new MapAccessor(mapVector);
    } else if (vector instanceof ListVector) {
      ListVector listVector = (ListVector) vector;
      accessor = new ArrayAccessor(listVector);
    } else if (vector instanceof StructVector) {
      StructVector structVector = (StructVector) vector;
      accessor = new StructAccessor(structVector);

      childColumns = new ArrowColumnVector[structVector.size()];
      for (int i = 0; i < childColumns.length; ++i) {
        childColumns[i] = new ArrowColumnVector(structVector.getVectorById(i));
      }
    } else if (vector instanceof NullVector) {
      accessor = new NullAccessor((NullVector) vector);
    } else if (vector instanceof IntervalYearVector) {
      accessor = new IntervalYearAccessor((IntervalYearVector) vector);
    } else if (vector instanceof DurationVector) {
      accessor = new DurationAccessor((DurationVector) vector);
    } else {
      throw new UnsupportedOperationException();
    }
  }
ArrowVectorAccessor

ArrowVectorAccessor 是对 ValueVector 套了一层总的接口,总的意义不大
Spark 之 ArrowColumnVector_第1张图片

  abstract static class ArrowVectorAccessor {

    final ValueVector vector;

    ArrowVectorAccessor(ValueVector vector) {
      this.vector = vector;
    }

    // TODO: should be final after removing ArrayAccessor workaround
    boolean isNullAt(int rowId) {
      return vector.isNull(rowId);
    }

    final int getNullCount() {
      return vector.getNullCount();
    }

    final void close() {
      vector.close();
    }

    boolean getBoolean(int rowId) {
      throw new UnsupportedOperationException();
    }

    byte getByte(int rowId) {
      throw new UnsupportedOperationException();
    }

    short getShort(int rowId) {
      throw new UnsupportedOperationException();
    }

    int getInt(int rowId) {
      throw new UnsupportedOperationException();
    }

    long getLong(int rowId) {
      throw new UnsupportedOperationException();
    }

    float getFloat(int rowId) {
      throw new UnsupportedOperationException();
    }

    double getDouble(int rowId) {
      throw new UnsupportedOperationException();
    }

    Decimal getDecimal(int rowId, int precision, int scale) {
      throw new UnsupportedOperationException();
    }

    UTF8String getUTF8String(int rowId) {
      throw new UnsupportedOperationException();
    }

    byte[] getBinary(int rowId) {
      throw new UnsupportedOperationException();
    }

    ColumnarArray getArray(int rowId) {
      throw new UnsupportedOperationException();
    }

    ColumnarMap getMap(int rowId) {
      throw new UnsupportedOperationException();
    }
  }
UT
test("int") {
    val allocator = ArrowUtils.rootAllocator.newChildAllocator("int", 0, Long.MaxValue)
    val vector = ArrowUtils.toArrowField("int", IntegerType, nullable = true, null)
      .createVector(allocator).asInstanceOf[IntVector]
    vector.allocateNew()

    (0 until 10).foreach { i =>
      vector.setSafe(i, i)
    }
    vector.setNull(10)
    vector.setValueCount(11)

    val columnVector = new ArrowColumnVector(vector)
    assert(columnVector.dataType === IntegerType)
    assert(columnVector.hasNull)
    assert(columnVector.numNulls === 1)

    (0 until 10).foreach { i =>
      assert(columnVector.getInt(i) === i)
    }
    assert(columnVector.isNullAt(10))

    assert(columnVector.getInts(0, 10) === (0 until 10))

    columnVector.close()
    allocator.close()
  }
  test("array") {
    val allocator = ArrowUtils.rootAllocator.newChildAllocator("array", 0, Long.MaxValue)
    val vector = ArrowUtils.toArrowField("array", ArrayType(IntegerType), nullable = true, null)
      .createVector(allocator).asInstanceOf[ListVector]
    vector.allocateNew()
    val elementVector = vector.getDataVector().asInstanceOf[IntVector]

    // [1, 2]
    vector.startNewValue(0)
    elementVector.setSafe(0, 1)
    elementVector.setSafe(1, 2)
    vector.endValue(0, 2)

    // [3, null, 5]
    vector.startNewValue(1)
    elementVector.setSafe(2, 3)
    elementVector.setNull(3)
    elementVector.setSafe(4, 5)
    vector.endValue(1, 3)

    // null

    // []
    vector.startNewValue(3)
    vector.endValue(3, 0)

    elementVector.setValueCount(5)
    vector.setValueCount(4)

    val columnVector = new ArrowColumnVector(vector)
    assert(columnVector.dataType === ArrayType(IntegerType))
    assert(columnVector.hasNull)
    assert(columnVector.numNulls === 1)

    val array0 = columnVector.getArray(0)
    assert(array0.numElements() === 2)
    assert(array0.getInt(0) === 1)
    assert(array0.getInt(1) === 2)

    val array1 = columnVector.getArray(1)
    assert(array1.numElements() === 3)
    assert(array1.getInt(0) === 3)
    assert(array1.isNullAt(1))
    assert(array1.getInt(2) === 5)

    assert(columnVector.isNullAt(2))

    val array3 = columnVector.getArray(3)
    assert(array3.numElements() === 0)

    columnVector.close()
    allocator.close()
  }
  test("create columnar batch from Arrow column vectors") {
    val allocator = ArrowUtils.rootAllocator.newChildAllocator("int", 0, Long.MaxValue)
    val vector1 = ArrowUtils.toArrowField("int1", IntegerType, nullable = true, null)
      .createVector(allocator).asInstanceOf[IntVector]
    vector1.allocateNew()
    val vector2 = ArrowUtils.toArrowField("int2", IntegerType, nullable = true, null)
      .createVector(allocator).asInstanceOf[IntVector]
    vector2.allocateNew()

    (0 until 10).foreach { i =>
      vector1.setSafe(i, i)
      vector2.setSafe(i + 1, i)
    }
    vector1.setNull(10)
    vector1.setValueCount(11)
    vector2.setNull(0)
    vector2.setValueCount(11)

    val columnVectors = Seq(new ArrowColumnVector(vector1), new ArrowColumnVector(vector2))

    val schema = StructType(Seq(StructField("int1", IntegerType), StructField("int2", IntegerType)))
    val batch = new ColumnarBatch(columnVectors.toArray)
    batch.setNumRows(11)

    assert(batch.numCols() == 2)
    assert(batch.numRows() == 11)

    val rowIter = batch.rowIterator().asScala
    rowIter.zipWithIndex.foreach { case (row, i) =>
      if (i == 10) {
        assert(row.isNullAt(0))
      } else {
        assert(row.getInt(0) == i)
      }
      if (i == 0) {
        assert(row.isNullAt(1))
      } else {
        assert(row.getInt(1) == i - 1)
      }
    }

    batch.close()
    allocator.close()
  }

你可能感兴趣的:(spark源码分析,c++,开发语言,spark)