带错误检测的CUDA资源管理

        项目中发现很多cuda代码很冗余,主要是有一些cuda内存相关的操作,比如cudaMemcpy之后,要进行错误检测,所以修改了一版,实现接口内部自己检测,这样代码看起来不会那么乱。

        为什么使用宏定义,而不是函数定义,是因为函数定义的话,如果希望在出错误的时候直接输出出错位置的文件和行号,就需要在使用接口的时候,将__FILE__和__LINE__传入进去,有点麻烦,因此使用了宏定义的方式

下面是具体代码:

        

#pragma once
#include 
#include 
#include 

// 向上取整
#define iDIV_UP(a, b) ((a + b - 1) / b)
#define ALG_MAX(x,y) ((x)>(y)?(x):(y))
#define ALG_MIN(x,y) ((x)<(y)?(x):(y))
#define FLOAT_EPS  1e-6
#define FLOAT_EQUAL(v1, v2) ((fabs((v1)-(v2))) < (FLOAT_EPS))

// cblas错误检查
#define CheckCBlasError(ErrorID)        \
{                                       \
    if(CUBLAS_STATUS_SUCCESS != ErrorID)\
    {                                   \
        printf("=====Imaging Error CBlas: %s, line: %d of file: %s\n", cublasGetStatusString(ErrorID), __LINE__, __FILE__);\
        assert(false);                  \
    }                                   \
}

#define CheckCudaError(ErrorId){\
    if (cudaSuccess != ErrorId)\
    { \
        printf("=====Imaging Error Cuda: %s, file: %s : %d\n", cudaGetErrorString(ErrorId), __FILE__, __LINE__);assert(false);\
    }\
}

// Cuda显存释放
#define Cuda_Free(pData){ \
    if (nullptr != pData){\
        cudaError_t error_id = cudaFree(pData);\
        pData = nullptr;\
        CheckCudaError(error_id);\
    }\
}

// Cuda显存设置值
#define Cuda_Memset(devPtr, iValue, iSize){\
    void** ptr = (void**)&devPtr;\
    if (ptr != nullptr){\
        auto error_id = cudaMemset(*ptr, iValue, iSize);\
        CheckCudaError(error_id);\
    }\
}

// 显存申请(这个宏有问题,应该使用后面那个)
// 该宏的问题是,如果外面传进来的pData是个空指针的话,那这个位置访问空指针,肯定是不对的
#define Cuda_Malloc(pData, iSize){\
    Cuda_Free(*pData);\
    cudaError_t error_id = cudaMalloc(pData, iSize);\
    CheckCudaError(error_id);\
}



// 显存申请
// 普通显存申请是cudMalloc(void**)&ptr, size);
// 使用下面宏的话,只需要Cuda_Malloc(ptr, size)即可
#define Cuda_Malloc(pData, iSize){\
    Cuda_Free(pData);\
    auto p_cu_malloc_data = (void**)&pData;\
    cudaError_t error_id = cudaMalloc(p_cu_malloc_data, iSize);\
    CheckCudaError(error_id);\
}


#define Cuda_Memcpy(pDst, pSrc, iSize, cpyKind){\
    auto error_id = cudaMemcpy(pDst, pSrc, iSize, cpyKind);\
    CheckCudaError(error_id);\
}

你可能感兴趣的:(Cuda,开发语言,c++)