这样的问题的一般方法是使用tf.map_fn()
通过对每一行应用函数来创建具有适当值的新张量。让我们如何表达对单行的条件开始:
row = tf.placeholder(tf.int32, shape=[10])
condition = tf.logical_and(
tf.equal(row[0], tf.reduce_max(row)),
tf.less(row[0], 4))
sess = tf.Session()
print sess.run(condition, feed_dict={row: [6, -2, -2, -2, -1, -2, -3, -3, -6, -6]})
print sess.run(condition, feed_dict={row: [1, -6, -7, -7, -7, -7, -7, -6, -6, -6]})
print sess.run(condition, feed_dict={row: [5, -3, -3, -4, -4, -4, -4, -3, -3, -3]})
# Prints the following:
# False
# True
# False
现在我们有一个条件,我们可以使用tf.cond()
建立一个条件表达式交换前两个元素,如果条件为真:
def swap_first_two(x):
swapped_first_two = tf.stack([x[1], x[0]])
rest = x[2:]
return tf.concat([swapped_first_two, rest], 0)
maybe_swapped = tf.cond(condition, lambda: swap_first_two(row), lambda: row)
print sess.run(maybe_swapped, feed_dict={row: [6, -2, -2, -2, -1, -2, -3, -3, -6, -6]})
print sess.run(maybe_swapped, feed_dict={row: [1, -6, -7, -7, -7, -7, -7, -6, -6, -6]})
print sess.run(maybe_swapped, feed_dict={row: [5, -3, -3, -4, -4, -4, -4, -3, -3, -3]})
# Prints the following:
# [ 6 -2 -2 -2 -1 -2 -3 -3 -6 -6]
# [-6 1 -7 -7 -7 -7 -7 -6 -6 -6]
# [ 5 -3 -3 -4 -4 -4 -4 -3 -3 -3]
最后,我们把它放在一起通过包装的maybe_swapped
计算的功能,并将它传递给tf.map_fn()
:
matrix = tf.constant([
[6, -2, -2, -2, -1, -2, -3, -3, -6, -6],
[1, -6, -7, -7, -7, -7, -7, -6, -6, -6],
[5, -3, -3, -4, -4, -4, -4, -3, -3, -3],
])
def row_function(row):
condition = tf.logical_and(
tf.equal(row[0], tf.reduce_max(row)),
tf.less(row[0], 4))
def swap_first_two(x):
swapped_first_two = tf.stack([x[1], x[0]])
rest = x[2:]
return tf.concat([swapped_first_two, rest], 0)
maybe_swapped = tf.cond(condition, lambda: swap_first_two(row), lambda: row)
return maybe_swapped
result = tf.map_fn(row_function, matrix)
print sess.run(result)
# Prints the following:
# [[ 6 -2 -2 -2 -1 -2 -3 -3 -6 -6]
# [-6 1 -7 -7 -7 -7 -7 -6 -6 -6]
# [ 5 -3 -3 -4 -4 -4 -4 -3 -3 -3]]
你有没有尝试过任何东西?你应该在你的问题中包含它。 – gobrewers14