下面的代码对于2*3进行演示
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
import numpy as np
# 定义输入数组
input_array = np.array([[1, 2, 3], [4, 9, 6]])
print("np.e:", np.e)
print("1/np.e:", 1/np.e)
# 求出每行的最大值
max_values = np.max(input_array, axis=1, keepdims=True)
m1 = max_values[0]
m2 = max_values[1]
print(f"m1={m1}, m2={m2}")
# 减去每行的最大值
input_array = input_array - max_values
# 计算softmax
exp_values = np.exp(input_array)
f1=exp_values[0]
f2=exp_values[1]
max = np.sum(exp_values, axis=1, keepdims=True)
sum1 = max[0]
sum2 = max[1]
print(f"f1={f1}\nf2={f2}")
print(f"sum1={sum1}, sum2={sum2}")
print("f1/sum1=", f1 / sum1)
print("f2/sum2=", f2 / sum2)
softmax_output = exp_values / np.sum(exp_values, axis=1, keepdims=True)
print("sum: ", np.sum(exp_values, axis=1, keepdims=True))
print("基础softmax_output:", softmax_output)
L = np.exp(-6)*sum1 + sum2
print(f"L={L}")
con = np.concatenate((np.exp(-6)*f1, f2), axis=0)
print("con: ", con)
print("result:", con / L)
def softmax(input_array):
# 求出每行的最大值
max_values = np.max(input_array, axis=1, keepdims=True)
# 减去每行的最大值
input_array = input_array - max_values
# 计算e的指数
exp_values = np.exp(input_array)
softmax_output = exp_values / np.sum(exp_values, axis=1, keepdims=True)
return softmax_output
print("直接计算: ", softmax(np.array([[1, 2, 3, 4, 9, 6]])))
上述的代码过程主要是将张量分成了两块进行计算,最后可以看到采用逐步累加的方式得到的结果与逐步运算是相同的。
进一步优化了程序,让程序可以自由的变大,并且更加灵活
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
import numpy as np
# 定义输入数组
input_array = np.array([[1, 2, 3, 1], [4, 9, 6, 10], [5, 10, 7, 6], [100, 13, 20, 30]])
print("np.e:", np.e)
print("1/np.e:", 1/np.e)
# # 求出每行的最大值
# max_values = np.max(input_array, axis=1, keepdims=True)
# m1 = max_values[0]
# m2 = max_values[1]
# print(f"m1={m1}, m2={m2}")
# 最大值
m = 0;
L = 0
i = 0
for arr in input_array:
print(arr)
if i == 0:
m = np.max(arr)
print(f"m更新为{m}")
print("arr-m:", arr-m)
temp = np.exp(arr - m)
L = np.sum(temp)
result = temp
i += 1
print("result:", result)
continue
m2 = np.max(arr)
print(f"m2={m2}")
print(f"arr-m2={arr-m2}")
temp2 = np.exp(arr - m2)
L2 = np.sum(temp2)
temp2 = temp2
print(f"temp2={temp2}")
m_new = m2 if m < m2 else m
print(f"L2 = {L2}")
L = np.exp(m - m_new) * L + np.exp(m2 - m_new) * L2
print(f"L={L}")
print(f"m-m_new: {m-m_new}, m2-m_new: {m2-m_new}")
result = np.concatenate((np.exp(m-m_new)*result, np.exp(m2-m_new)*temp2))
print(f"result={result}")
m = m_new
print(f"m更新为: {m}")
print(f"结果为: {result/L}")
def softmax(input_array):
# 求出每行的最大值
max_values = np.max(input_array, axis=1, keepdims=True)
# 减去每行的最大值
input_array = input_array - max_values
# 计算e的指数
exp_values = np.exp(input_array)
softmax_output = exp_values / np.sum(exp_values, axis=1, keepdims=True)
return softmax_output
print(input_array.reshape(-1))
print("直接计算: ", softmax([input_array.reshape(-1)]))
print("直接计算[1, 2, 3]: ", softmax(np.array([[1, 2, 3]])))
print("直接计算[4, 9, 6]: ", softmax(np.array([[4, 9, 6]])))
运算之后可以得到输出如下:
python softmax.py
np.e: 2.718281828459045
1/np.e: 0.36787944117144233
[1 2 3 1]
m更新为3
arr-m: [-2 -1 0 -2]
result: [ 0.13533528 0.36787944 1. 0.13533528]
[ 4 9 6 10]
m2=10
arr-m2=[-6 -1 -4 0]
temp2=[ 0.00247875 0.36787944 0.01831564 1. ]
L2 = 1.3886738322368428
L=1.3901679964384732
m-m_new: -7, m2-m_new: 0
result=[ 1.23409804e-04 3.35462628e-04 9.11881966e-04 1.23409804e-04
2.47875218e-03 3.67879441e-01 1.83156389e-02 1.00000000e+00]
m更新为: 10
[ 5 10 7 6]
m2=10
arr-m2=[-5 0 -3 -4]
temp2=[ 0.00673795 1. 0.04978707 0.01831564]
L2 = 1.0748406542556834
L=2.4650086506941564
m-m_new: 0, m2-m_new: 0
result=[ 1.23409804e-04 3.35462628e-04 9.11881966e-04 1.23409804e-04
2.47875218e-03 3.67879441e-01 1.83156389e-02 1.00000000e+00
6.73794700e-03 1.00000000e+00 4.97870684e-02 1.83156389e-02]
m更新为: 10
[100 13 20 30]
m2=100
arr-m2=[ 0 -87 -80 -70]
temp2=[ 1.00000000e+00 1.64581143e-38 1.80485139e-35 3.97544974e-31]
L2 = 1.0
L=1.0
m-m_new: -90, m2-m_new: 0
result=[ 1.01122149e-43 2.74878501e-43 7.47197234e-43 1.01122149e-43
2.03109266e-42 3.01440879e-40 1.50078576e-41 8.19401262e-40
5.52108228e-42 8.19401262e-40 4.07955867e-41 1.50078576e-41
1.00000000e+00 1.64581143e-38 1.80485139e-35 3.97544974e-31]
m更新为: 100
结果为: [ 1.01122149e-43 2.74878501e-43 7.47197234e-43 1.01122149e-43
2.03109266e-42 3.01440879e-40 1.50078576e-41 8.19401262e-40
5.52108228e-42 8.19401262e-40 4.07955867e-41 1.50078576e-41
1.00000000e+00 1.64581143e-38 1.80485139e-35 3.97544974e-31]
[ 1 2 3 1 4 9 6 10 5 10 7 6 100 13 20 30]
直接计算: [[ 1.01122149e-43 2.74878501e-43 7.47197234e-43 1.01122149e-43
2.03109266e-42 3.01440879e-40 1.50078576e-41 8.19401262e-40
5.52108228e-42 8.19401262e-40 4.07955867e-41 1.50078576e-41
1.00000000e+00 1.64581143e-38 1.80485139e-35 3.97544974e-31]]
直接计算[1, 2, 3]: [[ 0.09003057 0.24472847 0.66524096]]
直接计算[4, 9, 6]: [[ 0.00637746 0.94649912 0.04712342]]
从上述的日志中,可以看到主键累计的计算结果
结果为: [ 1.01122149e-43 2.74878501e-43 7.47197234e-43 1.01122149e-43
2.03109266e-42 3.01440879e-40 1.50078576e-41 8.19401262e-40
5.52108228e-42 8.19401262e-40 4.07955867e-41 1.50078576e-41
1.00000000e+00 1.64581143e-38 1.80485139e-35 3.97544974e-31]
而直接计算的结果为:
[[ 1.01122149e-43 2.74878501e-43 7.47197234e-43 1.01122149e-43
2.03109266e-42 3.01440879e-40 1.50078576e-41 8.19401262e-40
5.52108228e-42 8.19401262e-40 4.07955867e-41 1.50078576e-41
1.00000000e+00 1.64581143e-38 1.80485139e-35 3.97544974e-31]]
因此验证了精确的注意力计算