复杂的融合算子训练pass自动化的探讨

在 flash attention带来速度提升的同时,我们发现其训练过程必须重新来写,打破了pytorch 和 tensorflow 等引以为豪的自动化求导的机制,而必须加入新的求导算子;

那么,每次出现融合算子时,都需要向 pytorch中添加新的算子才能工作;

复杂融合算子的自动求导是否可能呢?

接下来做一些探讨

未完待续 ... ...

你可能感兴趣的:(自动求导,pytorch自动求导,flash-attention)