half(fp16)类型转float(fp32)类型的简单实现

half和float的数据格式

half (fp16)

half(fp16)类型转float(fp32)类型的简单实现_第1张图片

组成:符号位 1 bit + 指数位 5 bits + 小数位 10 bits。

指数位的表示范围是[2^-14, 2^15]。

float (fp32)

half(fp16)类型转float(fp32)类型的简单实现_第2张图片

组成:符号位 1 bit + 指数位 8 bits + 小数位 23 bits。

指数位的表示范围是[2^-126, 2^127]。

算法原理

只要half类型的值不是nan,则转换为float后,符号位不变。

特殊值处理

当half类型的绝对值为0.0 (0x0000)时

转化为float后绝对值依然为0.0 (0x00000000)。

当half类型的绝对值为INF (0x7C00)时

转化为float后绝对值为INF (0x7F800000)。

当half类型的绝对值大于INF (0x7C00)时

此时数值为NAN,转化为float后统一为NAN (0x7FC00000)。

其他情况

当half类型的指数位值不为0时

half(fp16)类型转float(fp32)类型的简单实现_第3张图片

指数位转换:当表示的指数为2^0=1时,half类型的指数位为0x10,float类型指数位为0x80。因此在half转float后,指数位需要加0x70。

小数位转换:由于half类型的小数位长度是 10 bits,float类型的小数位长度为 23 bits。因此在half转float后,小数位需要左移 13 bits。

当half类型的指数位值为0时

注意,当浮点数的指数位为0时,小数位表示的数值是 0.xxxxx,而当指数位不为0时,小数位表示的数值位 1.xxxxx。

由于half类型指数位最小(0x0)能表示的是2^-14,因此转化为float后指数位必然不为0。

因此小数位除了要左移13位之外,还需要继续左移至最高为的1省略掉,同时指数位也要减去 (额外的左移值 - 1)。

half(fp16)类型转float(fp32)类型的简单实现_第4张图片

指数位转换:half转float后,指数位需要加0x70,再减去 (额外的左移值 - 1)。如图所示,当half转换为float时,小数位共左移了 (13 + 4) bits,因此float类型的指数位值为 0 + 0x70 - (4 - 1) = 0x6D。

小数位转换:首先要确认half类型数值小数位的最高位1的位置,在转换为float后,需要左移直至该最高位的1被省去。如图所示,half小数位最高位的1位于第7 bit位,因此小数位需要左移 (13 + (10 - 6)) = (13 + 4) bits。

示例代码

float f16_to_f32(half __x) {
  unsigned short n = *((unsigned short *)&__x);
  unsigned int x = (unsigned int)n;
  x = x & 0xffff;
  unsigned int sign = x & 0x8000;                   //符号位
  unsigned int exponent_f16 = (x & 0x7c00) >> 10;   //half指数位
  unsigned int mantissa_f16 = x & 0x03ff;           //half小数位
  unsigned int y = sign << 16;
  unsigned int exponent_f32;                        //float指数位
  unsigned int mantissa_f32;                        //float小数位
  unsigned int first_1_pos = 0;                     //(half小数位)最高位1的位置
  unsigned int mask;
  unsigned int hx;
 
  hx = x & 0x7fff;
 
  if (hx == 0) {
    return *((float *)&y);
  }
  if (hx == 0x7c00) {
    y |= 0x7f800000;
    return *((float *)&y);
  }
  if (hx > 0x7c00) {
    y = 0x7fc00000;
    return *((float *)&y);
  }
 
  exponent_f32 = 0x70 + exponent_f16;
  mantissa_f32 = mantissa_f16 << 13;
 
  for (first_1_pos = 0; first_1_pos < 10; first_1_pos++) {
    if ((mantissa_f16 >> (first_1_pos + 1)) == 0) {
      break;
    }
  }
 
  if (exponent_f16 == 0) {
    mask = (1 << 23) - 1;
    exponent_f32 = exponent_f32 - (10 - first_1_pos) + 1;
    mantissa_f32 = mantissa_f32 << (10 - first_1_pos);
    mantissa_f32 = mantissa_f32 & mask;
  }
 
  y = y | (exponent_f32 << 23) | mantissa_f32;

  return *((float *)&y);
}

推荐:float-toy

你可能感兴趣的:(算法,数据结构)