我目前正在用Python构建模型,并从另一个Java客户端获取结果。

我需要知道如何从具有1个以上维度的TensorProto中获取float[][]List<List<Float>>(类似)。

在Python中,完成这项工作可能非常容易:

from tensorflow.python.framework import tensor_util
.
.
.
print tensor_util.MakeNdarray(tensorProto)


=====更新=======:

如果Java的tensorProto.getFloatValList()由Python的tensor_util.make_tensor_proto(vector)创建,则也不起作用。

以上所有情况都可以通过@Ash的答案解决

最佳答案

正如Allen在评论中提到的那样,这可能是一个很好的功能要求。

但是在此期间,一种变通方法是构造并运行一个图,该图解析编码的protobuf并返回Tensor。它不会特别有效,但是您可以执行以下操作:

import org.tensorflow.*;
import java.util.Arrays;

public final class ProtoToTensor {

  public static Tensor<Float> tensorFromSerializedProto(byte[] serialized) {
    // One may way to cache the Graph and Session as member variables to avoid paying the cost of
    // graph and session construction on each call.
    try (Graph g = buildGraphToParseProto();
        Session sess = new Session(g);
        Tensor<String> input = Tensors.create(serialized)) {
      return sess.runner()
          .feed("input", input)
          .fetch("output")
          .run()
          .get(0)
          .expect(Float.class);
    }
  }

  private static Graph buildGraphToParseProto() {
    Graph g = new Graph();
    // The graph construction process in Java is currently (as of TensorFlow 1.4) very verbose.
    // Once https://github.com/tensorflow/tensorflow/issues/7149 is resolved, this should become
    // *much* more convenient and succint.
    Output<String> in =
        g.opBuilder("Placeholder", "input")
            .setAttr("dtype", DataType.STRING)
            .setAttr("shape", Shape.scalar())
            .build()
            .output(0);
    g.opBuilder("ParseTensor", "output").setAttr("out_type", DataType.FLOAT).addInput(in).build();
    return g;
  }

  public static void main(String[] args) {
    // Let's say you got a byte[] representation of the proto somehow.
    // In this case, I got it from Python from the following program
    // that serializes the 1x1 matrix:
    /*
    import tensorflow as tf
    list(bytearray(tf.make_tensor_proto([[1.]]).SerializeToString()))
    */
    byte[] bytes = {8, 1, 18, 8, 18, 2, 8, 1, 18, 2, 8, 1, 42, 4, 0, 0, (byte)128, 63};
    try (Tensor<Float> t = tensorFromSerializedProto(bytes)) {
      // You can now get an float[][] array using t.copyTo().
      // t.shape() gives shape information.
      System.out.println("Tensor: " + t);
      float[][] f = t.copyTo(new float[1][1]);
      System.out.println("float[][]: " + Arrays.deepToString(f));
    }
  }
}


如您所见,这使用了一些相当低级的API来构造图形和会话。提出一个功能请求,用一行代替所有这些请求是合理的:

Tensor<Float> t = Tensor.createFromProto(serialized);

10-05 18:35