2017-07-15 53 views
0

这是我的代码,它使用DynamicPartition操作将矢量[1,2,3,4,5,6]分解为两个矢量[1,2,3]和[4 ,5,6]使用掩模[1,1,1,0,0,0]:DynamicPartition返回单个输出而不是多个

@Test 
public void dynamicPartition2() { 
    Graph graph = new Graph(); 

    Output a = graph.opBuilder("Const", "a") 
      .setAttr("dtype", DataType.INT64) 
      .setAttr("value", Tensor.create(new long[]{6}, LongBuffer.wrap(new long[] {1, 2, 3, 4, 5, 6}))) 
      .build().output(0); 

    Output partitions = graph.opBuilder("Const", "partitions") 
      .setAttr("dtype", DataType.INT32) 
      .setAttr("value", Tensor.create(new long[]{6}, IntBuffer.wrap(new int[] {1, 1, 1, 0, 0, 0}))) 
      .build().output(0); 

    graph.opBuilder("DynamicPartition", "result") 
      .addInput(a) 
      .addInput(partitions) 
      .setAttr("num_partitions", 2) 
      .build().output(0); 

    try (Session s = new Session(graph)) { 
     List<Tensor> outputs = s.runner().fetch("result").run(); 

     try (Tensor output = outputs.get(0)) { 
      LongBuffer result = LongBuffer.allocate(3); 
      output.writeTo(result); 

      assertArrayEquals("Shape", new long[]{3}, output.shape()); 
      assertArrayEquals("Values", new long[]{4, 5, 6}, result.array()); 
     } 

     //Test will fail here 
     try (Tensor output = outputs.get(1)) { 
      LongBuffer result = LongBuffer.allocate(3); 
      output.writeTo(result); 

      assertArrayEquals("Shape", new long[]{3}, output.shape()); 
      assertArrayEquals("Values", new long[]{1, 2, 3}, result.array()); 
     } 
    } 
} 

主叫长度为1的s.runner().fetch("result").run()列表之后返回到与值[4,5,6]。看来我的图形只产生一个输出。

如何获得分裂矢量的其余部分?

+0

你只需要它的Java或将Python的答案是足够的? –

+0

欢迎任何答案 – Aeteros

+0

我的答案解释了什么? –

回答

1

DynamicPartition操作返回多个输出(每个分区一个输出),但Session.Runner.fetch调用仅请求第0个输出。

Java API缺少一堆Python API的便利糖,但您可以通过明确请求所有输出来做你想做的事。换句话说,改变从:

List<Tensor> outputs = s.runner().fetch("result").run(); 

List<Tensor> outputs = s.runner().fetch("result", 0).fetch("result", 1).run(); 

希望有所帮助。

+0

提交一个错误谢谢,这是解决方案 – Aeteros

0

不确定Java(我不知道它,没有环境调查),但在Python中一切正常。例如,这

import tensorflow as tf 
a = tf.constant([1, 2, 3, 4, 5, 6]) 
b = tf.constant([1, 1, 1, 0, 0, 0]) 
c = tf.dynamic_partition(a, b, 2) 
with tf.Session() as sess: 
    v1, v2 = sess.run(c) 
    print v1 
    print v2 

返回正确的分区。

相关问题