在tensorflow的kernels添加一个op
在tensorflow/core/kernels/目录下新建一个zeroout_op.cc的文件
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/platform/logging.h"
namespace tensorflow {
class ZeroOutOp : public OpKernel {
public:
explicit ZeroOutOp(OpKernelConstruction* context) : OpKernel(context) {}
void Compute(OpKernelContext* context) override {
LOG(WARNING)<< "ZeroOutOp--running";
// Grab the input tensor
const Tensor& input_tensor = context->input(0);
auto input = input_tensor.flat<int32>();
// Create an output tensor
Tensor* output_tensor = NULL;
OP_REQUIRES_OK(context, context->allocate_output(0, input_tensor.shape(),
&output_tensor));
auto output_flat = output_tensor->flat<int32>();
// Set all but the first element of the output tensor to 0.
const int N = input.size();
for (int i = 1; i < N; i++) {
output_flat(i) = 0;
}
// Preserve the first input value if possible.
if (N > 0) output_flat(0) = input(0);
}
};
REGISTER_KERNEL_BUILDER(Name("ZeroOut").Device(DEVICE_CPU), ZeroOutOp);
}
打开tensorflow/core/kernels/下的BUILD文件,添加bazel编译文件
cc_library(
name = "math",
deps = [
":aggregate_ops",
":argmax_op",
":batch_matmul_op",
":betainc_op",
":bincount_op",
":bucketize_op",
":cast_op",
":check_numerics_op",
":compare_and_bitpack_op",
":cross_op",
":cwise_op",
":fft_ops",
":histogram_op",
":matmul_op",
+++ ":zeroout_op",
":population_count_op",
":reduction_ops",
":scan_ops",
":segment_reduction_ops",
":sequence_ops",
],
)
@3000行
tf_kernel_library(
name = "zeroout_op",
srcs = [
"zeroout_op.cc",
],
deps = MATH_DEPS,
)
在tensorflow/core/ops/ops.pbtxt里添加op
op {
name: "ZeroOut"
input_arg {
name: "to_zero"
type_attr: "T"
}
output_arg {
name: "zeroed"
type_attr: "T"
}
attr {
name: "T"
type: "type"
allowed_values {
list {
type: DT_BFLOAT16
type: DT_HALF
type: DT_FLOAT
type: DT_DOUBLE
type: DT_INT32
type: DT_INT64
}
}
}
}
注册ZeroOut,打开tensorflow/core/ops/math_ops.cc,添加如下代码:
REGISTER_OP("ZeroOut")
.Input("to_zero: int32")
.Output("zeroed: int32")
.Attr("T: {bfloat16, half, float, double, int32, int64}")
.SetShapeFn([](::tensorflow::shape_inference::InferenceContext* c) {
c->set_output(0, c->input(0));
return Status::OK();
});
bazel重新编译安装,实例如下:
import tensorflow as tf
#a = tf.constant([1, 2, 3, 4, 5, 6], name='a')
with tf.device('/spu:0'):
c = tf.zero_out([1,2,3,4,5], "int32")
sess = tf.Session(config=tf.ConfigProto(log_device_placement=True))
#sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True))
#options = tf.RunOptions(output_partition_graphs=True)
#metadata = tf.RunMetadata()
#c_val = sess.run(c,options=options,run_metadata=metadata)
#print(metadata.partition_graphs)
# Runs the op.
print sess.run(c)