Python作为动态语言,现在越来越流行,但是在使用中却未必十全十美,其中运行的性能问题,便是其中之一。当程序中有很多for循环,并且迭代次数很多的情况下,性能问题尤其突出。当然,解决办法也有很多,比如使用Cython便是一个好的解决办法,也可以使用一些第三方的Python库,如,PyOpenCI,PyCuda,Numbia等,但相比之下使用Numbia更为简洁,本文使用Numba来计算并生成Mandelbrot集图片
关于什么是mandelbrot,不属于本文讨论范畴,有兴趣参看维基百科,本文是在一个预设值的复合平面内,按照目标图片的宽高生成复数集合,并对其中每个复数进行Mandelbrot计算迭代,在规定的迭代次数内,检查其是否为Mandelbrot集合内的点,返回迭代次数,并根据迭代次数映射到对应的颜色,最后绘制到一张800*800的PNG图片上,并显示图片。以下就是一张生成的图片。
import numpy as np
from PIL import Image
import time
def mandelbrot(c, maxiter):
z = c
for n in range(maxiter):
if abs(z) > 2:
return n
z = z * z + c
return 0
def mandelbrot_set(xmin, xmax, ymin, ymax, img, maxiter):
width, height = img.size[0], img.size[1]
r1 = np.linspace(xmin, xmax, width)
r2 = np.linspace(ymin, ymax, height)
[img.putpixel((idx1, idx2),
(mandelbrot(complex(r, i), maxiter) << 21) + (mandelbrot(complex(r, i), maxiter) << 10)
+ mandelbrot(complex(r, i), maxiter) * 8) for idx1, r in enumerate(r1) for idx2, i in enumerate(r2)]
bitmap = Image.new("RGB", (800, 800), "white")
start = time.time()
mandelbrot_set(-2.0, 0.5, -1.25, 1.25, bitmap, 100)
print("执行时间 {} 秒".format(round(time.time() - start, 2)))
bitmap.show()
运行上述代码:
执行时间 12.37 秒
import time
from src.utils import *
from numba import jit, guvectorize, complex128, int32
import math
def mandelbrot_set(xmin, xmax, ymin, ymax, width, height, maxiter):
re = np.linspace(xmin, xmax, width, dtype=np.float64)
im = np.linspace(ymin, ymax, height, dtype=np.float64)
c = re + im[:, None]*1j
n3 = mandelbrot_numpy(c, maxiter)
# To handle row exchange issue.
rows, row = n3.shape[0], math.floor(n3.shape[0]/2)
for i in range(row):
n3[[i, rows - 1 - i], :] = n3[[rows - 1 - i, i], :]
return n3
@jit(int32(complex128, int32))
def mandelbrot(c, maxiter):
real = 0
imag = 0
for n in range(maxiter):
nreal = real * real - imag * imag + c.real
imag = 2 * real * imag + c.imag
real = nreal
if real * real + imag * imag > 4.0:
return n
return 0
@guvectorize([(complex128[:], int32[:], int32[:])], '(n),()->(n)', target='parallel')
def mandelbrot_numpy(c, maxit, output):
maxiter = maxit[0]
for i in range(c.shape[0]):
output[i] = mandelbrot(c[i], maxiter)
width = 800
height = 800
max_iter = 100
start = time.time()
n = mandelbrot_set(-2.0, 0.5, -1.25, 1.25, width, height, max_iter)
img = get_image(n, create_palette())
print("迭代执行时间 {} 秒".format(round(time.time() - start, 2)))
img.show()
迭代执行时间 0.26 秒
可以看到运行效率提高100倍以上,而且电脑的风扇也不会像分机起飞一样发出难听的噪音。其核心概念时使用了向量化和并行计算的概念,你可以在Numbia的注解较容易的实现标量参数向量化计算,类似Numpy里的通用函数frompyfunc类似。
当你在机器学习或者大规模计算时,涉及较多循环,或者迭代次数很高的时候,多考虑一下是否可以使用第三方Python库来提高计算效率,可用的手段有Cython,Numba,PyCuda,PyOpenCI等,使用这些手段往往会收到事半功倍的效果。
https://numba.pydata.org/numba-doc/dev/user/vectorize.html