scratch lenet(4): 开根号的C语言实现

文章目录

    • 1. 目的
    • 2 二分法求开根号
      • 2.1 数学原理:单调函数
      • 2.2 代码实现:注意事项
      • 2.3 代码实现: 完整代码
      • 2.4 验证结果
    • 3. 牛顿法
      • 3.1 数学原理:迭代求解
      • 3.2 代码实现
      • 3.3 结果
    • 4. 卡马克快速法
      • 4.1 原理
      • 4.2 代码实现
      • 4.3 结果
    • 5. 完整代码
    • 6. References

1. 目的

训练 lenet 需要初始化 kernel 的 weight 和 bias,而使用 Xavier Glorot 初始化则需要计算 sqrt ( 6.0 f a n i n + f a n o u t ) \text{sqrt}(\frac{6.0}{fan_{in} + fan_{out}}) sqrt(fanin+fanout6.0)(均匀分布) 或 sqrt ( 2.0 f a n i n + f a n o u t ) \text{sqrt}(\frac{2.0}{fan_{in} + fan_{out}}) sqrt(fanin+fanout2.0)(高斯分布).(参考[1]). 为了完全用 C 语言实现 lenet 的训练, 避免依赖 C 标准库的数学库函数 sqrt(), 考虑弄清楚 sqrt() 的原理, 手动实现一个"低配版": 精度有轻微误差,实现简单。

实现开根号的方法,粗略说有三种:

  • 二分法
  • 牛顿法
  • 卡马克公式快速法

本文只考虑 n , n ∈ R + \sqrt{n}, n \in \R^+ n ,nR+.

2 二分法求开根号

2.1 数学原理:单调函数

对于正实数 n ∈ R + n \in \R^+ nR+, 它的二次方根为 x = n x=\sqrt{n} x=n , 也就是使得 x 2 = n x^2=n x2=n 成立的数字。考察方程 f ( x ) = x 2 − n = 0 f(x)=x^2-n=0 f(x)=x2n=0 的解:

  • 如果 n > 1 n > 1 n>1, 则 n ∈ ( 1 , n ) \sqrt{n} \in (1, n) n (1,n), 是一个单调递增区间, s.t. f ( x ) \text{s.t.} f(x) s.t.f(x)有解
  • 如果 0 < n < 1 0 < n < 1 0<n<1, 则 n ∈ ( n , 1 ) \sqrt{n} \in (n, 1) n (n,1), 也是一个单调递增区间, s.t. f ( x ) \text{s.t.} f(x) s.t.f(x)有解
    单调性使得我们可以用二分法求解 f ( x ) = x 2 − n = 0 f(x)=x^2-n=0 f(x)=x2n=0, 从而得到答案 n = x \sqrt{n}=x n =x

2.2 代码实现:注意事项

比较相等

C语言使用 IEEE-754 标准来表示浮点数, 表示的数字可能和理论数字有误差, 因此判断浮点数相等时往往做差值的绝对值然后和 eps 比较, 小于eps就认为相等。

迭代求解

二分法是一个迭代求解算法, 可以手动设置迭代次数, 也可以设置比较精度 eps,迭代过程中精度误差小于 eps 就停止。本文的实现选择设置 eps 的方式。

防止溢出

本文给出的实现,是用 double 类型计算的。 计算两个数字中点时,有可能超过 double 类型最大值, 因此用先求差值的一半,再加到左端点的方式来计算中点。

特殊数字处理

n < 0 n < 0 n<0, 直接返回。
n = 0 n = 0 n=0 n = 1 n = 1 n=1, 直接返回。

2.3 代码实现: 完整代码

#include 
#include 

double m_fabs(double n)
{
    return n >= 0.0 ? n : -n;
}

double m_sqrt(double n)
{
    if (n == 0.0 || n == 1.0)
    {
        return n;
    }
    if (n < 0.f)
    {
        printf("Error: not supported n: %f\n", n);
        return -1;
    }

    double left, right;
    if (n > 1.0)
    {
        left = 1.0;
        right = n;
    }
    else
    {
        left = n;
        right = 1.0;
    }

    double left_v = left * left - n;
    double right_v = right * right - n;
    if (left_v * right_v > 0)
    {
        printf("Error: not exist sqrt for n=%f\n", n);
        return -2;
    }

    const double eps = 1e-5;
    while (left <= right)
    {
        printf("left=%f, right=%f\n", left, right);
        double mid = left + (right - left) / 2.0;
        double value = mid * mid;
        if (value - n > eps)
        {
            right = mid;
        }
        else if (value - n < -eps)
        {
            left = mid;
        }
        else
        {
            return mid;
        }
    }

    return 233;
}

int main()
{
    double n;
    while (true)
    {
        printf(">>> Please input an double number: ");
        scanf("%lf", &n);
        double ans = m_sqrt(n);
        printf("sqrt(%lf) = %lf\n", n, ans);
    }
    return 0;
}

2.4 验证结果

