11-16 周四 简单代码理解FlashAttention 分块计算softmax

下面的代码对于2*3进行演示

#!/usr/bin/env python
# -*- encoding: utf-8 -*-
import numpy as np


# 定义输入数组
input_array = np.array([[1, 2, 3], [4, 9, 6]])

print("np.e:", np.e)
print("1/np.e:", 1/np.e)

# 求出每行的最大值
max_values = np.max(input_array, axis=1, keepdims=True)
m1 = max_values[0]
m2 = max_values[1]
print(f"m1={m1}, m2={m2}")

# 减去每行的最大值
input_array = input_array - max_values
# 计算softmax
exp_values = np.exp(input_array)

f1=exp_values[0]
f2=exp_values[1]
max = np.sum(exp_values, axis=1, keepdims=True)
sum1 = max[0]
sum2 = max[1]

print(f"f1={f1}\nf2={f2}")
print(f"sum1={sum1}, sum2={sum2}")
print("f1/sum1=", f1 / sum1)
print("f2/sum2=", f2 / sum2)
softmax_output = exp_values / np.sum(exp_values, axis=1, keepdims=True)
print("sum: ", np.sum(exp_values, axis=1, keepdims=True))

print("基础softmax_output:", softmax_output)

L = np.exp(-6)*sum1 + sum2

print(f"L={L}")

con = np.concatenate((np.exp(-6)*f1, f2), axis=0)

print("con: ", con)
print("result:", con / L)


def softmax(input_array):
    # 求出每行的最大值
    max_values = np.max(input_array, axis=1, keepdims=True)

    # 减去每行的最大值
    input_array = input_array - max_values
    # 计算e的指数
    exp_values = np.exp(input_array)
    
    softmax_output = exp_values / np.sum(exp_values, axis=1, keepdims=True)
    return softmax_output


print("直接计算: ", softmax(np.array([[1, 2, 3, 4, 9, 6]])))

 上述的代码过程主要是将张量分成了两块进行计算,最后可以看到采用逐步累加的方式得到的结果与逐步运算是相同的。

进一步优化了程序,让程序可以自由的变大,并且更加灵活

#!/usr/bin/env python
# -*- encoding: utf-8 -*-
import numpy as np


# 定义输入数组
input_array = np.array([[1, 2, 3, 1], [4, 9, 6, 10], [5, 10, 7, 6], [100, 13, 20, 30]])

print("np.e:", np.e)
print("1/np.e:", 1/np.e)

# # 求出每行的最大值
# max_values = np.max(input_array, axis=1, keepdims=True)
# m1 = max_values[0]
# m2 = max_values[1]
# print(f"m1={m1}, m2={m2}")

# 最大值
m = 0;
L = 0 
i = 0
for arr in input_array:
    print(arr)
    if i == 0:
        m = np.max(arr)
        print(f"m更新为{m}")
        print("arr-m:", arr-m)
        temp = np.exp(arr - m)
        
        L = np.sum(temp)
        result = temp 
        i += 1
        print("result:", result)
        continue
        
    
    m2 = np.max(arr)
    print(f"m2={m2}")
    print(f"arr-m2={arr-m2}")
    temp2 = np.exp(arr - m2)
    
    L2 = np.sum(temp2)
    temp2 = temp2
    print(f"temp2={temp2}")
    m_new = m2 if m < m2 else m
    print(f"L2 = {L2}")
    L = np.exp(m - m_new) * L + np.exp(m2 - m_new) * L2
    print(f"L={L}")
    print(f"m-m_new: {m-m_new}, m2-m_new: {m2-m_new}")
    result = np.concatenate((np.exp(m-m_new)*result, np.exp(m2-m_new)*temp2))
    
    print(f"result={result}")
    m = m_new
    print(f"m更新为: {m}")

print(f"结果为: {result/L}")        


def softmax(input_array):
    # 求出每行的最大值
    max_values = np.max(input_array, axis=1, keepdims=True)

    # 减去每行的最大值
    input_array = input_array - max_values
    # 计算e的指数
    exp_values = np.exp(input_array)
    
    softmax_output = exp_values / np.sum(exp_values, axis=1, keepdims=True)
    return softmax_output


print(input_array.reshape(-1))

