2017-07-06 81 views
0

假设我有这样的修改张量

[ 
    [ 6 -2 -2 -2 -1 -2 -3 -3 -6 -6] 
    [ 1 -6 -7 -7 -7 -7 -7 -6 -6 -6] 
    [ 5 -3 -3 -4 -4 -4 -4 -3 -3 -3] 
] 

我必须在每一行上执行以下操作的张量。 如果第一个元素是该行中的最大(值)元素,但其值小于4,则交换该行的第一个和第二个元素。结果张量将是

[ 
    [ 6 -2 -2 -2 -1 -2 -3 -3 -6 -6] 
    [ -6 1 -7 -7 -7 -7 -7 -6 -6 -6] #elements swapped 
    [ 5 -3 -3 -4 -4 -4 -4 -3 -3 -3] 
] 

我正在使用tensorflow模块的蟒蛇工作。请帮忙。

+0

你有没有尝试过任何东西?你应该在你的问题中包含它。 – gobrewers14

回答

6

这样的问题的一般方法是使用tf.map_fn()通过对每一行应用函数来创建具有适当值的新张量。让我们如何表达对单行的条件开始:

row = tf.placeholder(tf.int32, shape=[10]) 

condition = tf.logical_and(
    tf.equal(row[0], tf.reduce_max(row)), 
    tf.less(row[0], 4)) 

sess = tf.Session() 

print sess.run(condition, feed_dict={row: [6, -2, -2, -2, -1, -2, -3, -3, -6, -6]}) 
print sess.run(condition, feed_dict={row: [1, -6, -7, -7, -7, -7, -7, -6, -6, -6]}) 
print sess.run(condition, feed_dict={row: [5, -3, -3, -4, -4, -4, -4, -3, -3, -3]}) 

# Prints the following: 
# False 
# True 
# False 

现在我们有一个条件,我们可以使用tf.cond()建立一个条件表达式交换前两个元素,如果条件为真:

def swap_first_two(x): 
    swapped_first_two = tf.stack([x[1], x[0]]) 
    rest = x[2:] 
    return tf.concat([swapped_first_two, rest], 0) 

maybe_swapped = tf.cond(condition, lambda: swap_first_two(row), lambda: row) 

print sess.run(maybe_swapped, feed_dict={row: [6, -2, -2, -2, -1, -2, -3, -3, -6, -6]}) 
print sess.run(maybe_swapped, feed_dict={row: [1, -6, -7, -7, -7, -7, -7, -6, -6, -6]}) 
print sess.run(maybe_swapped, feed_dict={row: [5, -3, -3, -4, -4, -4, -4, -3, -3, -3]}) 

# Prints the following: 
# [ 6 -2 -2 -2 -1 -2 -3 -3 -6 -6] 
# [-6 1 -7 -7 -7 -7 -7 -6 -6 -6] 
# [ 5 -3 -3 -4 -4 -4 -4 -3 -3 -3] 

最后,我们把它放在一起通过包装的maybe_swapped计算的功能,并将它传递给tf.map_fn()

matrix = tf.constant([ 
    [6, -2, -2, -2, -1, -2, -3, -3, -6, -6], 
    [1, -6, -7, -7, -7, -7, -7, -6, -6, -6], 
    [5, -3, -3, -4, -4, -4, -4, -3, -3, -3], 
]) 

def row_function(row): 
    condition = tf.logical_and(
     tf.equal(row[0], tf.reduce_max(row)), 
     tf.less(row[0], 4)) 

    def swap_first_two(x): 
    swapped_first_two = tf.stack([x[1], x[0]]) 
    rest = x[2:] 
    return tf.concat([swapped_first_two, rest], 0) 

    maybe_swapped = tf.cond(condition, lambda: swap_first_two(row), lambda: row) 

    return maybe_swapped 

result = tf.map_fn(row_function, matrix) 

print sess.run(result) 

# Prints the following: 
# [[ 6 -2 -2 -2 -1 -2 -3 -3 -6 -6] 
# [-6 1 -7 -7 -7 -7 -7 -6 -6 -6] 
# [ 5 -3 -3 -4 -4 -4 -4 -3 -3 -3]]