2017-02-10 94 views
3

我想用tf.cond(pred, fn1, fn2, name=None)的条件分支。假设我有两个张量:x, y。各张量是分批的0/1和我想用这个张量压缩x < y作为源 tf.condpred论点:如何使用tf.cond批处理

PRED:标量确定是否要返回的fn1的结果或FN2。

但是,如果我与批次工作那么它看起来像我需要遍历图形内源张量,使切片在每一批项目,并申请tf.cond为每个项目。怀疑地看着我。为什么tf.cond不接受批处理和唯一标量?你能建议什么是批量使用它的正确方法吗?

回答

4

tf.where听起来像你想要的:张量之间的向量化选择。

tf.cond是控制流量调节器:它确定哪些OPS被执行,因此它认为有用批次语义是困难的。

我们还可以将这些操作混合在一起:基于条件切片并将这些切片传递给两个分支的操作​​。

import tensorflow as tf 
from tensorflow.python.util import nest 

def slicing_where(condition, full_input, true_branch, false_branch): 
    """Split `full_input` between `true_branch` and `false_branch` on `condition`. 

    Args: 
    condition: A boolean Tensor with shape [B_1, ..., B_N]. 
    full_input: A Tensor or nested tuple of Tensors of any dtype, each with 
     shape [B_1, ..., B_N, ...], to be split between `true_branch` and 
     `false_branch` based on `condition`. 
    true_branch: A function taking a single argument, that argument having the 
     same structure and number of batch dimensions as `full_input`. Receives 
     slices of `full_input` corresponding to the True entries of 
     `condition`. Returns a Tensor or nested tuple of Tensors, each with batch 
     dimensions matching its inputs. 
    false_branch: Like `true_branch`, but receives inputs corresponding to the 
     false elements of `condition`. Returns a Tensor or nested tuple of Tensors 
     (with the same structure as the return value of `true_branch`), but with 
     batch dimensions matching its inputs. 
    Returns: 
    Interleaved outputs from `true_branch` and `false_branch`, each Tensor 
    having shape [B_1, ..., B_N, ...]. 
    """ 
    full_input_flat = nest.flatten(full_input) 
    true_indices = tf.where(condition) 
    false_indices = tf.where(tf.logical_not(condition)) 
    true_branch_inputs = nest.pack_sequence_as(
     structure=full_input, 
     flat_sequence=[tf.gather_nd(params=input_tensor, indices=true_indices) 
        for input_tensor in full_input_flat]) 
    false_branch_inputs = nest.pack_sequence_as(
     structure=full_input, 
     flat_sequence=[tf.gather_nd(params=input_tensor, indices=false_indices) 
        for input_tensor in full_input_flat]) 
    true_outputs = true_branch(true_branch_inputs) 
    false_outputs = false_branch(false_branch_inputs) 
    nest.assert_same_structure(true_outputs, false_outputs) 
    def scatter_outputs(true_output, false_output): 
    batch_shape = tf.shape(condition) 
    scattered_shape = tf.concat(
     [batch_shape, tf.shape(true_output)[tf.rank(batch_shape):]], 
     0) 
    true_scatter = tf.scatter_nd(
     indices=tf.cast(true_indices, tf.int32), 
     updates=true_output, 
     shape=scattered_shape) 
    false_scatter = tf.scatter_nd(
     indices=tf.cast(false_indices, tf.int32), 
     updates=false_output, 
     shape=scattered_shape) 
    return true_scatter + false_scatter 
    result = nest.pack_sequence_as(
     structure=true_outputs, 
     flat_sequence=[ 
      scatter_outputs(true_single_output, false_single_output) 
      for true_single_output, false_single_output 
      in zip(nest.flatten(true_outputs), nest.flatten(false_outputs))]) 
    return result 

一些例子:

vector_test = slicing_where(
    condition=tf.equal(tf.range(10) % 2, 0), 
    full_input=tf.range(10, dtype=tf.float32), 
    true_branch=lambda x: 0.2 + x, 
    false_branch=lambda x: 0.1 + x) 

cross_range = (tf.range(10, dtype=tf.float32)[:, None] 
       * tf.range(10, dtype=tf.float32)[None, :]) 
matrix_test = slicing_where(
    condition=tf.equal(tf.range(10) % 3, 0), 
    full_input=cross_range, 
    true_branch=lambda x: -x, 
    false_branch=lambda x: x + 0.1) 

with tf.Session(): 
    print(vector_test.eval()) 
    print(matrix_test.eval()) 

打印:

[ 0.2   1.10000002 2.20000005 3.0999999 4.19999981 5.0999999 
    6.19999981 7.0999999 8.19999981 9.10000038] 
[[ 0.   0.   0.   0.   0.   0. 
    0.   0.   0.   0.  ] 
[ 0.1   1.10000002 2.0999999 3.0999999 4.0999999 
    5.0999999 6.0999999 7.0999999 8.10000038 9.10000038] 
[ 0.1   2.0999999 4.0999999 6.0999999 8.10000038 
    10.10000038 12.10000038 14.10000038 16.10000038 18.10000038] 
[ 0.   -3.   -6.   -9.   -12.   -15. 
    -18.   -21.   -24.   -27.  ] 
[ 0.1   4.0999999 8.10000038 12.10000038 16.10000038 
    20.10000038 24.10000038 28.10000038 32.09999847 36.09999847] 
[ 0.1   5.0999999 10.10000038 15.10000038 20.10000038 
    25.10000038 30.10000038 35.09999847 40.09999847 45.09999847] 
[ 0.   -6.   -12.   -18.   -24.   -30. 
    -36.   -42.   -48.   -54.  ] 
[ 0.1   7.0999999 14.10000038 21.10000038 28.10000038 
    35.09999847 42.09999847 49.09999847 56.09999847 63.09999847] 
[ 0.1   8.10000038 16.10000038 24.10000038 32.09999847 
    40.09999847 48.09999847 56.09999847 64.09999847 72.09999847] 
[ 0.   -9.   -18.   -27.   -36.   -45. 
    -54.   -63.   -72.   -81.  ]] 
+0

的目标是 “控制流”。这就是为什么我需要tf.cond。但是你完全正确的是,在当前的体系结构中,“很难想到有用的批处理语义”。我只能使用SGD。谢谢!现在我已经意识到这一点。 –

+0

你能详细谈谈你想解决的问题吗?乐于帮助头脑风暴的解决方案。 –

+0

让我们想象在tf.cond(两个brachnes - 检测器(d1,d2)和一个 - 数据源(ds))之前的三个控制流。 tf.cond(p1,p2)之后还有两个分支。假设第一个探测器的输出大于或等于第二个探测器的输出,那么来自数据源(ds)分支的数据在其他情况下应该由p1分支处理 - 通过p2分支 我们不应该同时处理两个分支 –