CK草稿本

调用流程

    1. 获得op_ptr,ck有个工厂模式:
    const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory<DeviceOp>::GetInstances();
    
    1. 设置参数,这些参数包括输入输出,以及其他必要的配置
    auto argument_ptr = op_ptr->MakeArgumentPointer(a_device_buf.GetDeviceBuffer(),
                            b_device_buf.GetDeviceBuffer(),
                            c_device_buf.GetDeviceBuffer(),
                            M,
                            N,
                            K,
                            StrideA,
                            StrideB,
                            StrideC,
                            a_element_op,
                            b_element_op,
                            c_element_op);
    
    1. 获得invoker_ptr:auto invoker_ptr = op_ptr->MakeInvokerPointer();
    1. run:float ave_time = invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, true});
    1. 结果后处理

Invoker

  • 有一个基类BaseInvoker,定义了赋值拷贝,和Run函数(用于算子运行),以及一个虚析构
    • 地址:include/ck/tensor_operation/gpu/device/device_base.hpp
  • 然后每个算子里面会实现一个Invoker,来实现run的操作
    struct BaseInvoker
    {
        BaseInvoker()                   = default;
        BaseInvoker(const BaseInvoker&) = default;
        BaseInvoker& operator=(const BaseInvoker&) = default;
    
        virtual float Run(const BaseArgument*, const StreamConfig& = StreamConfig{})
        {
            return float{0};
        }
    
        virtual ~BaseInvoker() {}
    };
    
    
    struct Invoker : public BaseInvoker
    {
        float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
        {
            // run kernel ....
            // cost time ....
        };
    
        float Run(const BaseArgument* p_arg,
                    const StreamConfig& stream_config = StreamConfig{}) override
        {
            return Run(*dynamic_cast(p_arg), stream_config);
        };
    };
    

Argument

  • 同样有个基类BaseArgument,有一个p_workspace_的void指针参数,暂不清楚做啥的
    • 地址:include/ck/tensor_operation/gpu/device/device_base.hpp
  • 而每个Operator中都会定义一个Argument子类,里面存一些输入输出,配置等参数
    struct BaseArgument
    {
        BaseArgument()                    = default;
        BaseArgument(const BaseArgument&) = default;
        BaseArgument& operator=(const BaseArgument&) = default;
    
        virtual ~BaseArgument() {}
    
        void* p_workspace_ = nullptr;
    };
    
    struct Argument : public ck::tensor_operation::device::BaseArgument
    {
        Argument(const Tensor& a_gs_ms_ks,
                    const Tensor& b_gs_ns_ks,
                    Tensor& e_gs_ms_ns,
                    AElementwiseOperation a_element_op,
                    BElementwiseOperation b_element_op,
                    CDEElementwiseOperation cde_element_op)
            : a_gs_ms_ks_{a_gs_ms_ks},
                b_gs_ns_ks_{b_gs_ns_ks},
                e_gs_ms_ns_{e_gs_ms_ns},
                a_element_op_{a_element_op},
                b_element_op_{b_element_op},
                cde_element_op_{cde_element_op}
        {
        }
    
        const Tensor& a_gs_ms_ks_;
        const Tensor& b_gs_ns_ks_;
        Tensor& e_gs_ms_ns_;
    
        AElementwiseOperation a_element_op_;
        BElementwiseOperation b_element_op_;
        CDEElementwiseOperation cde_element_op_;
    };
    

Operator

  • 基类叫BaseOperator,定义如下函数 都是一些比较通用的基础属性:
    • IsSupportedArgument
    • GetTypeString
    • GetTypeIdName
    • GetTypeIdHashCode
    • GetWorkSpaceSize
    • SetWorkSpacePointer
  • 通常子类中需要有定义:
    • struct Argument/MakeArgumentPointer
    • struct Invoke/MakeInvokerPointer
    struct BaseOperator
    {
        BaseOperator()                    = default;
        BaseOperator(const BaseOperator&) = default;
        BaseOperator& operator=(const BaseOperator&) = default;
    
        virtual bool IsSupportedArgument(const BaseArgument*) { return false; }
        virtual std::string GetTypeString() const { return ""; }
    
        virtual std::string GetTypeIdName() const { return typeid(*this).name(); }
    
        virtual std::string GetTypeIdHashCode() const
        {
            std::ostringstream oss;
    
            oss << std::hex << typeid(*this).hash_code();
    
            return oss.str();
        };
    
        virtual size_t GetWorkSpaceSize(const BaseArgument*) const { return 0; }
    
        virtual void SetWorkSpacePointer(BaseArgument* p_arg, void* p_workspace) const
        {
            assert(p_arg);
            p_arg->p_workspace_ = p_workspace;
        }
    
        virtual ~BaseOperator() {}
    };
    

