2017-08-24 59 views
1

(编辑原始问题) 在Tensorflow中,我们经常需要定义包含变量的函数,以在中间的网络层上实现。有没有评估此例如输出的方式:Tensorflow中的函数

import tensorflow as tf 
def Mult(mult): 
    A = tf.get_variable([2,2], initializer = tf.zeros_initializer) 
    B = tf.get_variable([2,2], initializer = tf.zeros_initializer) 
    return mult*tf.matmul(A,B) 

with tf.Session() as sess: 
    sess.run(tf.global_variables_initializer()) 
    # Here we alter the variables A,B in some fashion e.g. in an optimisation algorithm 
    print(Mult(A,B)) 

产生一个错误

回答

0

有三个错误在代码:

  1. 如果你创建一个函数,接受

    , B作为参数,在函数中创建具有相同名称的变量没有多大意义。因此,要么删除AB参数,要么不要在函数内部创建变量。

  2. Mult(A,B)返回张量。要检索张量的值,您需要在会话中对其进行评估。会话会跟踪参数AB的值,从中可以计算出Mult(A,B)的值。

  3. tf.get_variable函数需要一个name参数。下面

代码修复你的错误:

import tensorflow as tf 

def Mult(): 
    A = tf.get_variable('A', shape=[2,2], initializer = tf.zeros_initializer) 
    B = tf.get_variable('B', shape=[2,2], initializer = tf.zeros_initializer) 
    return tf.matmul(A,B) 

with tf.Session() as sess: 
    result = Mult() 
    sess.run(tf.global_variables_initializer()) 
    print(sess.run(result)) 
+0

为什么它产生一个错误,如果行:结果= ......,和sess.run(TF .....切换? – sbb

+0

'sess.run(tf.global_variables_initializer())'初始化图形中的所有全局变量。如果在调用'Mult()'之前调用它,图形中没有变量,而您最终得到两个单位变量('A'和'B')。 – GeertH

0

没有与代码的几个问题。首先,由于您将变量AB传递给该函数,因此您无需在函数内对它们进行初始化。因此,功能应该是这样的:

def Mult(A,B): 
    return tf.matmul(A,B) 

接下来,你需要初始化之前定义的变量,因此该行所定义的变量后

sess.run(tf.global_variables_initializer()) 

应该会出现。

另外,如果你想获得的Mult(A,B)数值,应打印

sess.run(Mult(A,B)) 

并不仅仅是Mult(A,B)(因为后者只会给你一个张量对象)。

最后,您需要为您定义的变量提供一个名称。 [2,2]是形状(函数的第二个参数)。第一个参数应该是名称。

这里是校正后的代码:

import tensorflow as tf 
def Mult(A,B): 
    return tf.matmul(A,B) 

A = tf.get_variable('A',shape=[2,2], initializer = tf.zeros_initializer) 
B = tf.get_variable('B',shape=[2,2], initializer = tf.zeros_initializer) 

with tf.Session() as sess: 
    sess.run(tf.global_variables_initializer()) 
    print(sess.run(Mult(A,B))) 

这将打印

[[ 0. 0.] 
[ 0. 0.]] 
相关问题