scratch lenet(9): C语言实现tanh的计算

文章目录

    • 1. 目的
    • 2. tanh ⁡ ( x ) \tanh(x) tanh(x) 的 naive 实现
      • 2.1 数学公式
      • 2.2 naive 实现
    • 3. tanh ⁡ ( x ) \tanh(x) tanh(x) 的快速计算
      • 3.1 Maple 中的近似公式
      • 3.2 tan_c3()
      • 3.3 Gauss 连分数公式 (Continued Fraction)
    • 4. 最终代码和运行结果
      • 代码
      • 运行结果
    • 5. 其他
    • References

1. 目的

用于 LeNet-5 网络中 squashing function 中 tanh() 部分的计算。tanh() 是 hyperbolic tangent 双曲正切三角函数的意思。

LeNet-5 网络的 C1~F6, 每一层都需要对于输出结果应用 squashing function. 后世称作 activation function.

2. tanh ⁡ ( x ) \tanh(x) tanh(x) 的 naive 实现

2.1 数学公式

tanh ⁡ ( x ) = e x − e − x e x + e − x \tanh(x) = \frac{e^x-e^{-x}}{e^x+e^{-x}} tanh(x)=ex+exexex

2.2 naive 实现

直接翻译公式得到:

static double m_tanh(double x)
{
    double ep = m_exp(x);  // exponent of positive x
    double en = m_exp(-x); // exponent of negative x
    double up = ep - en;
    double down = ep + en;
    return up / down;
}

其中 m_exp 在前一篇博客1 scratch lenet(8): C语言实现 exp(x) 的计算 给出过:

static double m_fabs(double n)
{
    return n >= 0.0 ? n : -n;
}

double m_exp(double x)
{
    double res = 1;
    double eps = 1e-9;
    double up = 1;
    double down = 1;
    for (int i = 1; ;i++)
    {
        up *= x;
        down *= i;
        double delta = up / down;
        res += delta;
        // printf("i=%d, delta=%lf\n", i, delta);
        if (m_fabs(delta) < eps)
            break;
    }
    return res;
}

3. tanh ⁡ ( x ) \tanh(x) tanh(x) 的快速计算

StackOverFlow 上的一个问答2 给出了好几种近似计算方式。

3.1 Maple 中的近似公式

回答3 给出了一个公式(TL;DR 这一节的公式不靠谱,精度丢失比较多)

The best rational approximation to tanh(x) with numerator and denominator of degree 3 on the interval [0,3.1] (as provided by Maple’s minimax function) is

(-.67436811832e-5+(.2468149110712040+(.583691066395175e-1+.3357335044280075e-1*x)*x)*x)/(.2464845986383725+(.609347197060491e-1+(.1086202599228572+.2874707922475963e-1*x)*x)*x)