base) zz@Legion-R7000P% gcc sqrt.c
(base) zz@Legion-R7000P% ./a.out
>>> Please input an double number: 9.0
left=1.000000, right=9.000000
left=1.000000, right=5.000000
sqrt(9.000000) = 3.000000
>>> Please input an double number: 0.04
left=0.040000, right=1.000000
left=0.040000, right=0.520000
left=0.040000, right=0.280000
left=0.160000, right=0.280000
left=0.160000, right=0.220000
left=0.190000, right=0.220000
left=0.190000, right=0.205000
left=0.197500, right=0.205000
left=0.197500, right=0.201250
left=0.199375, right=0.201250
left=0.199375, right=0.200313
left=0.199844, right=0.200313
left=0.199844, right=0.200078
left=0.199961, right=0.200078
sqrt(0.040000) = 0.200020
>>> Please input an double number: 0.01
left=0.010000, right=1.000000
left=0.010000, right=0.505000
left=0.010000, right=0.257500
left=0.010000, right=0.133750
left=0.071875, right=0.133750
left=0.071875, right=0.102813
left=0.087344, right=0.102813
left=0.095078, right=0.102813
left=0.098945, right=0.102813
left=0.098945, right=0.100879
left=0.099912, right=0.100879
left=0.099912, right=0.100396
left=0.099912, right=0.100154
sqrt(0.010000) = 0.100033
>>> Please input an double number: ^C

3. 牛顿法

3.1 数学原理:迭代求解

给定数字 a a a, 求 a \sqrt{a} a . 等价于求方程 f ( x ) = x 2 − a = 0 f(x)=x^2-a = 0 f(x)=x2a=0 的解。

这个方程在 x 0 x_0 x0 点处的切线 L ( x 0 ) L(x_0) L(x0)方程为 f ( x ) − f ( x 0 ) = f ′ ( x 0 ) ( x − x 0 ) f(x)-f(x_0)=f'(x_0)(x-x_0) f(x)f(x0)=f(x0)(xx0).

切线与 x x x 轴有交点, 也就是当 f ( x ) = 0 f(x)=0 f(x)=0, f ′ ( x 0 ) ( x − x 0 ) + f ( x 0 ) = 0 f'(x_0)(x-x_0) + f(x_0) = 0 f(x0)(xx0)+f(x0)=0

⇒ x − x 0 = − f ( x 0 ) / f ′ ( x 0 ) \Rightarrow x-x_0 = -f(x_0)/f'(x_0) xx0=f(x0)/f(x0)

$\Rightarrow x = x_0 - f(x_0)/f’(x_0) = x_0 - (x_0^2-n)/2x_0 = (x_0 + a/x_0)/2 $

得到 x \sqrt{x} x 的第一个近似解 x 1 = ( x 0 + a x 0 ) / 2 x_1=(x_0+\frac{a}{x_0})/2 x1=(x0+x0a)/2.

通常 x 1 x_1 x1 的精度不足,也就是 x 1 2 x_1 ^ 2 x12 a a a 相差比较多,因此还需要继续迭代。迭代到第 n n n 次时:
$\Rightarrow x_{n+1} = x_{n} - \frac{f_n}{f’(x_n)} = \frac{1}{2} (x_n + \frac{a}{x_n}) $

只要此时 x n 2 {x_n}^2 xn2 a a a 足够接近, 或者迭代次数 n n n 足够大, 都可以停止迭代, 用 x n x_n xn 作为 a \sqrt{a} a .

3.2 代码实现

#include 
#include 

double m_fabs(double n)
{
    return n >= 0.0 ? n : -n;
}

double m_sqrt_newton(double a)
{
    // x_{n+1} = \frac{1}{2} (x_n + \frac{a}{x_n})
    double x = 1.0; // why?
    double eps = 1e-5;
    while (m_fabs(x * x - a) > eps)
    {
        printf("x = %lf\n", x);
        x = (x + a / x) / 2.0;
    }
    return x;
}

int main()
{
    double n;
    while (true)
    {
        printf(">>> Please input an double number: ");
        scanf("%lf", &n);
        //double ans = m_sqrt(n);
        //printf("sqrt(%lf) = %lf\n", n, ans);

        double ans_newton = m_sqrt_newton(n);
        printf("sqrt_newton(%lf) = %lf\n", n, ans_newton);
    }
    return 0;
}

3.3 结果

zz@Legion-R7000P% gcc sqrt.c 
zz@Legion-R7000P% ./a.out 
>>> Please input an double number: 9.0
x = 1.000000
x = 5.000000
x = 3.400000
x = 3.023529
x = 3.000092
sqrt_newton(9.000000) = 3.000000
>>> Please input an double number: 0.04
x = 1.000000
x = 0.520000
x = 0.298462
x = 0.216241
x = 0.200610
sqrt_newton(0.040000) = 0.200001
>>> Please input an double number: 0.01
x = 1.000000
x = 0.505000
x = 0.262401
x = 0.150255
x = 0.108404
x = 0.100326
sqrt_newton(0.010000) = 0.100001
>>> Please input an double number: ^C

4. 卡马克快速法

4.1 原理

卡马克在雷神之锤游戏中给出了求平方根倒数的一种非常trick的代码实现。把它再求倒数, 就得到开根号结果。

它其实是一种混合方法: 一部分是牛顿法, 另一部分是对数函数的近似。其中牛顿迭代部分用于提升精度, 对数函数的逼近则和 IEEE-754 浮点数表示法紧密结合。

使用的近似公式是 l o g 2 ( 1 + x ) ≈ x + k log_2(1+x) \approx x + k log2(1+x)x+k. 见参考[4].

4.2 代码实现

由于 Carmack 快速求平方根的倒数法, 本身目的就是要尽可能快, 因此使用 float 类型而不是 double 类型。

#include 

double m_sqrt_carmack(double n)
{
    int i;
    float x2, y;
    const float threehalfs = 1.5f;

    x2 = n * 0.5f;
    y = (float)n;

    i = *(int*)&y;
    i = 0x5f3759df - (i >> 1);
    y = *(float *)&i;
    y = y * (threehalfs - (x2 * y * y)); // 1st iteration
    y = y * (threehalfs - (x2 * y * y)); // 2nd iteration
    return 1.0 / y;
}

int main()
{
    double n;
    while (true)
    {
        printf(">>> Please input an double number: ");
        scanf("%lf", &n);

        double ans_carmack = m_sqrt_carmack(n);
        printf("sqrt_carmack(%lf) = %lf\n", n, ans_carmack);
    }
    return 0;
}

4.3 结果

zz@Legion-R7000P% gcc sqrt.c 
zz@Legion-R7000P% ./a.out 
>>> Please input an double number: 9.0
sqrt_carmack(9.000000) = 3.000006
>>> Please input an double number: 0.04
sqrt_carmack(0.040000) = 0.200001
>>> Please input an double number: 0.01
sqrt_carmack(0.010000) = 0.100000
>>> Please input an double number: ^C

5. 完整代码

// Author: Zhuo Zhang 
// Homepage: https://github.com/zchrissirhcz
#include 
#include 

double m_fabs(double n)
{
    return n >= 0.0 ? n : -n;
}

double m_sqrt(double n)
{
    if (n == 0.0 || n == 1.0)
    {
        return n;
    }
    if (n < 0.f)
    {
        printf("Error: not supported n: %f\n", n);
        return -1;
    }

    double left, right;
    if (n > 1.0)
    {
        left = 1.0;
        right = n;
    }
    else
    {
        left = n;
        right = 1.0;
    }

    double left_v = left * left - n;
    double right_v = right * right - n;
    if (left_v * right_v > 0)
    {
        printf("Error: not exist sqrt for n=%f\n", n);
        return -2;
    }

    const double eps = 1e-5;
    while (left <= right)
    {
        printf("left=%f, right=%f\n", left, right);
        double mid = left + (right - left) / 2.0;
        double value = mid * mid;
        if (value - n > eps)
        {
            right = mid;
        }
        else if (value - n < -eps)
        {
            left = mid;
        }
        else
        {
            return mid;
        }
    }

    return 233;
}

double m_sqrt_newton(double a)
{
    // x_{n+1} = \frac{1}{2} (x_n + \frac{a}{x_n})
    double x = 1.0; // why?
    double eps = 1e-5;
    while (m_fabs(x * x - a) > eps)
    {
        printf("x = %lf\n", x);
        x = (x + a / x) / 2.0;
    }
    return x;
}

double m_sqrt_carmack(double n)
{
    int i;
    float x2, y;
    const float threehalfs = 1.5f;

    x2 = n * 0.5f;
    y = (float)n;

    i = *(int*)&y;
    i = 0x5f3759df - (i >> 1);
    y = *(float *)&i;
    y = y * (threehalfs - (x2 * y * y)); // 1st iteration
    y = y * (threehalfs - (x2 * y * y)); // 2nd iteration
    return 1.0 / y;
}

int main()
{
    double n;
    while (true)
    {
        printf(">>> Please input an double number: ");
        scanf("%lf", &n);
        double ans = m_sqrt(n);
        printf("sqrt(%lf) = %lf\n", n, ans);

        double ans_newton = m_sqrt_newton(n);
        printf("sqrt_newton(%lf) = %lf\n", n, ans_newton);

        double ans_carmack = m_sqrt_carmack(n);
        printf("sqrt_carmack(%lf) = %lf\n", n, ans_carmack);

    }
    return 0;
}

6. References

  • [1] https://www.bookstack.cn/read/paddlepaddle-1.6/3f4d0d9266a7a5c8.md
  • [2] https://www.cnblogs.com/wangkundentisy/p/8118007.html
  • [3] https://blog.csdn.net/plm199513100/article/details/124072422
  • [4] 【回归本源】番外1-雷神之锤3的那段代码

你可能感兴趣的:(C/C++,c语言,算法,开发语言,开根号,数学)