【分治】大整数乘法Python实现

文章目录

    • @[toc]
      • 问题描述
      • 基础算法
        • 时间复杂性
      • 优化算法
        • 时间复杂性
      • `Python`实现

问题描述

  • X X X Y Y Y都是 n n n位二进制整数,计算它们的乘积 X Y XY XY

基础算法

  • n n n位二进制整数 X X X Y Y Y都分为 2 2 2段,每段的长为 n / 2 n / 2 n/2位(假设 n n n 2 2 2的幂)
    • X = A × 2 n / 2 + B X = A \times 2^{n / 2} + B X=A×2n/2+B
    • Y = C × 2 n / 2 + D Y = C \times 2^{n / 2} + D Y=C×2n/2+D
    • X Y = ( A × 2 n / 2 + B ) ( C × 2 n / 2 + D ) = A C × 2 n + ( A D + B C ) × 2 n / 2 + B D XY = (A \times 2^{n / 2} + B) (C \times 2^{n / 2} + D) = AC \times 2^{n} + (AD + BC) \times 2^{n / 2} + BD XY=(A×2n/2+B)(C×2n/2+D)=AC×2n+(AD+BC)×2n/2+BD
时间复杂性
  • 如果按此式计算 X Y XY XY,必须进行 4 4 4 n / 2 n / 2 n/2位整数的乘法, 3 3 3次不超过 2 n 2n 2n位的整数加法,以及 2 2 2次移位,所有这些加法和移位共用 O ( n ) O(n) O(n)步运算

T ( n ) = { O ( 1 ) n = 1 4 T ( n / 2 ) + O ( n ) n > 1 T(n) = \begin{cases} O(1) & n = 1 \\ 4 T(n / 2) + O(n) & n > 1 \end{cases} T(n)={O(1)4T(n/2)+O(n)n=1n>1

T ( n ) = O ( n 2 ) T(n) = O(n^{2}) T(n)=O(n2)


优化算法

  • 要想改进算法的计算复杂性,必须减少乘法次数,把 X Y XY XY写成另一种形式
    • X Y = A C × 2 n + ( ( A − B ) ( D − C ) + A C + B D ) × 2 n / 2 + B D XY = AC \times 2^{n} + ((A - B)(D - C) + AC + BD) \times 2^{n / 2} + BD XY=AC×2n+((AB)(DC)+AC+BD)×2n/2+BD
时间复杂性
  • 需做 3 3 3 n / 2 n / 2 n/2位整数的乘法, 6 6 6次加减法和 2 2 2次移位

T ( n ) = { O ( 1 ) n = 1 3 T ( n / 2 ) + O ( n ) n > 1 T(n) = \begin{cases} O(1) & n = 1 \\ 3 T(n / 2) + O(n) & n > 1 \end{cases} T(n)={O(1)3T(n/2)+O(n)n=1n>1

T ( n ) = O ( n log ⁡ 3 ) = O ( n 1.59 ) T(n) = O(n^{\log{3}}) = O(n^{1.59}) T(n)=O(nlog3)=O(n1.59)


Python实现

def karatsuba_multiply(x, y):
    # 如果乘数之一为 0, 则直接返回 0
    if x == 0 or y == 0:
        return 0

    # 将乘数转换为字符串, 并获取它们的位数
    x_str = str(x)
    y_str = str(y)

    n = max(len(x_str), len(y_str))

    # 达到基本情况时, 使用传统的乘法
    if n == 1:
        return x * y

    # 将乘数补齐到相同的位数
    x_str = x_str.zfill(n)
    y_str = y_str.zfill(n)

    # 将乘数划分为两部分
    m = n // 2

    high1, low1 = int(x_str[:m]), int(x_str[m:])
    high2, low2 = int(y_str[:m]), int(y_str[m:])

    # 递归地计算三个乘法
    z0 = karatsuba_multiply(low1, low2)
    z1 = karatsuba_multiply((low1 + high1), (low2 + high2))
    z2 = karatsuba_multiply(high1, high2)

    # 计算结果
    res = (z2 * 10 ** (2 * m)) + ((z1 - z2 - z0) * 10 ** m) + z0

    return res


x = 123456789012345678901234567890
y = 987654321098765432109876543210

res = karatsuba_multiply(x, y)

print(f'{x} * {y} = {res}')

你可能感兴趣的:(算法,分治算法,Python)