This (call it f(x)) has maximum error .2735944241730870e-4, which is considerably less than 2^(-8).
On the interval [−3.1,3.1], use sgn(x)f(|x|

double fast_tanh_by_maple(double x)
{
    return (-.67436811832e-5+(.2468149110712040+(.583691066395175e-1+.3357335044280075e-1*x)*x)*x)/(.2464845986383725+(.609347197060491e-1+(.1086202599228572+.2874707922475963e-1*x)*x)*x);
}
zz@Legion-R7000P% ./a.out 
Please input a real number: 0.345
 tanh(0.345000) = 0.331934
 tanh_c3(0.345000) = 0.331935
 m_tanh(0.345000) = 0.331934
 fast_tanh_by_maple(0.345000) = 0.331907

我尝试后,发现精度差的有点多,并不是所谓的“精度损失小于 .2735944241730870e-4”, 而是肉眼可见的有精度损失:

>>> e1 = .2735944241730870e-4
>>> e2 = 0.331934 - 0.331907
>>> e1 < e2
False
>>> 

3.2 tan_c3()

jenkas 给出了一个更好的近似公式和实现4.

float tanh_c3(float v)
{
    const float c1 = 0.03138777F;
    const float c2 = 0.276281267F;
    const float c_log2f = 1.442695022F;
    v *= c_log2f;
    int intPart = (int)v;
    float x = (v - intPart);
    float xx = x * x;
    float v1 = c_log2f + c2 * xx;
    float v2 = x + xx * c1 * x;
    float v3 = (v2 + v1);
    *((int*)&v3) += intPart << 24;
    float v4 = v2 - v1;
    return (v3 + v4) / (v3 - v4);
}

暂时没搞懂这个实现对应的公式

v = I + x // 整数部分 + 小数部分
xx = x * x // 小数部分的平方
v1 = c_log2f + c2 * xx
v2 = x + xx * c1 * x
v3 = v2 + v1 = x + xx * c1 * x + c_log2f + c2 * xx
             = c_log2f + x + c1 * x * x * x + c2 * x * x
v4 = v2 - v1 = x + xx * c1 * x - c_log2f - c2 * xx
             = -c_log2f + x - c2 * x * 2 + c1 * x * x * x

3.3 Gauss 连分数公式 (Continued Fraction)

1812年高斯给出的双曲正切函数 tanh ⁡ ( x ) \tanh(x) tanh(x) 的连分数展开公式 (continued fraction for the hyperbolic tangent 5)

tanh ⁡ ( x ) = x 1 + x 2 3 + x 2 5 + . . . \tanh(x) = \frac{x}{1 + \frac{x^2}{3 + \frac{x^2}{5 + ...}}} tanh(x)=1+3+5+...x2x2x

我们使用展开到 9 + x 2 11 9 + \frac{x^2}{11} 9+11x2 的这一项, 作为 tanh 的近似6
scratch lenet(9): C语言实现tanh的计算_第1张图片

发现结果非常准确(至少对于 x = 0.345 x=0.345 x=0.345 来说, 和 C 标准库结果一样)

double approx_tanh_by_continues_fraction(double x)
{
    double s = x * x;
    double y = 9 + s / 11;
    y = 7 + s / y;
    y = 5 + s / y;
    y = 3 + s / y;
    y = 1 + s / y;
    y = x / y;
    return y;
}

4. 最终代码和运行结果

代码

#include 
#include 
#include 

double tanh_c3(float v)
{
    const float c1 = 0.03138777F;
    const float c2 = 0.276281267F;
    const float c_log2f = 1.442695022F;
    v *= c_log2f;
    int intPart = (int)v;
    float x = (v - intPart);
    float xx = x * x;
    float v1 = c_log2f + c2 * xx;
    float v2 = x + xx * c1 * x;
    float v3 = (v2 + v1);
    *((int*)&v3) += intPart << 24;
    float v4 = v2 - v1;
    return (v3 + v4) / (v3 - v4);
}

static double m_fabs(double n)
{
    return n >= 0.0 ? n : -n;
}

double m_exp(double x)
{
    double res = 1;
    double eps = 1e-9;
    double up = 1;
    double down = 1;
    for (int i = 1; ;i++)
    {
        up *= x;
        down *= i;
        double delta = up / down;
        res += delta;
        // printf("i=%d, delta=%lf\n", i, delta);
        if (m_fabs(delta) < eps)
            break;
    }
    return res;
}

static double m_tanh(double x)
{
    double ep = m_exp(x);  // exponent of positive x
    double en = m_exp(-x); // exponent of negative x
    double up = ep - en;
    double down = ep + en;
    return up / down;
}

double fast_tanh_by_maple(double x)
{
    return (-.67436811832e-5+(.2468149110712040+(.583691066395175e-1+.3357335044280075e-1*x)*x)*x)/(.2464845986383725+(.609347197060491e-1+(.1086202599228572+.2874707922475963e-1*x)*x)*x);
}

double approx_tanh_by_continues_fraction(double x)
{
    double s = x * x;
    double y = 9 + s / 11;
    y = 7 + s / y;
    y = 5 + s / y;
    y = 3 + s / y;
    y = 1 + s / y;
    y = x / y;
    return y;
}

int main()
{
    double x;
    while (true)
    {
        printf("Please input a real number: ");
        scanf("%lf", &x);
        double y1 = tanh(x);
        double y2 = tanh_c3(x);
        double y3 = m_tanh(x);
        double y4 = fast_tanh_by_maple(x);
        double y5 = approx_tanh_by_continues_fraction(x);
        printf(" tanh(%lf) = %lf\n", x, y1);
        printf(" tanh_c3(%lf) = %lf\n", x, y2);
        printf(" m_tanh(%lf) = %lf\n", x, y3);
        printf(" fast_tanh_by_maple(%lf) = %lf\n", x, y4);
        printf(" approx_tanh_by_continues_fraction(%lf) = %lf\n", x, y5);
    }

    return 0;
}

运行结果

gcc tanh.c -lm
zz@Legion-R7000P% ./a.out 
Please input a real number: 0.345
 tanh(0.345000) = 0.331934
 tanh_c3(0.345000) = 0.331935
 m_tanh(0.345000) = 0.331934
 fast_tanh_by_maple(0.345000) = 0.331907
 approx_tanh_by_continues_fraction(0.345000) = 0.331934

也尝试了其他输入如 x=257, 整体上看 Gauss 给出的 Continued Fraction 的精度损失更小一些,速度也还算快,打算在 lenet-5 代码中使用它:

double approx_tanh_by_continues_fraction(double x)
{
    double s = x * x;
    double y = 9 + s / 11;
    y = 7 + s / y;
    y = 5 + s / y;
    y = 3 + s / y;
    y = 1 + s / y;
    y = x / y;
    return y;
}

5. 其他

K-Tanh 7 基于 AVX512 指令给出了5倍加速的实现。

[【Tanh的标量实现】]8 则考虑了 Inf/Nan 等情况, 并使用了
tanh ⁡ ( x ) = e 2 x − 1 e 2 x + 1 = 1 − 2 e 2 x + 1 \tanh(x) = \frac{e^{2x}-1}{e^{2x} + 1} = 1 - \frac{2}{e^{2x}+1} tanh(x)=e2x+1e2x1=1e2x+12
这一等效公式计算。

References


  1. scratch lenet(8): C语言实现 exp(x) 的计算 ↩︎

  2. Rapid approximation of tanh(x) ↩︎

  3. https://math.stackexchange.com/a/107302 ↩︎

  4. https://math.stackexchange.com/a/3485944 ↩︎

  5. continued fraction for the hyperbolic tangent ↩︎

  6. https://math.stackexchange.com/a/107295 ↩︎

  7. K-TANH: EFFICIENT TANH FOR DEEP LEARNING ↩︎

  8. 【Tanh的标量实现】 ↩︎

你可能感兴趣的:(C/C++,c语言,算法,tanh,深度学习,数值计算)