最快计算Mandelbrot的Python代码

Python作为动态语言,现在越来越流行,但是在使用中却未必十全十美,其中运行的性能问题,便是其中之一。当程序中有很多for循环,并且迭代次数很多的情况下,性能问题尤其突出。当然,解决办法也有很多,比如使用Cython便是一个好的解决办法,也可以使用一些第三方的Python库,如,PyOpenCI,PyCuda,Numbia等,但相比之下使用Numbia更为简洁,本文使用Numba来计算并生成Mandelbrot集图片

例子说明

关于什么是mandelbrot,不属于本文讨论范畴,有兴趣参看维基百科,本文是在一个预设值的复合平面内,按照目标图片的宽高生成复数集合,并对其中每个复数进行Mandelbrot计算迭代,在规定的迭代次数内,检查其是否为Mandelbrot集合内的点,返回迭代次数,并根据迭代次数映射到对应的颜色,最后绘制到一张800*800的PNG图片上,并显示图片。以下就是一张生成的图片。
最快计算Mandelbrot的Python代码_第1张图片

1、一般的Python代码实现

import numpy as np
from PIL import Image
import time


def mandelbrot(c, maxiter):
    z = c
    for n in range(maxiter):
        if abs(z) > 2:
            return n
        z = z * z + c
    return 0


def mandelbrot_set(xmin, xmax, ymin, ymax, img, maxiter):
    width, height = img.size[0], img.size[1]
    r1 = np.linspace(xmin, xmax, width)
    r2 = np.linspace(ymin, ymax, height)

    [img.putpixel((idx1, idx2),
                  (mandelbrot(complex(r, i), maxiter) << 21) + (mandelbrot(complex(r, i), maxiter) << 10)
                  + mandelbrot(complex(r, i), maxiter) * 8) for idx1, r in enumerate(r1) for idx2, i in enumerate(r2)]


bitmap = Image.new("RGB", (800, 800), "white")

start = time.time()
mandelbrot_set(-2.0, 0.5, -1.25, 1.25, bitmap, 100)
print("执行时间 {} 秒".format(round(time.time() - start, 2)))

bitmap.show()

运行上述代码:

执行时间 12.37

2、 使用numba进行优化

import time
from src.utils import *
from numba import jit, guvectorize, complex128, int32
import math


def mandelbrot_set(xmin, xmax, ymin, ymax, width, height, maxiter):
    re = np.linspace(xmin, xmax, width, dtype=np.float64)
    im = np.linspace(ymin, ymax, height, dtype=np.float64)
    c = re + im[:, None]*1j

    n3 = mandelbrot_numpy(c, maxiter)

    # To handle row exchange issue.
    rows, row = n3.shape[0], math.floor(n3.shape[0]/2)
    for i in range(row):
        n3[[i, rows - 1 - i], :] = n3[[rows - 1 - i, i], :]

    return n3


@jit(int32(complex128, int32))
def mandelbrot(c, maxiter):
    real = 0
    imag = 0
    for n in range(maxiter):
        nreal = real * real - imag * imag + c.real
        imag = 2 * real * imag + c.imag
        real = nreal
        if real * real + imag * imag > 4.0:
            return n
    return 0


@guvectorize([(complex128[:], int32[:], int32[:])], '(n),()->(n)', target='parallel')
def mandelbrot_numpy(c, maxit, output):
    maxiter = maxit[0]
    for i in range(c.shape[0]):
        output[i] = mandelbrot(c[i], maxiter)


width = 800
height = 800
max_iter = 100

start = time.time()
n = mandelbrot_set(-2.0, 0.5, -1.25, 1.25, width, height, max_iter)
img = get_image(n, create_palette())
print("迭代执行时间 {} 秒".format(round(time.time() - start, 2)))

img.show()
迭代执行时间 0.26

可以看到运行效率提高100倍以上,而且电脑的风扇也不会像分机起飞一样发出难听的噪音。其核心概念时使用了向量化和并行计算的概念,你可以在Numbia的注解较容易的实现标量参数向量化计算,类似Numpy里的通用函数frompyfunc类似。

总结

当你在机器学习或者大规模计算时,涉及较多循环,或者迭代次数很高的时候,多考虑一下是否可以使用第三方Python库来提高计算效率,可用的手段有Cython,Numba,PyCuda,PyOpenCI等,使用这些手段往往会收到事半功倍的效果。

参考资料

https://numba.pydata.org/numba-doc/dev/user/vectorize.html

你可能感兴趣的:(Python基础,机器学习系列)