print("直接计算: ", softmax([input_array.reshape(-1)]))
print("直接计算[1, 2, 3]: ", softmax(np.array([[1, 2, 3]])))
print("直接计算[4, 9, 6]: ", softmax(np.array([[4, 9, 6]])))

 运算之后可以得到输出如下:

python softmax.py
np.e: 2.718281828459045
1/np.e: 0.36787944117144233
[1 2 3 1]
m更新为3
arr-m: [-2 -1  0 -2]
result: [ 0.13533528  0.36787944  1.          0.13533528]
[ 4  9  6 10]
m2=10
arr-m2=[-6 -1 -4  0]
temp2=[ 0.00247875  0.36787944  0.01831564  1.        ]
L2 = 1.3886738322368428
L=1.3901679964384732
m-m_new: -7, m2-m_new: 0
result=[  1.23409804e-04   3.35462628e-04   9.11881966e-04   1.23409804e-04
   2.47875218e-03   3.67879441e-01   1.83156389e-02   1.00000000e+00]
m更新为: 10
[ 5 10  7  6]
m2=10
arr-m2=[-5  0 -3 -4]
temp2=[ 0.00673795  1.          0.04978707  0.01831564]
L2 = 1.0748406542556834
L=2.4650086506941564
m-m_new: 0, m2-m_new: 0
result=[  1.23409804e-04   3.35462628e-04   9.11881966e-04   1.23409804e-04
   2.47875218e-03   3.67879441e-01   1.83156389e-02   1.00000000e+00
   6.73794700e-03   1.00000000e+00   4.97870684e-02   1.83156389e-02]
m更新为: 10
[100  13  20  30]
m2=100
arr-m2=[  0 -87 -80 -70]
temp2=[  1.00000000e+00   1.64581143e-38   1.80485139e-35   3.97544974e-31]
L2 = 1.0
L=1.0
m-m_new: -90, m2-m_new: 0
result=[  1.01122149e-43   2.74878501e-43   7.47197234e-43   1.01122149e-43
   2.03109266e-42   3.01440879e-40   1.50078576e-41   8.19401262e-40
   5.52108228e-42   8.19401262e-40   4.07955867e-41   1.50078576e-41
   1.00000000e+00   1.64581143e-38   1.80485139e-35   3.97544974e-31]
m更新为: 100
结果为: [  1.01122149e-43   2.74878501e-43   7.47197234e-43   1.01122149e-43
   2.03109266e-42   3.01440879e-40   1.50078576e-41   8.19401262e-40
   5.52108228e-42   8.19401262e-40   4.07955867e-41   1.50078576e-41
   1.00000000e+00   1.64581143e-38   1.80485139e-35   3.97544974e-31]
[  1   2   3   1   4   9   6  10   5  10   7   6 100  13  20  30]
直接计算:  [[  1.01122149e-43   2.74878501e-43   7.47197234e-43   1.01122149e-43
    2.03109266e-42   3.01440879e-40   1.50078576e-41   8.19401262e-40
    5.52108228e-42   8.19401262e-40   4.07955867e-41   1.50078576e-41
    1.00000000e+00   1.64581143e-38   1.80485139e-35   3.97544974e-31]]
直接计算[1, 2, 3]:  [[ 0.09003057  0.24472847  0.66524096]]
直接计算[4, 9, 6]:  [[ 0.00637746  0.94649912  0.04712342]]

 从上述的日志中,可以看到主键累计的计算结果

结果为: [  1.01122149e-43   2.74878501e-43   7.47197234e-43   1.01122149e-43
   2.03109266e-42   3.01440879e-40   1.50078576e-41   8.19401262e-40
   5.52108228e-42   8.19401262e-40   4.07955867e-41   1.50078576e-41
   1.00000000e+00   1.64581143e-38   1.80485139e-35   3.97544974e-31]

而直接计算的结果为:

[[  1.01122149e-43   2.74878501e-43   7.47197234e-43   1.01122149e-43
 2.03109266e-42   3.01440879e-40   1.50078576e-41   8.19401262e-40
 5.52108228e-42   8.19401262e-40   4.07955867e-41   1.50078576e-41
 1.00000000e+00   1.64581143e-38   1.80485139e-35   3.97544974e-31]]

 因此验证了精确的注意力计算

你可能感兴趣的:(python,机器学习)