cupy系列(一)——自定义elementwise核函数

文章目录

    • 定义elementwise核函数
    • Type-generic 核函数
    • 手动索引

定义elementwise核函数

Type-generic 核函数

>>> squared_diff_generic = cp.ElementwiseKernel(
...     'T x, T y',
...     'T z',
...     'z = (x - y) * (x - y)',
...     'squared_diff_generic')

解释

  • 采用T作为类型占位符,T具体指代的类型依次由:输出,输入。及首先看输出的类型,输出没有制定,则用输入的数据类型。

例子一

import cupy as cp
import numpy as np

squared_diff_generic = cp.ElementwiseKernel(
'T x, T y',
'T z',
'''
T diff = x - y;
z = diff * diff;
''',
'squared_diff_generic')

x = cp.arange(10, dtype=np.float32).reshape(2,5)
y = cp.arange(10, dtype=np.float32).reshape(2,5)
squared_diff_generic(x, y, z)
>>>array([[0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0.]], dtype=float32)

这里的z没有指定类型,所以用的是输入的float32类型

例子二 多个类型占位符

import cupy as cp
import numpy as np

squared_diff_super_generic = cp.ElementwiseKernel(
'X x, Y y',
'Z z',
'z = (x - y) * (x - y)',
'squared_diff_super_generic')

x = cp.arange(10, dtype=np.float32).reshape(2,5)
y = cp.arange(10, dtype=np.float32).reshape(2,5)
squared_diff_super_generic(x, y, z)

>>> NameError: name 'z' is not defined

修改

import cupy as cp
import numpy as np

squared_diff_super_generic = cp.ElementwiseKernel(
'X x, Y y',
'Z z',
'z = (x - y) * (x - y)',
'squared_diff_super_generic')

x = cp.arange(10, dtype=np.float32).reshape(2,5)
y = cp.arange(10, dtype=np.float32).reshape(2,5)
z = cp.empty(10, dtype=np.int16).reshape(2,5)
squared_diff_super_generic(x, y, z)

>>>array([[0, 0, 0, 0, 0],
       [0, 0, 0, 0, 0]], dtype=int16)

手动索引

ElementwiseKernel类使用广播自动执行索引,这对于定义大多数elementwise计算非常有用。另一方面,我们有时想为一些参数编写一个手动索引的内核。我们可以告诉ElementwiseKernel类使用手动索引,方法是在类型说明符之前添加raw关键字。
我们可以使用特殊的变量i和方法_ind.size()进行手动索引。i表示循环中的索引。size()表示应用elementwise操作的元素总数。注意,它表示广播操作之后的大小。

例子一 向量相加,其中一个反序

import cupy as cp
import numpy as np

add_reverse = cp.ElementwiseKernel(
'T x, raw T y', 'T z',
'z = x + y[_ind.size() - i - 1]',
'add_reverse')

x = cp.arange(10, dtype=np.float32)
y = cp.arange(10, dtype=np.float32)
z = cp.empty(10, dtype=np.float32)
add_reverse(x,y,z)

>>> array([9., 9., 9., 9., 9., 9., 9., 9., 9., 9.], dtype=float32)

关键

  • i:循环中的索引
  • _ind.size():元素的总数

你可能感兴趣的:(深度学习,cupy)