用cupy实现python + cuda编程

1 . 编写核函数进行数据处理

kernerl.py 文件

kernel_img = '''
extern "C"
__global__ void medianfilter(   const float *input,
                                float *output,
                                const int width,
                                const int height)
{
     
  const int idx = blockIdx.x*blockDim.x + threadIdx.x;
  const int idy = blockIdx.y*blockDim.y + threadIdx.y;
  if (idx > width || idy > height)
    return ;
  }

2.加载核函数编译 + 加载调用核函数

cupy_test.py 文件

import numpy as np
from cupy.cuda import  function
from img_cupy import kernel_img
from pynvrtc.compiler import Program
import torch
from collections import namedtuple
from torch.autograd import Variable
import cv2

将核函数从第一行开始从定向到.cu 文件中编译

program = Program(kernel_img,"img.cu") # fuction_name , kernel_name
ptx = program.compile()
m = function.Module()
m.load(bytes(ptx))
img_process = m.get_function('medianfilter')

3.加载测试图像数据

img_name = "/home/wj/project/hdrnet/sample_data/input.png"
img_data = cv2.imread(img_name)
#cv2 load  is H,W,C ,transpose to C,H,W
img0 = img_data.transpose(2,0,1)
#cv2.imshow("img_init",img0[0])
#cv2.waitKey(0)
#just process channel 0
img0_float = img0[0].astype('float32')
input = Variable(torch.from_numpy(img0_float).cuda(), requires_grad=True)
output = Variable(torch.from_numpy(img0_float).cuda(), requires_grad=True)

4. 数据校验是否为cuda 型

if use_cuda:
    assert input.is_cuda,
    'GPU ReLU with fast element-wise CUDA kernel requested but tensors not on GPU'

5.定义cuda 运行的 grid 和 block

IMG_WIDTH = img_data.shape[1]
IMG_HEIGHT =img_data.shape[0]
block = (32,32)
grid = ((IMG_WIDTH + block[0] - 1) // block[0] , (IMG_HEIGHT + block[1] - 1) // block[1])

6.运行kernel 函数

Stream = namedtuple('Stream', ['ptr'])
stream = Stream(ptr=torch.cuda.current_stream().cuda_stream)
img_process(grid=grid,
            block = block,
            args=[input.data_ptr(), output.data_ptr(),IMG_WIDTH , IMG_HEIGHT],
            stream=stream)

7.处理后的cuda数据拷贝出来

out  = output.cpu().detach().numpy().astype(np.uint8)
cv2.imshow("img_output",out)
cv2.waitKey(0)

你可能感兴趣的:(编程,cupy,python)