训练 lenet 需要初始化 kernel 的 weight 和 bias,而使用 Xavier Glorot 初始化则需要计算 sqrt ( 6.0 f a n i n + f a n o u t ) \text{sqrt}(\frac{6.0}{fan_{in} + fan_{out}}) sqrt(fanin+fanout6.0)(均匀分布) 或 sqrt ( 2.0 f a n i n + f a n o u t ) \text{sqrt}(\frac{2.0}{fan_{in} + fan_{out}}) sqrt(fanin+fanout2.0)(高斯分布).(参考[1]). 为了完全用 C 语言实现 lenet 的训练, 避免依赖 C 标准库的数学库函数 sqrt()
, 考虑弄清楚 sqrt()
的原理, 手动实现一个"低配版": 精度有轻微误差,实现简单。
实现开根号的方法,粗略说有三种:
本文只考虑 n , n ∈ R + \sqrt{n}, n \in \R^+ n,n∈R+.
对于正实数 n ∈ R + n \in \R^+ n∈R+, 它的二次方根为 x = n x=\sqrt{n} x=n, 也就是使得 x 2 = n x^2=n x2=n 成立的数字。考察方程 f ( x ) = x 2 − n = 0 f(x)=x^2-n=0 f(x)=x2−n=0 的解:
比较相等
C语言使用 IEEE-754 标准来表示浮点数, 表示的数字可能和理论数字有误差, 因此判断浮点数相等时往往做差值的绝对值然后和 eps 比较, 小于eps就认为相等。
迭代求解
二分法是一个迭代求解算法, 可以手动设置迭代次数, 也可以设置比较精度 eps,迭代过程中精度误差小于 eps 就停止。本文的实现选择设置 eps 的方式。
防止溢出
本文给出的实现,是用 double 类型计算的。 计算两个数字中点时,有可能超过 double 类型最大值, 因此用先求差值的一半,再加到左端点的方式来计算中点。
特殊数字处理
n < 0 n < 0 n<0, 直接返回。
n = 0 n = 0 n=0 和 n = 1 n = 1 n=1, 直接返回。
#include
#include
double m_fabs(double n)
{
return n >= 0.0 ? n : -n;
}
double m_sqrt(double n)
{
if (n == 0.0 || n == 1.0)
{
return n;
}
if (n < 0.f)
{
printf("Error: not supported n: %f\n", n);
return -1;
}
double left, right;
if (n > 1.0)
{
left = 1.0;
right = n;
}
else
{
left = n;
right = 1.0;
}
double left_v = left * left - n;
double right_v = right * right - n;
if (left_v * right_v > 0)
{
printf("Error: not exist sqrt for n=%f\n", n);
return -2;
}
const double eps = 1e-5;
while (left <= right)
{
printf("left=%f, right=%f\n", left, right);
double mid = left + (right - left) / 2.0;
double value = mid * mid;
if (value - n > eps)
{
right = mid;
}
else if (value - n < -eps)
{
left = mid;
}
else
{
return mid;
}
}
return 233;
}
int main()
{
double n;
while (true)
{
printf(">>> Please input an double number: ");
scanf("%lf", &n);
double ans = m_sqrt(n);
printf("sqrt(%lf) = %lf\n", n, ans);
}
return 0;
}
base) zz@Legion-R7000P% gcc sqrt.c
(base) zz@Legion-R7000P% ./a.out
>>> Please input an double number: 9.0
left=1.000000, right=9.000000
left=1.000000, right=5.000000
sqrt(9.000000) = 3.000000
>>> Please input an double number: 0.04
left=0.040000, right=1.000000
left=0.040000, right=0.520000
left=0.040000, right=0.280000
left=0.160000, right=0.280000
left=0.160000, right=0.220000
left=0.190000, right=0.220000
left=0.190000, right=0.205000
left=0.197500, right=0.205000
left=0.197500, right=0.201250
left=0.199375, right=0.201250
left=0.199375, right=0.200313
left=0.199844, right=0.200313
left=0.199844, right=0.200078
left=0.199961, right=0.200078
sqrt(0.040000) = 0.200020
>>> Please input an double number: 0.01
left=0.010000, right=1.000000
left=0.010000, right=0.505000
left=0.010000, right=0.257500
left=0.010000, right=0.133750
left=0.071875, right=0.133750
left=0.071875, right=0.102813
left=0.087344, right=0.102813
left=0.095078, right=0.102813
left=0.098945, right=0.102813
left=0.098945, right=0.100879
left=0.099912, right=0.100879
left=0.099912, right=0.100396
left=0.099912, right=0.100154
sqrt(0.010000) = 0.100033
>>> Please input an double number: ^C
给定数字 a a a, 求 a \sqrt{a} a. 等价于求方程 f ( x ) = x 2 − a = 0 f(x)=x^2-a = 0 f(x)=x2−a=0 的解。
这个方程在 x 0 x_0 x0 点处的切线 L ( x 0 ) L(x_0) L(x0)方程为 f ( x ) − f ( x 0 ) = f ′ ( x 0 ) ( x − x 0 ) f(x)-f(x_0)=f'(x_0)(x-x_0) f(x)−f(x0)=f′(x0)(x−x0).
切线与 x x x 轴有交点, 也就是当 f ( x ) = 0 f(x)=0 f(x)=0, f ′ ( x 0 ) ( x − x 0 ) + f ( x 0 ) = 0 f'(x_0)(x-x_0) + f(x_0) = 0 f′(x0)(x−x0)+f(x0)=0
⇒ x − x 0 = − f ( x 0 ) / f ′ ( x 0 ) \Rightarrow x-x_0 = -f(x_0)/f'(x_0) ⇒x−x0=−f(x0)/f′(x0)
$\Rightarrow x = x_0 - f(x_0)/f’(x_0) = x_0 - (x_0^2-n)/2x_0 = (x_0 + a/x_0)/2 $
得到 x \sqrt{x} x 的第一个近似解 x 1 = ( x 0 + a x 0 ) / 2 x_1=(x_0+\frac{a}{x_0})/2 x1=(x0+x0a)/2.
通常 x 1 x_1 x1 的精度不足,也就是 x 1 2 x_1 ^ 2 x12 和 a a a 相差比较多,因此还需要继续迭代。迭代到第 n n n 次时:
$\Rightarrow x_{n+1} = x_{n} - \frac{f_n}{f’(x_n)} = \frac{1}{2} (x_n + \frac{a}{x_n}) $
只要此时 x n 2 {x_n}^2 xn2 和 a a a 足够接近, 或者迭代次数 n n n 足够大, 都可以停止迭代, 用 x n x_n xn 作为 a \sqrt{a} a.
#include
#include
double m_fabs(double n)
{
return n >= 0.0 ? n : -n;
}
double m_sqrt_newton(double a)
{
// x_{n+1} = \frac{1}{2} (x_n + \frac{a}{x_n})
double x = 1.0; // why?
double eps = 1e-5;
while (m_fabs(x * x - a) > eps)
{
printf("x = %lf\n", x);
x = (x + a / x) / 2.0;
}
return x;
}
int main()
{
double n;
while (true)
{
printf(">>> Please input an double number: ");
scanf("%lf", &n);
//double ans = m_sqrt(n);
//printf("sqrt(%lf) = %lf\n", n, ans);
double ans_newton = m_sqrt_newton(n);
printf("sqrt_newton(%lf) = %lf\n", n, ans_newton);
}
return 0;
}
zz@Legion-R7000P% gcc sqrt.c
zz@Legion-R7000P% ./a.out
>>> Please input an double number: 9.0
x = 1.000000
x = 5.000000
x = 3.400000
x = 3.023529
x = 3.000092
sqrt_newton(9.000000) = 3.000000
>>> Please input an double number: 0.04
x = 1.000000
x = 0.520000
x = 0.298462
x = 0.216241
x = 0.200610
sqrt_newton(0.040000) = 0.200001
>>> Please input an double number: 0.01
x = 1.000000
x = 0.505000
x = 0.262401
x = 0.150255
x = 0.108404
x = 0.100326
sqrt_newton(0.010000) = 0.100001
>>> Please input an double number: ^C
卡马克在雷神之锤游戏中给出了求平方根倒数的一种非常trick的代码实现。把它再求倒数, 就得到开根号结果。
它其实是一种混合方法: 一部分是牛顿法, 另一部分是对数函数的近似。其中牛顿迭代部分用于提升精度, 对数函数的逼近则和 IEEE-754 浮点数表示法紧密结合。
使用的近似公式是 l o g 2 ( 1 + x ) ≈ x + k log_2(1+x) \approx x + k log2(1+x)≈x+k. 见参考[4].
由于 Carmack 快速求平方根的倒数法, 本身目的就是要尽可能快, 因此使用 float 类型而不是 double 类型。
#include
double m_sqrt_carmack(double n)
{
int i;
float x2, y;
const float threehalfs = 1.5f;
x2 = n * 0.5f;
y = (float)n;
i = *(int*)&y;
i = 0x5f3759df - (i >> 1);
y = *(float *)&i;
y = y * (threehalfs - (x2 * y * y)); // 1st iteration
y = y * (threehalfs - (x2 * y * y)); // 2nd iteration
return 1.0 / y;
}
int main()
{
double n;
while (true)
{
printf(">>> Please input an double number: ");
scanf("%lf", &n);
double ans_carmack = m_sqrt_carmack(n);
printf("sqrt_carmack(%lf) = %lf\n", n, ans_carmack);
}
return 0;
}
zz@Legion-R7000P% gcc sqrt.c
zz@Legion-R7000P% ./a.out
>>> Please input an double number: 9.0
sqrt_carmack(9.000000) = 3.000006
>>> Please input an double number: 0.04
sqrt_carmack(0.040000) = 0.200001
>>> Please input an double number: 0.01
sqrt_carmack(0.010000) = 0.100000
>>> Please input an double number: ^C
// Author: Zhuo Zhang
// Homepage: https://github.com/zchrissirhcz
#include
#include
double m_fabs(double n)
{
return n >= 0.0 ? n : -n;
}
double m_sqrt(double n)
{
if (n == 0.0 || n == 1.0)
{
return n;
}
if (n < 0.f)
{
printf("Error: not supported n: %f\n", n);
return -1;
}
double left, right;
if (n > 1.0)
{
left = 1.0;
right = n;
}
else
{
left = n;
right = 1.0;
}
double left_v = left * left - n;
double right_v = right * right - n;
if (left_v * right_v > 0)
{
printf("Error: not exist sqrt for n=%f\n", n);
return -2;
}
const double eps = 1e-5;
while (left <= right)
{
printf("left=%f, right=%f\n", left, right);
double mid = left + (right - left) / 2.0;
double value = mid * mid;
if (value - n > eps)
{
right = mid;
}
else if (value - n < -eps)
{
left = mid;
}
else
{
return mid;
}
}
return 233;
}
double m_sqrt_newton(double a)
{
// x_{n+1} = \frac{1}{2} (x_n + \frac{a}{x_n})
double x = 1.0; // why?
double eps = 1e-5;
while (m_fabs(x * x - a) > eps)
{
printf("x = %lf\n", x);
x = (x + a / x) / 2.0;
}
return x;
}
double m_sqrt_carmack(double n)
{
int i;
float x2, y;
const float threehalfs = 1.5f;
x2 = n * 0.5f;
y = (float)n;
i = *(int*)&y;
i = 0x5f3759df - (i >> 1);
y = *(float *)&i;
y = y * (threehalfs - (x2 * y * y)); // 1st iteration
y = y * (threehalfs - (x2 * y * y)); // 2nd iteration
return 1.0 / y;
}
int main()
{
double n;
while (true)
{
printf(">>> Please input an double number: ");
scanf("%lf", &n);
double ans = m_sqrt(n);
printf("sqrt(%lf) = %lf\n", n, ans);
double ans_newton = m_sqrt_newton(n);
printf("sqrt_newton(%lf) = %lf\n", n, ans_newton);
double ans_carmack = m_sqrt_carmack(n);
printf("sqrt_carmack(%lf) = %lf\n", n, ans_carmack);
}
return 0;
}