JAX 记录

测试官方sample里的resnet50,用的机器是单卡v100.

batch size 设置为32.
先测试了一下对update函数默认带了@jit的,也就是开启了XLA Jit优化。
执行时间大概是0.16s/step

然后关闭jit发现时间变成了2~3s/step。两者差异巨大。

于是用nvprof profile了一下。
使用jit的情况:
GPU kernel 执行情况:


image.png

API call 情况:


image.png

关闭jit的情况:

GPU kernel 情况:


image.png

API call 情况:


image.png

你可能感兴趣的:(JAX 记录)