mpi4py 中的 Op 对象

在上一篇中我们介绍了 mpi4py 中的数据类型解析,下面我们将介绍 mpi4py 中的 Op 对象。

MPI 内置定义了很多算符,可以用在 Reduce,Allreduce,Reduce_scatter,Scan,Exscan 等操作上,但是这些内置定义的算符通常只能使用 MPI 预定义数据类型。不过 MPI 支持自定义操作算符,可以用自定义的算符完成更通用的操作,或者操作自定义的数据类型。

mpi4py 中的 MPI.Op 类提供了对操作算符的基本抽象,MPI 中的预定义算符,如 MPI.MAX, MPI.SUM, MPI.REPLACE,MPI.NO_OP,MPI.OP_NULL 等都是 MPI.Op 类的对象,MPI.Op 类提供了相应的方法来创建新的操作算符。MPI.OP 类的相关方法接口如下:

方法

MPI.Op.Create(type cls, function, bool commute=False)

由函数 function 创建一个自定义算符,commute 指明该操作是否满足交换律。function 的原型接口为 func(a, b, dt),其中 a,b 是操作的数据,dt 是数据类型。在 mpi4py 中,如果使用的是以大写字母开头的方法,如 Allreduce,则传递进来的 a 和 b 是 MPI.memory 对象,dt 是一个 MPI.Datatype 对象;如果使用的是以小写字母开头的方法,如 allreduce,则传递进来的 a 和 b 都是 Python 对象,而 dt 为 None 。

MPI.Op.Free(self)

释放该算符对象。注意:只能释放自定义算符,否则会抛出 MPI.Exception。

MPI.Op.Is_commutative(self)

该算符是否满足交换律,是返回 True,否则返回 False。也可以通过属性 is_commutative 获取。

MPI.Op.Reduce_local(self, inbuf, inoutbuf)

inbufinoutbuf 中的数据执行本地规约操作,结果存放在 inoutbuf 中。不支持 MPI.IN_PLACE。

MPI.Op.__bool__(self)

返回 True 如果该算符不是 MPI.OP_NULL,否则返回 False。

MPI.Op.__call__(self, x, y)

以参数 xy 调用该算符。

属性

MPI.Op.is_commutative

该算符是否满足交换律,是返回 True,否则返回 False。同调用 MPI.Op.Is_commutative 方法的结果。

MPI.Op.is_predefined

该算符是否为预定义的算符。

例程

下面给出使用例程。

# op.py

"""
Demonstrates the usage of MPI.Op.

Run this with 2 processes like:
$ mpiexec -n 2 python op.py
"""

import numpy as np
from mpi4py import MPI


comm = MPI.COMM_WORLD
rank = comm.rank

def mysum_obj(a, b):
    # sum of pythn objects
    return a + b

def mysum_buf(a, b, dt):
    assert dt == MPI.INT
    assert len(a) == len(b)

    def to_nyarray(a):
        # convert a MPI.memory object to a numpy array
        size = len(a)
        buf = np.array(a, dtype='B', copy=False)
        return np.ndarray(buffer=buf, dtype='i', shape=(size / 4,))

    to_nyarray(b)[:] = mysum_obj(to_nyarray(a), to_nyarray(b))

def mysum(ba, bb, dt):
    if dt is None:
        # ba, bb are python objects
        return mysum_obj(ba, bb)
    else:
        # ba, bb are MPI.memory objects
        return mysum_buf(ba, bb, dt)

commute = True
# create a user-defined operator by using function mysum
myop = MPI.Op.Create(mysum, commute)
print 'myop.is_commutative: %s' % myop.is_commutative
print 'myop.is_predefined: %s' % myop.is_predefined

# call the op on different objects
print 'myop(1, 2) = %s' % myop(1, 2)
print 'myop([1], [2]) = %s' % myop([1], [2])
print 'myop(np.array([1]), np.array([2])) = %s' % myop(np.array([1]), np.array([2]))

a = np.arange(3, dtype='i')
b = np.zeros(3, dtype='i')
# use the user-defined myop in allreduce
comm.Allreduce(a, b, op=myop)
# or
# comm.Allreduce([a, MPI.INT], [b, MPI.INT], op=myop)
print 'Allreduce: b = %s' % b

print 'allreduce 2: %s' % comm.allreduce(2, op=myop)
print 'allreduce [2]: %s' % comm.allreduce([2], op=myop)

inbuf = np.arange(4*rank, 4*(rank+1), dtype ='i')
inoutbuf = np.array([10, 10, 10, 10], dtype='i')
myop.Reduce_local(inbuf, inoutbuf)
print 'Reduce_local with myop: %s' % inoutbuf

# free the user-defined op
myop.Free()

# use the predefined op
print 'isinstance(MPI.MAX, MPI.Op): %s' % isinstance(MPI.MAX, MPI.Op)
print 'MPI.MAX.is_predefined: %s' % MPI.MAX.is_predefined

inbuf = np.array([1, 2], dtype='i')
inoutbuf = np.array([2, 0], dtype='i')
MPI.MAX.Reduce_local(inbuf, inoutbuf)
print 'Reduce_local with MPI.MAX: %s' % inoutbuf

try:
    # try to free the predefined op
    MPI.MAX.Free()
except MPI.Exception as e:
    print e.error_string

运行结果如下:

$ mpiexec -n 2 python op.py
myop.is_commutative: True
myop.is_predefined: False
myop(1, 2) = 3
myop([1], [2]) = [1, 2]
myop(np.array([1]), np.array([2])) = [3]
Allreduce: b = [0 2 4]
allreduce 2: 4
allreduce [2]: [2, 2]
Reduce_local with myop: [10 11 12 13]
isinstance(MPI.MAX, MPI.Op): True
MPI.MAX.is_predefined: True
Reduce_local with MPI.MAX: [2 2]
myop.is_commutative: True
myop.is_predefined: False
myop(1, 2) = 3
myop([1], [2]) = [1, 2]
myop(np.array([1]), np.array([2])) = [3]
Allreduce: b = [0 2 4]
allreduce 2: 4
allreduce [2]: [2, 2]
Reduce_local with myop: [14 15 16 17]
isinstance(MPI.MAX, MPI.Op): True
MPI.MAX.is_predefined: True
Reduce_local with MPI.MAX: [2 2]
MPI_ERR_OP: invalid reduce operation
MPI_ERR_OP: invalid reduce operation

以上介绍了 mpi4py 中的 Op 对象,在下一篇中我们将介绍 mpi4py 中的客户端-服务器编程。

你可能感兴趣的:(mpi4py 中的 Op 对象)