TensorFlow实现自定义Op方式

今天小编就为大家分享一篇TensorFlow实现自定义Op方式,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧

『写在前面』

以CTC Beam search decoder为例,简单整理一下TensorFlow实现自定义Op的操作流程。

基本的流程

1. 定义Op接口

 #include "tensorflow/core/framework/op.h" REGISTER_OP("Custom") .Input("custom_input: int32") .Output("custom_output: int32");

2. 为Op实现Compute操作(CPU)或实现kernel(GPU)

 #include "tensorflow/core/framework/op_kernel.h" using namespace tensorflow; class CustomOp : public OpKernel{ public: explicit CustomOp(OpKernelConstruction* context) : OpKernel(context) {} void Compute(OpKernelContext* context) override { // 获取输入 tensor. const Tensor& input_tensor = context->input(0); auto input = input_tensor.flat(); // 创建一个输出 tensor. Tensor* output_tensor = NULL; OP_REQUIRES_OK(context, context->allocate_output(0, input_tensor.shape(), &output_tensor)); auto output = output_tensor->template flat(); //进行具体的运算,操作input和output //…… } };

3. 将实现的kernel注册到TensorFlow系统中

REGISTER_KERNEL_BUILDER(Name("Custom").Device(DEVICE_CPU), CustomOp);

CTCBeamSearchDecoder自定义

该Op对应TensorFlow中的源码部分

Op接口的定义:

tensorflow-master/tensorflow/core/ops/ctc_ops.cc

CTCBeamSearchDecoder本身的定义:

tensorflow-master/tensorflow/core/util/ctc/ctc_beam_search.cc

Op-Class的封装与Op注册:

tensorflow-master/tensorflow/core/kernels/ctc_decoder_ops.cc

基于源码修改的Op

 #include  #include  #include  #include "tensorflow/core/util/ctc/ctc_beam_search.h" #include "tensorflow/core/framework/op.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/shape_inference.h" #include "tensorflow/core/kernels/bounds_check.h" namespace tf = tensorflow; using tf::shape_inference::DimensionHandle; using tf::shape_inference::InferenceContext; using tf::shape_inference::ShapeHandle; using namespace tensorflow; REGISTER_OP("CTCBeamSearchDecoderWithParam") .Input("inputs: float") .Input("sequence_length: int32") .Attr("beam_width: int >= 1") .Attr("top_paths: int >= 1") .Attr("merge_repeated: bool = true") //新添加了两个参数 .Attr("label_selection_size: int >= 0 = 0") .Attr("label_selection_margin: float") .Output("decoded_indices: top_paths * int64") .Output("decoded_values: top_paths * int64") .Output("decoded_shape: top_paths * int64") .Output("log_probability: float") .SetShapeFn([](InferenceContext* c) { ShapeHandle inputs; ShapeHandle sequence_length; TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 3, &inputs)); TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &sequence_length)); // Get batch size from inputs and sequence_length. DimensionHandle batch_size; TF_RETURN_IF_ERROR( c->Merge(c->Dim(inputs, 1), c->Dim(sequence_length, 0), &batch_size)); int32 top_paths; TF_RETURN_IF_ERROR(c->GetAttr("top_paths", &top_paths)); // Outputs. int out_idx = 0; for (int i = 0; i set_output(out_idx++, c->Matrix(InferenceContext::kUnknownDim, 2)); } for (int i = 0; i set_output(out_idx++, c->Vector(InferenceContext::kUnknownDim)); } ShapeHandle shape_v = c->Vector(2); for (int i = 0; i set_output(out_idx++, shape_v); } c->set_output(out_idx++, c->Matrix(batch_size, top_paths)); return Status::OK(); }); typedef Eigen::ThreadPoolDevice CPUDevice; inline float RowMax(const TTypes::UnalignedConstMatrix& m, int r, int* c) { *c = 0; CHECK_LT(0, m.dimension(1)); float p = m(r, 0); for (int i = 1; i  p) { p = m(r, i); *c = i; } } return p; } class CTCDecodeHelper { public: CTCDecodeHelper() : top_paths_(1) {} inline int GetTopPaths() const { return top_paths_; } void SetTopPaths(int tp) { top_paths_ = tp; } Status ValidateInputsGenerateOutputs( OpKernelContext* ctx, const Tensor** inputs, const Tensor** seq_len, Tensor** log_prob, OpOutputList* decoded_indices, OpOutputList* decoded_values, OpOutputList* decoded_shape) const { Status status = ctx->input("inputs", inputs); if (!status.ok()) return status; status = ctx->input("sequence_length", seq_len); if (!status.ok()) return status; const TensorShape& inputs_shape = (*inputs)->shape(); if (inputs_shape.dims() != 3) { return errors::InvalidArgument("inputs is not a 3-Tensor"); } const int64 max_time = inputs_shape.dim_size(0); const int64 batch_size = inputs_shape.dim_size(1); if (max_time == 0) { return errors::InvalidArgument("max_time is 0"); } if (!TensorShapeUtils::IsVector((*seq_len)->shape())) { return errors::InvalidArgument("sequence_length is not a vector"); } if (!(batch_size == (*seq_len)->dim_size(0))) { return errors::FailedPrecondition( "len(sequence_length) != batch_size. ", "len(sequence_length): ", (*seq_len)->dim_size(0), " batch_size: ", batch_size); } auto seq_len_t = (*seq_len)->vec(); for (int b = 0; b output_list("decoded_indices", decoded_indices); if (!s.ok()) return s; s = ctx->output_list("decoded_values", decoded_values); if (!s.ok()) return s; s = ctx->output_list("decoded_shape", decoded_shape); if (!s.ok()) return s; return Status::OK(); } // sequences[b][p][ix] stores decoded value "ix" of path "p" for batch "b". Status StoreAllDecodedSequences( const std::vector > >& sequences, OpOutputList* decoded_indices, OpOutputList* decoded_values, OpOutputList* decoded_shape) const { // Calculate the total number of entries for each path const int64 batch_size = sequences.size(); std::vector num_entries(top_paths_, 0); // Calculate num_entries per path for (const auto& batch_s : sequences) { CHECK_EQ(batch_s.size(), top_paths_); for (int p = 0; p allocate(p, TensorShape({p_num, 2}), &p_indices); if (!s.ok()) return s; s = decoded_values->allocate(p, TensorShape({p_num}), &p_values); if (!s.ok()) return s; s = decoded_shape->allocate(p, TensorShape({2}), &p_shape); if (!s.ok()) return s; auto indices_t = p_indices->matrix(); auto values_t = p_values->vec(); auto shape_t = p_shape->vec(); int64 max_decoded = 0; int64 offset = 0; for (int64 b = 0; b GetAttr("merge_repeated", &merge_repeated_)); OP_REQUIRES_OK(ctx, ctx->GetAttr("beam_width", &beam_width_)); //从参数列表中读取新添的两个参数 OP_REQUIRES_OK(ctx, ctx->GetAttr("label_selection_size", &label_selection_size)); OP_REQUIRES_OK(ctx, ctx->GetAttr("label_selection_margin", &label_selection_margin)); int top_paths; OP_REQUIRES_OK(ctx, ctx->GetAttr("top_paths", &top_paths)); decode_helper_.SetTopPaths(top_paths); } void Compute(OpKernelContext* ctx) override { const Tensor* inputs; const Tensor* seq_len; Tensor* log_prob = nullptr; OpOutputList decoded_indices; OpOutputList decoded_values; OpOutputList decoded_shape; OP_REQUIRES_OK(ctx, decode_helper_.ValidateInputsGenerateOutputs( ctx, &inputs, &seq_len, &log_prob, &decoded_indices, &decoded_values, &decoded_shape)); auto inputs_t = inputs->tensor(); auto seq_len_t = seq_len->vec(); auto log_prob_t = log_prob->matrix(); const TensorShape& inputs_shape = inputs->shape(); const int64 max_time = inputs_shape.dim_size(0); const int64 batch_size = inputs_shape.dim_size(1); const int64 num_classes_raw = inputs_shape.dim_size(2); OP_REQUIRES( ctx, FastBoundsCheck(num_classes_raw, std::numeric_limits::max()), errors::InvalidArgument("num_classes cannot exceed max int")); const int num_classes = static_cast(num_classes_raw); log_prob_t.setZero(); std::vector::UnalignedConstMatrix> input_list_t; for (std::size_t t = 0; t  beam_search(num_classes, beam_width_, &beam_scorer_, 1 /* batch_size */, merge_repeated_); //使用传入的两个参数进行Set beam_search.SetLabelSelectionParameters(label_selection_size, label_selection_margin); Tensor input_chip(DT_FLOAT, TensorShape({num_classes})); auto input_chip_t = input_chip.flat(); std::vector > > best_paths(batch_size); std::vector log_probs; // Assumption: the blank index is num_classes - 1 for (int b = 0; b (input_chip_t.data(), num_classes); beam_search.Step(input_bi); } OP_REQUIRES_OK( ctx, beam_search.TopPaths(decode_helper_.GetTopPaths(), &best_paths_b, &log_probs, merge_repeated_)); beam_search.Reset(); for (int bp = 0; bp ::DefaultBeamScorer beam_scorer_; bool merge_repeated_; int beam_width_; //新添两个数据成员,用于存储新加的参数 int label_selection_size; float label_selection_margin; TF_DISALLOW_COPY_AND_ASSIGN(CTCBeamSearchDecoderWithParamOp); }; REGISTER_KERNEL_BUILDER(Name("CTCBeamSearchDecoderWithParam").Device(DEVICE_CPU), CTCBeamSearchDecoderWithParamOp); 

