在微调GPT/BERT模型时,会经常遇到“ cuda out of memory”的情况。这是因为transformer是内存密集型的模型,并且内存要求也随序列长度而增加。所以如果能对模型的内存要求进行粗略的估计将有助于估计任务所需的资源。
如果你想直接看结果,可以跳到本文最后。不过在阅读本文前请记住所有神经网络都是通过反向传播的方法进行训练的, 这一点对于我们计算内存的占用十分重要。
total_memory = memory_modal + memory_activations + memory_gradients
这里的memory_modal是指存储模型所有参数所需的内存。memory_activations是计算并存储在正向传播中的中间变量,在计算梯度时需要使用这些变量。因为模型中梯度的数量通常等于中间变量的数量,所以memory_activations= memory_gradients。因此可以写成:
total_memory = memory_modal + 2 * memory_activations
所以我们计算总体内存的要求时只需要找到memory_modal和memory_activations就可以了。
下面我们以GPT为例。GPT由许多transformer块组成(后面我用n_tr_blocks表示其数量)。每个transformer块都包含以下结构:
multi_headed_attention --> layer_normalization --> MLP -->layer_normalization
每个multi_headed_attention元素都由键,值和查询组成。其中包括n_head个注意力头和dim个维度。MLP是包含有n_head * dim的尺寸。这些权重都是要占用内存的,那么
memory_modal = memory of multi_headed_attention + memory of MLP
= memory of value + memory of key + memory of query + memory of MLP
= square_of(n_head * dim) + square_of(n_head * dim) + square_of(n_head * dim) + square_of(n_head * dim)
= 4*square_of(n_head * dim)
因为我们的模型包含了n个单元。所以最后内存就变为:
memory_modal = 4*n_tr_blocks*square_of(n_head * dim)
上面的估算没有考虑到偏差所需的内存,因为这大部分是静态的,不依赖于批大小、输入序列等。
多头注意力通常使用softmax,可以写成:
multi_headed_attention = softmax(query * key * sequence_length) * value
k,q,v的维度是:
[batch_size, n_head, sequence_length, dim]
multi_headed_attention操作会得出如下形状:
[batch_size, n_head, sequence_length, sequence_length]
所以最终得内存为:
memory_softmax = batch_size * n_head * square_of(sequence_length)
q* k * sequence_length操作乘以value的形状为[batch_size, n_head, sequence_length, dim]。MLP也有相同的维度:
memory of MLP = batch_size * n_head * sequence_length * dim
memory of value = batch_size * n_head * sequence_length * dim
我们把上面的整合在一起,单个transformer的中间变量为:
memory_activations = memory_softmax + memory_value + memory_MLP
= batch_size * n_head * square_of(sequence_length)
+ batch_size * n_head * sequence_length * dim
+ batch_size * n_head * sequence_length * dim
= batch_size * n_head * sequence_length * (sequence_length + 2*dim)
再乘以块的数量,模型所有的memory_activations就是:
n_tr_blocks * (batch_size * n_head * sequence_length * (sequence_length + 2*dim))
我们把上面两个公式进行归纳总结,想看结果的话直接看这里就行了。transformer模型所需的总内存为:
total_memory = memory_modal + 2 * memory_activations
模型参数的内存:
4*n_tr_blocks*square_of(n_head * dim)
中间变量内存:
n_tr_blocks * (batch_size * n_head * sequence_length * (sequence_length + 2*dim))
我们使用下面的符号可以更简洁地写出这些公式。
R = n_tr_blocks = transformer层堆叠的数量
N = n_head = 注意力头数量
D = dim = 注意力头的维度
B = batch_size = 批大小
S = sequence_length =输入序列的长度
memory modal = 4 * R * N^2 * D^2
memory activations = RBNS(S + 2D)
所以在训练模型时总的内存占用为:
M = (4 * R * N^2 * D^2) + RBNS(S + 2D)
因为内存的占用和序列长度又很大的关系,如果有一个很长的序列长度S >> D S + 2D <——> S,这时可以将计算变为:
M = (4 * R * N^2 * D^2) + RBNS(S) = 4*R*N^2*D^2 + RBNS^2
可以看到对于较大的序列,M与输入序列长度的平方成正比,与批大小成线性比例,这也就证明了序列长度和内存占用有很大的关系。
所以最终的内存占用的评估为:
总内存 = ((4 * R * N^2 * D^2) + RBNS(S + 2D)) * float64(以字节为单位)
https://avoid.overfit.cn/post/6724eec842b740d482f73386b1b8b012
作者:Schartz Rehan