这是我的代码,它使用DynamicPartition操作将矢量[1,2,3,4,5,6]分解为两个矢量[1,2,3]和[4 ,5,6]使用掩模[1,1,1,0,0,0]:DynamicPartition返回单个输出而不是多个
@Test
public void dynamicPartition2() {
Graph graph = new Graph();
Output a = graph.opBuilder("Const", "a")
.setAttr("dtype", DataType.INT64)
.setAttr("value", Tensor.create(new long[]{6}, LongBuffer.wrap(new long[] {1, 2, 3, 4, 5, 6})))
.build().output(0);
Output partitions = graph.opBuilder("Const", "partitions")
.setAttr("dtype", DataType.INT32)
.setAttr("value", Tensor.create(new long[]{6}, IntBuffer.wrap(new int[] {1, 1, 1, 0, 0, 0})))
.build().output(0);
graph.opBuilder("DynamicPartition", "result")
.addInput(a)
.addInput(partitions)
.setAttr("num_partitions", 2)
.build().output(0);
try (Session s = new Session(graph)) {
List<Tensor> outputs = s.runner().fetch("result").run();
try (Tensor output = outputs.get(0)) {
LongBuffer result = LongBuffer.allocate(3);
output.writeTo(result);
assertArrayEquals("Shape", new long[]{3}, output.shape());
assertArrayEquals("Values", new long[]{4, 5, 6}, result.array());
}
//Test will fail here
try (Tensor output = outputs.get(1)) {
LongBuffer result = LongBuffer.allocate(3);
output.writeTo(result);
assertArrayEquals("Shape", new long[]{3}, output.shape());
assertArrayEquals("Values", new long[]{1, 2, 3}, result.array());
}
}
}
主叫长度为1的s.runner().fetch("result").run()
列表之后返回到与值[4,5,6]。看来我的图形只产生一个输出。
如何获得分裂矢量的其余部分?
你只需要它的Java或将Python的答案是足够的? –
欢迎任何答案 – Aeteros
我的答案解释了什么? –