将自定义的Op编译成.so文件

在tensorflow-master目录下新建一个文件夹custom_op

cd custom_op

新建一个BUILD文件,并在其中添加如下代码:

 cc_library( name = "ctc_decoder_with_param", srcs = [ "new_beamsearch.cc" ] + glob(["boost_locale/**/*.hpp"]), includes = ["boost_locale"], copts = ["-std=c++11"], deps = ["//tensorflow/core:core", "//tensorflow/core/util/ctc", "//third_party/eigen3", ], )

编译过程:

1. cd 到 tensorflow-master 目录下

2. bazel build -c opt --copt=-O3 //tensorflow:libtensorflow_cc.so //custom_op:ctc_decoder_with_param

3. bazel-bin/custom_op 目录下生成 libctc_decoder_with_param.so

在训练(预测)程序中使用自定义的Op

在程序中定义如下的方法:

 decode_param_op_module = tf.load_op_library('libctc_decoder_with_param.so') def decode_with_param(inputs, sequence_length, beam_width=100, top_paths=1, merge_repeated=True): decoded_ixs, decoded_vals, decoded_shapes, log_probabilities = ( decode_param_op_module.ctc_beam_search_decoder_with_param( inputs, sequence_length, beam_width=beam_width, top_paths=top_paths, merge_repeated=merge_repeated, label_selection_size=40, label_selection_margin=0.99)) return ( [tf.SparseTensor(ix, val, shape) for (ix, val, shape) in zip(decoded_ixs, decoded_vals, decoded_shapes)], log_probabilities) 

然后就可以像使用tf.nn.ctc_beam_search_decoder一样使用该Op了。

以上就是TensorFlow实现自定义Op方式的详细内容,更多请关注0133技术站其它相关文章!

赞(0) 打赏
未经允许不得转载:0133技术站首页 » python