Diffie-Hellman密钥交换算法的原理及程序演示


在http://en.wikipedia.org/wiki/Diffie-Hellman上面给出了这个密钥交换协议的历史,原理,重要文献的链接,以及演示代码。它的数学基础就是离散对数这个数学难题。用它进行密钥交换的过程简述如下:
选取两个大数p和g并公开,其中p是一个素数,g是p的一个模p本原单位根(primitive root module p),所谓本原单位根就是指在模p乘法运算下,g的1次方,2次方……(p-1)次方这p-1个数互不相同,并且取遍1到p-1;
对于Alice(其中的一个通信者),随机产生一个整数a,a对外保密,计算Ka = g^a mod p,将Ka发送给Bob;
对于Bob(另一个通信者),随机产生一个整数b,b对外保密,计算Kb = g^b mod p,将Kb发送给Alice;
在Alice方面,收到Bob送来的Kb后,计算出密钥为:key = Kb^a mod p=(g^b)^a=g^(b*a) mod p;
对于Bob,收到Alice送来的Ka后,计算出密钥为:key = Ka ^ b mod p=(g^a)^b=g^(a*b) mod p。
攻击者知道p和g,并且截获了Ka和Kb,但是当它们都是非常大的数的时候,依靠这四个数来计算a和b非常困难,这就是离散对数数学难题。
要实现Diffie-Hellman密钥交换协议,需要能够快速计算大数模幂,在模幂算法中,仍需计算大数的乘法和模运算,所以整个过程需要三个算法:高精度乘法,高精度除法(用来同时求出一个大数除以另一个大数的商和余数),快速模幂算法。
高精度的乘法和除法可以程序模拟手算。快速模幂算法也是从手算中总结出规律来,例如:
5^8 = (5^2)^4 = (25)^4 = (25^2)^2 = (625)^2,这样,原来计算5^8需要做8次乘法,而现在则只需要三次乘法,分别是:5^2, 25^2, 625^2。这就是快速模幂算法的基础。将算法描述出来,那就是:
算法M:输入整数a,b,p,计算a^b mod p:
M1.初始化c = 1
M2.如果b为0, 则c就是所要计算的结果。返回c的值。算法结束。
M3.如果b为奇数,则令c = c *a mod p, 令b = b-1,转到M2。
M4.如果b为偶数,则令a = a * a mod p, 令b = b / 2,转到M2。
高精度试除法原理简单,但是代码实现起来需要仔细考虑一些细节。
我的演示代码如下:
高精度运算类:
  1. class SuperNumber {
  2. public:
  3.     SuperNumber() {
  4.         memset(data, 0, MAX_SIZE);
  5.         high = 0;
  6.     }
  7.     // 一般整型到SuperNumber的转换,该版本中不支持负数
  8.     SuperNumber(unsigned long l) {
  9.         memset(data, 0, MAX_SIZE);
  10.         high = 0;
  11.         while(l) {
  12.             data[++high] = l % 10;
  13.             l /= 10;
  14.         }
  15.     }
  16.     // str为字符串形式表示的十进制数
  17.     SuperNumber(const char* str) {
  18.         assert(str != NULL);
  19.         high = strlen(str);
  20.         for(int i = high, j = 0; i >= 1; i--, j++) {
  21.             data[i] = str[j] - '0';
  22.         }
  23.     }
  24.     SuperNumber(const SuperNumber& s) {
  25.         memcpy(data, s.data, MAX_SIZE);
  26.         high = s.high;
  27.     }
  28.     operator const char*() const {
  29.         return toString(10);
  30.     }
  31.     SuperNumber& operator=(const SuperNumber& s) {
  32.         if(this != &s) {
  33.             memcpy(data, s.data, MAX_SIZE);
  34.             high = s.high;
  35.         }
  36.         return *this;
  37.     }
  38.     // 将数据置为0
  39.     void reset() {
  40.         memset(data, 0, MAX_SIZE);
  41.         high = 0;
  42.     }
  43.     // str为字符串形式表示的十进制数
  44.     void setToStr(const char* str) {
  45.         assert(str != NULL);
  46.         high = strlen(str);
  47.         for(int i = high, j = 0; i >= 1; i--, j++) {
  48.             data[i] = str[j] - '0';
  49.         }
  50.     }
  51.     // 将数据转换成以base指定的进制的字符串形式,默认为十进制
  52.     const char* toString(int base = 10) const {
  53.         static char buf[MAX_SIZE];
  54.         const char table[] = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ";
  55.         if(high == 0) return "0";
  56.         assert(base >= 2);      // 指定的进制应该不小于2
  57.         // 进制转换
  58.         buf[MAX_SIZE-1] = '/0';
  59.         int begin = MAX_SIZE-1;
  60.         char temp[MAX_SIZE];
  61.         memcpy(temp, data, MAX_SIZE);
  62.         while(1) {
  63.             // 找最高位的起始位置
  64.             int h = high;
  65.             while(temp[h] == 0 && h >= 1) h--;
  66.             if(h == 0) break;
  67.             // 除基取余
  68.             int t = 0;
  69.             while(h >= 1) {
  70.                 t = t * 10 + temp[h];
  71.                 temp[h] = t / base;
  72.                 t = t % base;
  73.                 h--;
  74.             }
  75.             buf[--begin] = table[t];
  76.         }
  77.         return buf+begin;
  78.     }
  79.     // 乘法
  80.     SuperNumber operator*(const SuperNumber& s) const {
  81.         SuperNumber result;     // default set to 0
  82.         int i, j;
  83.         
  84.         // 相乘
  85.         for(i = 1; i <= high; i++) {
  86.             for(j = 1; j <= s.high; j++) {
  87.                 int k = data[i] * s.data[j] + result.data[i+j-1];
  88.                 result.data[i+j-1] = k % 10;
  89.                 result.data[i+j] += k / 10;
  90.             }
  91.         }
  92.         // 进位
  93.         for(i = 1; i < MAX_SIZE - 1; i++) {
  94.             if(result.data[i] >= 10) {
  95.                 result.data[i+1] += result.data[i] / 10;
  96.                 result.data[i] %= 10;
  97.             }
  98.         }
  99.         // 确定最高位
  100.         for(i = MAX_SIZE-1; i >= 1 && result.data[i] == 0; i--);
  101.         result.high = i;
  102.         return result;
  103.     }
  104.     // 除法,内部调用doDivide来实现
  105.     SuperNumber operator/(const SuperNumber& s) const {
  106.         SuperNumber q, r;
  107.         doDivide(s, q, r);
  108.         return q;
  109.     }
  110.     // 模运算,内部调用doDivide来实现
  111.     SuperNumber operator%(const SuperNumber& s) const {
  112.         SuperNumber q, r;
  113.         doDivide(s, q, r);
  114.         return r;
  115.     }
  116.     // 除法运算,一次除法运算中同时得到商和余数,运算符/和%的重载
  117.     // 内部会调用这个函数,dest为除数,Q为商,R为余数,算法使用试除法
  118.     void doDivide(const SuperNumber& dest, SuperNumber& Q, SuperNumber& R) const {
  119.         int i, j, t;
  120.         Q.reset();
  121.         Q.high = high - dest.high + 1; // 商的初始位数
  122.         R = *this;                     // 余数初始实为被除数
  123.         t = dest.high;
  124.         // 判断除法是否结束
  125.         while(R >= dest) {
  126.             // 循环相减进行试除
  127.             while(dest >= R.sub(1, t)) {
  128.                 Q.data[Q.high--] = 0;
  129.                 ++t;
  130.             }
  131.             while(R.sub(1, t) >= dest) {
  132.                 // i为相减时被除数最低下标,j为除数最低下标
  133.                 for(i=R.high-t+1,j=1; j<=dest.high; i++,j++) {
  134.                     R.data[i] -= dest.data[j];
  135.                     if(R.data[i] < 0) {
  136.                         R.data[i] += 10;
  137.                         R.data[i+1] -= 1;
  138.                     }
  139.                 }
  140.                 while(R.data[i] < 0 && i <= R.high) {
  141.                     R.data[i] += 10;
  142.                     R.data[i+1] -= 1;
  143.                     ++i;
  144.                 }
  145.                 Q.data[Q.high] += 1;
  146.             }
  147.             // 一次试除结束,更新商的最高位下标
  148.             Q.high -= 1;
  149.             // 更新被除数的最高位下标
  150.             while(R.data[R.high] == 0) {
  151.                 R.high--;
  152.                 t--;
  153.             }
  154.             t+=1;               // 下一位被除数
  155.         }
  156.         Q.high = high - dest.high + 1; 
  157.         while(Q.data[Q.high] == 0) Q.high -= 1;
  158.         R.high = high;
  159.         while(R.data[R.high] == 0) R.high -= 1;
  160.     }
  161.     // 大数模幂算法,很简单的自然算法,即将指数分解为二进制,换句
  162.     // 更简单的话来说,就是不断地找平方模幂,而不是全部乘方后再
  163.     // 做一次最终的模运算
  164.     // a.power_mod(p, n)计算a^p mod n
  165.     SuperNumber power_mod(int power, SuperNumber n) const {
  166.         SuperNumber c("1"), t(*this);
  167.         while(power) {
  168.             if(power % 2) {
  169.                 c = c * t % n;
  170.                 power -= 1;
  171.             } else {
  172.                 t = t * t % n;
  173.                 power /= 2;
  174.             }
  175.         }
  176.         return c%n;
  177.     }
  178.     bool operator>=(const SuperNumber& s) const {
  179.         if(high == s.high) {
  180.             int k = high;
  181.             while(data[k] == s.data[k] && k >= 1)k--;
  182.             if(k < 1) return true; // equal
  183.             return data[k] > s.data[k];
  184.         } else if(high > s.high) return true;
  185.         return false;        
  186.     }
  187.     bool operator<(const SuperNumber& s) const {
  188.         return !(*this >= s);
  189.     }
  190.     // 从十进制表示的最高位开始数起,数到第i位,从第i位开始截取连续
  191.     // 的c位数字出来组成一个新的数。例如:数据是12345678925698,则
  192.     // sub(3, 5)返回数字34567,如果数字不够取,例如在34567上运行
  193.     // sub(3, 5),因为34567从高位数起第3个数字是5,剩下的数字是567,
  194.     // 最多只有三个,不够取要求的5个,这个时候返回567,不报错。
  195.     SuperNumber sub(int i, int c) const {
  196.         SuperNumber ret;
  197.         assert(high >= i);   // 保证可截取
  198.         i = high - i + 1;    // 从高位数起的第i个数位的下标
  199.         if(i >= c) {
  200.             ret.high = c;
  201.             while(c >= 1) ret.data[c--] = data[i--];
  202.         } else {
  203.             ret.high = i;
  204.             while(i >= 1) {
  205.                 ret.data[i] = data[i];
  206.                 i--;
  207.             }
  208.         } 
  209.         // 过滤前导0
  210.         while(ret.data[ret.high] == 0) ret.high--;
  211.         return ret;
  212.     }
  213.     // I/O
  214.     friend istream& operator>>(istream& in, SuperNumber& s) {
  215.         char t[256];
  216.         in >> t;
  217.         s.setToStr(t);
  218.         return in;
  219.     }
  220.     friend ostream& operator<<(ostream& out, const SuperNumber& s) {
  221.         return out << s.toString(10);
  222.     }
  223. private:
  224.     enum {MAX_SIZE=256};        // 最大十进制位数
  225.     // 须注意,使用data[0]存储最高位所在下标是自己的一点小聪明,后来在
  226.     // 调试的时候发现这是一个极大的错误,不过对于此题目来说可以应付
  227.     char data[MAX_SIZE];        // 数据的内部表示,字符串形式的十进制
  228.                                 // 其中data[0]存储最高位所在下标,data[1]
  229.                                 // 存储数据的最低位,也就是个位
  230.     int high;
  231. };

