python结构体_numba从入门到精通(4)—numba和numpy的结构体

numba对于numpy的支持是最完善的,对于python中的list、dict、tuple等数据类型要不就是不支持优化要不就是支持优化但是使用存在一定的局限性,所以比较建议尽量把输入用numpy的方式表示起来。

有时候为了方便,我们希望numba所修饰的函数能够接受结构体类型的参数该怎么办,因为python中没有显式的struct函数,只有class函数可以用来充当struct来用。

import numpy
class Point():
    """    
    Arguments:
        domain: the domain of random generated coordinates x,y,z, 
                default=1.0
    
    Attributes:
        x, y, z: coordinates of the point
    """
    def __init__(self, domain=1.0):
        self.x = domain * numpy.random.random()
        self.y = domain * numpy.random.random()
        self.z = domain * numpy.random.random()
            
    def distance(self, other):
        return ((self.x - other.x)**2 + 
                (self.y - other.y)**2 + 
                (self.z - other.z)**2)**.5

class Particle(Point):
    """    
    Attributes:
        m: mass of the particle
        phi: the potential of the particle
    """
    
    def __init__(self, domain=1.0, m=1.0):
        Point.__init__(self, domain) ##父类的初始化否则的话一般是 Particle.__inti__()
        self.m = m
        self.phi = 0.

这里我们就定义了一个叫Particle的对象,直接调用他的属性就会产生类似结构体的功能了。接下来我们产生1000个这样的结构体表示1000个数据集,然后放入普通函数中计算

n = 1000
particles = [Particle(m = 1 / n) for i in range(n)]
def direct_sum(particles):
    """
    Calculate the potential at each particle
    using direct summation method.

    Arguments:
        particles: the list of particles

    """
    for i, target in enumerate(particles):
        for source in (particles[:i] + particles[i+1:]):
            r = target.distance(source)
            target.phi += source.m / r

orig_time = %timeit -o direct_sum(particles)

通过上面的计时函数得到的结果:

a5cb177d0ba48affdf07f3b77c9cd36e.png
@jit(nopython=True)
def direct_sum(particles):
    """
    Calculate the potential at each particle
    using direct summation method.

    Arguments:
        particles: the list of particles

    """
    for i, target in enumerate(particles):
        for source in (particles[:i] + particles[i+1:]):
            r = target.distance(source)
            target.phi += source.m / r

orig_time = %timeit -o direct_sum(particles)

python结构体_numba从入门到精通(4)—numba和numpy的结构体_第1张图片

报错,numba无法识别python中的类class

那么如何解决这个问题?

其实,numpy有一个很有意思的功能可以用来实现类似结构体的功能,而且调用的效率要比通过class来定义的结构体高太多。

particle_dtype = numpy.dtype({'names':['x','y','z','m','phi'], 
                             'formats':[numpy.double, 
                                        numpy.double, 
                                        numpy.double, 
                                        numpy.double, 
                                        numpy.double]})

上面定义了一个numpy的数据类型“particle_dtype”,类似于c中的struct,首先是组员的名字分为:'x','y','z','m','phi',然后通过formats定义每一个组员的数据类型。

def create_n_random_particles(n, m, domain=1):
    '''
    Creates `n` particles with mass `m` with random coordinates
    between 0 and `domain`
    '''
    parts = numpy.zeros((n), dtype=particle_dtype)
    
    parts['x'] = numpy.random.random(size=n) * domain
    parts['y'] = numpy.random.random(size=n) * domain
    parts['z'] = numpy.random.random(size=n) * domain
    parts['m'] = m
    parts['phi'] = 0.0

    return parts
parts = create_n_random_particles(1000, .001, 1)

接下来我们把上述的结构体定义实例化,

parts = create_n_random_particles(1000, .001, 1)
print(parts[:5])

python结构体_numba从入门到精通(4)—numba和numpy的结构体_第2张图片

可以看到一个很有意思的结果,这就是通过numpy定义的“结构体”的实例。现在我们才可以不出错的用numba来编译刚刚的代码

@njit
def direct_sum(particles):
    for i, target in enumerate(particles):
        for j in range(particles.shape[0]):
            if i == j:
                continue
            source = particles[j]
            r = distance(target, source)
            target['phi'] += source['m'] / r
%timeit direct_sum(parts)

7e5f1fc24a9deb80b595a2a9f125f266.png

比较一下二者的运行速度:

orig_time.best / numba_time.best

e288e8eca81787e9cc75faa355598dd7.png

速度提高了300多倍。

正常,本身python中的class用来定义结构体就是又臭又长,一大堆的属性,一大堆的检查,和numpy这种轻量型的结构体没法比。

你可能感兴趣的:(python结构体)