2017-07-20 292 views
0

我正在创建一个带有稀疏列的DNNclassifier。训练数据看起来是这样的,什么TensorFlow hash_bucket_size很重要

samples  col1       col2   price label 
    eg1 [[0,1,0,0,0,2,0,1,0,3,...] [[0,0,4,5,0,...] 5.2 0 
    eg2 [0,0,...]      [0,0,...]   0  1 
    eg3 [0,0,...]]     [0,0,...]   0  1 

下面的代码片段可以成功运行,

import tensorflow as tf 

sparse_feature_a = tf.contrib.layers.sparse_column_with_hash_bucket('col1', 3, dtype=tf.int32) 
sparse_feature_b = tf.contrib.layers.sparse_column_with_hash_bucket('col2', 1000, dtype=tf.int32) 

sparse_feature_a_emb = tf.contrib.layers.embedding_column(sparse_id_column=sparse_feature_a, dimension=2) 
sparse_feature_b_emb = tf.contrib.layers.embedding_column(sparse_id_column=sparse_feature_b, dimension=2) 
feature_c = tf.contrib.layers.real_valued_column('price') 

estimator = tf.contrib.learn.DNNClassifier(
    feature_columns=[sparse_feature_a_emb, sparse_feature_b_emb, feature_c], 
    hidden_units=[5, 3], 
    n_classes=2, 
    model_dir='./tfTmp/tfTmp0') 

# Input builders 
def input_fn_train(): # returns x, y (where y represents label's class index). 
    features = {'col1': tf.SparseTensor(indices=[[0, 1], [0, 5], [0, 7], [0, 9]], 
            values=[1, 2, 1, 3], 
            dense_shape=[3, int(250e6)]), 
       'col2': tf.SparseTensor(indices=[[0, 2], [0, 3]], 
            values=[4, 5], 
            dense_shape=[3, int(100e6)]), 
         'price': tf.constant([5.2, 0, 0])} 
    labels = tf.constant([0, 1, 1]) 
    return features, labels 

estimator.fit(input_fn=input_fn_train, steps=100) 

不过,我有这句话的一个问题,

sparse_feature_a = tf.contrib.layers.sparse_column_with_hash_bucket('col1', 3, dtype=tf.int32) 

其中3指hash_bucket_size = 3,但是这个稀疏张量包括4个非零值,

'col1': tf.SparseTensor(indices=[[0, 1], [0, 5], [0, 7], [0, 9]], 
           values=[1, 2, 1, 3], 
           dense_shape=[3, int(250e6)]) 

看来has_bucket_size在这里什么都不做。无论稀疏张量中有多少非零值,只需将其设置为大于1的整数即可正常工作。

我知道我的理解可能不对。任何人都可以解释如何has_bucket_size的作品?非常感谢!

回答

1

hash_bucket_size通过获取原始索引,将它们散列到指定大小的空间中,并使用散列索引作为特征来工作。

这意味着您可以在知道可能指数的全部范围之前指定您的模型,但代价可能是某些指数可能发生碰撞。

+0

谢谢!现在我想如果能猜出一个涵盖整个真实数据范围的数字,那么它是完美的。如果猜测数字小于预期,它也可以工作,但可能会导致不确定的行为。 –

相关问题