DeviceOperationInstanceFactory

  • library/include/ck/library/tensor_operation_instance/device_operation_instance_factory.hpp
    • 在这个文件中声明了工厂,也就是:
        template 
        struct DeviceOperationInstanceFactory;
    
  • library/include/ck/library/tensor_operation_instance/add_device_operation_instance.hpp
    • 这里面有个add_device_operation_instances方法,定义了将op实现加入到vector(instance)中
  • 在这之上,有一些函数是用于添加这些instance的,比如device_gemm_dl_f16_f16_f16_km_kn_mn_instances
    • 位于library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_f16_f16_f16_km_kn_mn_instance.cpp
    • 原理就是把tuple中的元素在add_device_operation_instances中全部加到vector中去
    using device_gemm_dl_f16_f16_f16_km_kn_mn_instances = std::tuple<
            // MPerBlock=8, NPerBlock=8
            DeviceGemmDl<.....>,
            DeviceGemmDl<.....>,
            DeviceGemmDl<.....>,
            .....
        >;
    
    void add_device_gemm_dl_f16_f16_f16_km_kn_mn_instances(
        std::vector>>&
            instances)
    {
        add_device_operation_instances(instances, device_gemm_dl_f16_f16_f16_km_kn_mn_instances{});
    }
    
  • 然后这个函数会在DeviceOperationInstanceFactory中的GetInstances中被调用到,于是就得到了一个vector数组,里面装满了invoke_ptr实现
    • 对于上面这个例子,在这个文件中被调用到:library/include/ck/library/tensor_operation_instance/gpu/gemm.hpp

案例

  • client_example/01_gemm/gemm.cpp
  • 在这个example中有这样一句代码:
    • 很显然,这是通过工厂类拿到算子实例集合
    using DeviceOp =
        ck::tensor_operation::device::DeviceGemm;
    
    // get device op instances
    const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory<
        DeviceOp>::GetInstances();
    
  • DeviceGemm这个operator长这样,当然这也是个虚基类,真正的实现实在Impl文件夹中定义的:
    template 
    struct DeviceGemm : public BaseOperator
    {
        virtual std::unique_ptr
        MakeArgumentPointer(const void* p_a,
                            const void* p_b,
                            void* p_c,
                            ck::index_t M,
                            ck::index_t N,
                            ck::index_t K,
                            ck::index_t StrideA,
                            ck::index_t StrideB,
                            ck::index_t StrideC,
                            AElementwiseOperation a_element_op,
                            BElementwiseOperation b_element_op,
                            CElementwiseOperation c_element_op) = 0;
    
        virtual std::unique_ptr MakeInvokerPointer() = 0;
    };
    
  • 然后会在下一级子类中真正实现:
    struct DeviceGemm_Xdl_CShuffle : public DeviceGemm
    ........
    
  • 然后通过工厂类的GetInstances拿到op_ptrs,接下来就是遍历,在for的过程中需要经过:
    • auto argument_ptr = op_ptr->MakeArgumentPointer
    • auto invoker_ptr = op_ptr->MakeInvokerPointer
    • invoker_ptr->Run
  • 这就是这个example干的事儿,实际上在调用的过程中factory应该可以不用,而直接使用实例化的op_ptr

特有名词

  • 在阅读demo(如gemm.cc)的时候会发现一些特有的名词,如:

    • using F16 = ck::half_t;
    • using Row = ck::tensor_layout::gemm::RowMajor;
    • using Col = ck::tensor_layout::gemm::ColumnMajor;
    • using PassThrough = ck::tensor_operation::element_wise::PassThrough;
  • 有一些比较好理解,如:半精度之类

  • 有一些可以勉强看出来,如layerout是列优先还是行优先(RowMajor/ColumnMajor)

  • 有一些比较抽象,如PassThrough

以PassThrough为例

  • 这是一个传值操作,代码实现位于:include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp
  • 下面展示了一部分可以看到,函数的作用是传值
struct PassThrough
{
    template <typename Y, typename X>
    __host__ __device__ void operator()(Y& y, const X& x) const;

    template <>
    __host__ __device__ void operator()<double, double>(double& y, const double& x) const
    {
        y = x;
    }

    template <>
    __host__ __device__ void operator()<float, float>(float& y, const float& x) const
    {
        y = x;
    }
    ....
};

你可能感兴趣的:(c++)