主函数:
  1. int main(int argc, char **argv) {
  2.     freopen("in.txt", "r", stdin);
  3.     SuperNumberTest st;
  4. //    st.run();
  5.     // g和n都是超过2^127的素数。它们在DH算法中公开
  6.     SuperNumber g, n;
  7.     int a, b;
  8.     SuperNumber ka, kb, key;
  9.     srand(time(0));
  10.     cin >> g >> n;
  11.     cout << "g = " << g << endl
  12.          << "n = " << n << endl;
  13.     cout << "/nThis is Alice:/n";
  14.     a = rand();
  15.     cout << "Alice get a random integer a = " << a << endl;
  16.     cout << "Alice computer g^a mod n:/n";
  17.     ka = g.power_mod(a, n);
  18.     cout << "Alice compute out ka = " << ka << endl;
  19.     cout << "/nThis is Bob:/n";
  20.     b = rand();
  21.     cout << "Bob get a random integer b = " << b << endl;
  22.     cout << "Bob compute g^b mod n:/n";
  23.     kb = g.power_mod(b, n);
  24.     cout << "Bob compute out kb = " << kb << endl;
  25.     cout << "/nAlice get kb from Bob, she compute out key is:/n";
  26.     cout << kb.power_mod(a, n) << endl;
  27.     cout << "/nBob get ka from Alice, he compute out key is:/n";
  28.     cout << ka.power_mod(b, n) << endl;
  29.     return 0;
  30. }

运行结果:
g = 170141183460469231731687303715884105757
n = 170141183460469231731687303715884106309
 
This is Alice:
Alice get a random integer a = 20276
Alice computer g^a mod n:
Alice compute out ka = 102075421398841759242347870420481896337
 
This is Bob:
Bob get a random integer b = 28664
Bob compute g^b mod n:
Bob compute out kb = 62348451302684698452476840835428450852
 
Alice get kb from Bob, she compute out key is:
80402514625208456390620786920929643017
 
Bob get ka from Alice, he compute out key is:
80402514625208456390620786920929643017
在 http://oldpiewiki.yoonkn.com/cgi-bin/moin.cgi/DiffieHellmanKeyExchange 上面给出了使用GNU gmp库实现的这个算法的演示代码

你可能感兴趣的:(Algorithm,网络安全)