Pytorch amp(混合精度)的bfloat16和float16

bfloat16 格式使用 16 位表示浮点数,其中 1 位用于符号,8 位用于指数,7 位用于尾数。

float16 格式使用 16 位表示浮点数,其中 1 位用于符号,5 位用于指数,10 位用于尾数。

bfloat16 的表示范围比 float16 更广,但是精度更低

在训练大规模的CLIP的时候,用混合精度的float16,训练时会出现nan,换成bfloat16就可以解决,出自《Reproducible scaling laws for contrastive language-image learning》

你可能感兴趣的:(pytorch)