从TensorFlow文档中,可以执行以下操作以使用固有OP构建图形
#include "tensorflow/cc/client/client_session.h"
#include "tensorflow/cc/ops/standard_ops.h"
#include "tensorflow/core/framework/tensor.h"
int main() {
using namespace tensorflow;
using namespace tensorflow::ops;
Scope root = Scope::NewRootScope();
// Matrix A = [3 2; -1 0]
auto A = Const(root, { {3.f, 2.f}, {-1.f, 0.f} });
// Vector b = [3 5]
auto b = Const(root, { {3.f, 5.f} });
// v = Ab^T
auto v = MatMul(root.WithOpName("v"), A, b, MatMul::TransposeB(true));
std::vector<Tensor> outputs;
ClientSession session(root);
// Run and fetch v
TF_CHECK_OK(session.Run({v}, &outputs));
// Expect outputs[0] == [19; -3]
LOG(INFO) << outputs[0].matrix<float>();
return 0;
}
似乎
MatMul
类是自动生成的,因为github源代码中没有tensorflow/cc/ops/math_ops.h
。如何为自定义操作(例如here的ZeroOut OP)执行相同操作
最佳答案
以here中的ZeroOut
为例,您必须执行以下操作
class ZeroOut {
public:
ZeroOut(const ::tensorflow::Scope& scope, ::tensorflow::Input x);
operator ::tensorflow::Output() const { return y; }
operator ::tensorflow::Input() const { return y; }
::tensorflow::Node* node() const { return y.node(); }
::tensorflow::Output y;
};
ZeroOut::ZeroOut(const ::tensorflow::Scope& scope, ::tensorflow::Input x) {
if (!scope.ok()) return;
auto _x = ::tensorflow::ops::AsNodeOut(scope, x);
if (!scope.ok()) return;
::tensorflow::Node* ret;
const auto unique_name = scope.GetUniqueNameForOp("ZeroOut");
auto builder = ::tensorflow::NodeBuilder(unique_name, "ZeroOut")
.Input(_x)
;
scope.UpdateBuilder(&builder);
scope.UpdateStatus(builder.Finalize(scope.graph(), &ret));
if (!scope.ok()) return;
scope.UpdateStatus(scope.DoShapeInference(ret));
this->y = Output(ret, 0);
}
然后你可以用它来建立图形
Scope root = Scope::NewRootScope();
// Matrix A = [3 2; -1 0]
auto A = Const(root, { {3, 2}, {-1, 0} });
auto v = ZeroOut(root.WithOpName("v"), A);
std::vector<Tensor> outputs;
ClientSession session(root);
// Run and fetch v
TF_CHECK_OK(session.Run({v}, &outputs));
LOG(INFO) << outputs[0].matrix<int>();
注意:对于TensorFlow固有的OP,诸如
ZeroOut class
之类的代码由bazel规则自动生成。如果我们只有几个自定义OP,我们可以模仿那些代码(例如tensorflow/cc/ops/math_ops.h
)来手写我们自己的类。关于c++ - 如何使用自定义OP在C++中构建TensorFlow图?,我们在Stack Overflow上找到一个类似的问题:https://stackoverflow.com/questions/53384454/