2016-09-06 886 views

回答

22

要获得平均值和方差只需使用tf.nn.moments

mean, var = tf.nn.moments(x, axes=[1]) 

更多关于tf.nn.moments PARAMS看到docs

+0

我怎样才能做到这在C++ API? –

+0

我只看到在C++ API中的平均值的文档:https://www.tensorflow.org/api_docs/cc/class/tensorflow/ops/mean我想你必须自己计算方差。 sum [(x-u)^ 2]您可能可以通过python源代码了解他们如何调用后端以了解如何更有效地计算方差。 – Steven

3

您还可以在改编自Keras下面的代码使用reduce_std

#coding=utf-8 
import numpy as np 
import tensorflow as tf 

def reduce_var(x, axis=None, keepdims=False): 
    """Variance of a tensor, alongside the specified axis. 

    # Arguments 
     x: A tensor or variable. 
     axis: An integer, the axis to compute the variance. 
     keepdims: A boolean, whether to keep the dimensions or not. 
      If `keepdims` is `False`, the rank of the tensor is reduced 
      by 1. If `keepdims` is `True`, 
      the reduced dimension is retained with length 1. 

    # Returns 
     A tensor with the variance of elements of `x`. 
    """ 
    m = tf.reduce_mean(x, axis=axis, keep_dims=True) 
    devs_squared = tf.square(x - m) 
    return tf.reduce_mean(devs_squared, axis=axis, keep_dims=keepdims) 

def reduce_std(x, axis=None, keepdims=False): 
    """Standard deviation of a tensor, alongside the specified axis. 

    # Arguments 
     x: A tensor or variable. 
     axis: An integer, the axis to compute the standard deviation. 
     keepdims: A boolean, whether to keep the dimensions or not. 
      If `keepdims` is `False`, the rank of the tensor is reduced 
      by 1. If `keepdims` is `True`, 
      the reduced dimension is retained with length 1. 

    # Returns 
     A tensor with the standard deviation of elements of `x`. 
    """ 
    return tf.sqrt(reduce_var(x, axis=axis, keepdims=keepdims)) 

if __name__ == '__main__': 
    x_np = np.arange(10).reshape(2, 5).astype(np.float32) 
    x_tf = tf.constant(x_np) 
    with tf.Session() as sess: 
     print(sess.run(reduce_std(x_tf, keepdims=True))) 
     print(sess.run(reduce_std(x_tf, axis=0, keepdims=True))) 
     print(sess.run(reduce_std(x_tf, axis=1, keepdims=True))) 
    print(np.std(x_np, keepdims=True)) 
    print(np.std(x_np, axis=0, keepdims=True)) 
    print(np.std(x_np, axis=1, keepdims=True)) 
+0

我使用tf1.4,由于某种原因tf.nn.moments不给我一个正确的结果...我试过你的版本,它在第一次尝试:) +1 –

相关问题