tensorflow源码阅读添加一个Op

在tensorflow的kernels添加一个op
在tensorflow/core/kernels/目录下新建一个zeroout_op.cc的文件

#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/platform/logging.h"

namespace tensorflow {

class ZeroOutOp : public OpKernel {
 public:
    explicit ZeroOutOp(OpKernelConstruction* context) : OpKernel(context) {}

    void Compute(OpKernelContext* context) override {
     LOG(WARNING)<< "ZeroOutOp--running";
    // Grab the input tensor
    const Tensor& input_tensor = context->input(0);
    auto input = input_tensor.flat<int32>();

    // Create an output tensor
    Tensor* output_tensor = NULL;
    OP_REQUIRES_OK(context, context->allocate_output(0, input_tensor.shape(),
                                                     &output_tensor));
    auto output_flat = output_tensor->flat<int32>();

    // Set all but the first element of the output tensor to 0.
    const int N = input.size();
    for (int i = 1; i < N; i++) {
      output_flat(i) = 0;
    }

    // Preserve the first input value if possible.
    if (N > 0) output_flat(0) = input(0);
  }
};
REGISTER_KERNEL_BUILDER(Name("ZeroOut").Device(DEVICE_CPU), ZeroOutOp);
}

打开tensorflow/core/kernels/下的BUILD文件,添加bazel编译文件

cc_library(
    name = "math",
    deps = [
        ":aggregate_ops",
        ":argmax_op",
        ":batch_matmul_op",
        ":betainc_op",
        ":bincount_op",
        ":bucketize_op",
        ":cast_op",
        ":check_numerics_op",
        ":compare_and_bitpack_op",
        ":cross_op",
        ":cwise_op",
        ":fft_ops",
        ":histogram_op",
        ":matmul_op",
  +++   ":zeroout_op",
        ":population_count_op",
        ":reduction_ops",
        ":scan_ops",
        ":segment_reduction_ops",
        ":sequence_ops",
    ],
)

@3000行
tf_kernel_library(
    name = "zeroout_op",
    srcs = [
        "zeroout_op.cc",
    ],
    deps = MATH_DEPS,
)

在tensorflow/core/ops/ops.pbtxt里添加op

op {
  name: "ZeroOut"
  input_arg {
    name: "to_zero"
    type_attr: "T"
  }
  output_arg {
    name: "zeroed"
    type_attr: "T"
  }
  attr {
    name: "T"
    type: "type"
    allowed_values {
      list {
        type: DT_BFLOAT16
        type: DT_HALF
        type: DT_FLOAT
        type: DT_DOUBLE
        type: DT_INT32
        type: DT_INT64
      }
    }
  }
}

注册ZeroOut,打开tensorflow/core/ops/math_ops.cc,添加如下代码:

REGISTER_OP("ZeroOut")
    .Input("to_zero: int32")
    .Output("zeroed: int32")
    .Attr("T: {bfloat16, half, float, double, int32, int64}")
    .SetShapeFn([](::tensorflow::shape_inference::InferenceContext* c) {
      c->set_output(0, c->input(0));
      return Status::OK();
    });

bazel重新编译安装,实例如下:

import tensorflow as tf

#a = tf.constant([1, 2, 3, 4, 5, 6], name='a')
with tf.device('/spu:0'):
        c = tf.zero_out([1,2,3,4,5], "int32")

sess = tf.Session(config=tf.ConfigProto(log_device_placement=True))
#sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True))

#options = tf.RunOptions(output_partition_graphs=True)
#metadata = tf.RunMetadata()
#c_val = sess.run(c,options=options,run_metadata=metadata)
#print(metadata.partition_graphs)


# Runs the op.                                                                  
print sess.run(c)

你可能感兴趣的:(c++)