2017-08-25 53 views
1

我需要创建一个变量epsilon_n,该变量基于当前的step更改定义(和值)。由于我有两个以上的情况,似乎我不能使用tf.cond。我试图用tf.case如下:Tesnorflow:无法使用带输入参数的tf.case

import tensorflow as tf 

#### 
EPSILON_DELTA_PHASE1 = 33e-4 
EPSILON_DELTA_PHASE2 = 2.5 
#### 
step = tf.placeholder(dtype=tf.float32, shape=None) 


def fn1(step): 
    return tf.constant([1.]) 

def fn2(step): 
    return tf.constant([1.+step*EPSILON_DELTA_PHASE1]) 

def fn3(step): 
    return tf.constant([1.+step*EPSILON_DELTA_PHASE2]) 

epsilon_n = tf.case(
     pred_fn_pairs=[ 
      (tf.less(step, 3e4), lambda step: fn1(step)), 
      (tf.less(step, 6e4), lambda step: fn2(step)), 
      (tf.less(step, 1e5), lambda step: fn3(step))], 
      default=lambda: tf.constant([1e5]), 
     exclusive=False) 

不过,我不断收到此错误信息:

TypeError: <lambda>() missing 1 required positional argument: 'step' 

我试过如下:

epsilon_n = tf.case(
     pred_fn_pairs=[ 
      (tf.less(step, 3e4), fn1), 
      (tf.less(step, 6e4), fn2), 
      (tf.less(step, 1e5), fn3)], 
      default=lambda: tf.constant([1e5]), 
     exclusive=False) 

我仍相同的错误。 Tensorflow文档中的示例重点讨论了没有将输入参数传递给可调用函数的情况。我无法在互联网上找到关于tf.case的足够信息!请帮忙吗?

回答

2

这里有几个你需要做的改变。 为了保持一致性,您可以将所有返回值设置为变量。

# Since step is a scalar, scalar shape [() or [], not None] much be provided 
step = tf.placeholder(dtype=tf.float32, shape=()) 


def fn1(step): 
    return tf.constant([1.]) 

# Here you need to use Variable not constant, since you are modifying the value using placeholder 
def fn2(step): 
    return tf.Variable([1.+step*EPSILON_DELTA_PHASE1]) 

def fn3(step): 
    return tf.Variable([1.+step*EPSILON_DELTA_PHASE2]) 

epsilon_n = tf.case(
    pred_fn_pairs=[ 
     (tf.less(step, 3e4), lambda : fn1(step)), 
     (tf.less(step, 6e4), lambda : fn2(step)), 
     (tf.less(step, 1e5), lambda : fn3(step))], 
     default=lambda: tf.constant([1e5]), 
    exclusive=False) 
+0

修正轻微错字 –