我在用卷积定理做卷积时遇到的一个问题就两矩阵做完FFT之后都是为复数怎么用pycuda做矩阵相乘,在给GPU传递参数的时候总是有问题。
通过几天的摸索给出下面代码可完成矩阵的相乘,数据类型为complex。其中值得注意的地方就是对于调用了from jinja2 import Template
这个编译器,不然用 sourceModule总是会有错误,直接在用#include “complex.h”中的complex 去定义 complex * a 是不可以。
由于C++ 本人没学过,不好说,pycuda-complex.hpp这个应该是C++的头文件。 Jinja2 模板中执行任意代码 ,可以看看这个代码利用 Python 特性在 Jinja2 模板中执行任意代码有兴趣可以看看。
#-*- coding: utf-8 -*-
import pycuda.autoinit
import pycuda.driver as drv
from pycuda.compiler import SourceModule
from jinja2 import Template
import numpy as np
KERNEL = Template("""
#include <stdio.h>
#include <pycuda-complex.hpp>
typedef pycuda::complex<float> scmplx;
typedef pycuda::complex<double> dcmplx;
__global__ void complex_mat_mul(const {{complex_type}} *a, const {{complex_type}} *b, {{complex_type}} *res)
{
int row = threadIdx.y;
int col = threadIdx.x;
int mat_id = blockIdx.x * gridDim.x + blockIdx.y;
{{complex_type}} entry = 0;
for (int e = 0; e < {{mat_dim}}; ++e) {
entry += a[mat_id*{{mat_dim}}*{{mat_dim}} + row * {{mat_dim}} + e] * b[mat_id*{{mat_dim}}*{{mat_dim}} + e * {{mat_dim}} + col];
}
res[mat_id*{{mat_dim}}*{{mat_dim}} + row * {{mat_dim}} + col] = entry;
}
""")
data_types = {
'scmplx': np.complex64,
'dcmplx': np.complex128,
'float': np.float32,
'double': np.float64
}
def render_kernel(complex_type, real_type, mat_dim, block, gird):
templ = KERNEL.render(
complex_type=complex_type,
real_type=real_type,
mat_dim=mat_dim,
blockDim_x=block[0],
blockDim_y=block[1]
)
# print(templ)
return templ
complex_type = 'dcmplx'
real_type = 'double'
mat_dim = 4
block = (mat_dim,mat_dim,1)
grid = (1,1)
program = SourceModule(render_kernel(complex_type, real_type, mat_dim, block, grid))
complex_mat_mul = program.get_function("complex_mat_mul")
mats_1 = np.array((
[[1,1,1,0],
[0,1,1,1],
[0,0,1,1],
[0,0,1,1]
]), dtype=np.complex128)
mats_2 = np.array((
[[1,1,1,0],
[0,1,1,1],
[0,0,1,1],
[0,0,1,1]
]), dtype=np.complex128)
result = mats_1.copy()
result[:] = np.nan
a = drv.In(mats_1)
b = drv.In(mats_2)
c = drv.Out(result)
start = time.time()
complex_mat_mul(a, b, c,
block=block,
grid=grid
)
print(result.real)
给出代码的执行结果: