2017-06-21 591 views
4

我正在解决文本分类问题。我用我自己的model_fn使用Estimator类定义了我的分类器。我想使用Google的预先训练好的word2vec嵌入作为初始值,然后针对当前的任务对其进行进一步优化。加载预先训练好的word2vec在Estimator中初始化embedding_lookup model_fn

我看到这篇文章:Using a pre-trained word embedding (word2vec or Glove) in TensorFlow
它解释了如何在'原始'TensorFlow代码中去解决它。但是,我真的很想使用Estimator类。

作为一个扩展,我想在Cloud ML引擎上训练这个代码,是否有一种很好的方式来传递具有初始值的相当大的文件?

比方说,我们有这样的事:

def build_model_fn(): 
    def _model_fn(features, labels, mode, params): 
     input_layer = features['feat'] #shape=[-1, params["sequence_length"]] 
     #... what goes here to initialize W 

     embedded = tf.nn.embedding_lookup(W, input_layer) 
     ... 
     return predictions 

estimator = tf.contrib.learn.Estimator(
    model_fn=build_model_fn(), 
    model_dir=MODEL_DIR, 
    params=params) 
estimator.fit(input_fn=read_data, max_steps=2500) 

回答

7

曲面嵌入足够,唯一可行的办法是用它们来初始化在图中的tf.Variable一般较大。这将允许你利用分布式的参数服务器等。

对于这个(和其他任何东西),我建议你使用新的“核心”估计器,因为这会使事情变得更容易。

从您提供的链接的答案,知道我们想要一个变量不是一个常量,我们可以采取的方法:

(2)使用饲料字典初始化变量,或 (3)从检查点加载可变


我将选择(3)第一,因为它更容易,更好地:

在你model_fn,只需初始化使用Tensor返回的变量由tf.contrib.framework.load_variable打电话。这就要求:

  1. ,你必须与你的嵌入
  2. 你知道检查站内的嵌入变量的完全合格名有效TF检查点。

的代码非常简单:

def model_fn(mode, features, labels, hparams): 
    embeddings = tf.Variable(tf.contrib.framework.load_variable(
     'gs://my-bucket/word2vec_checkpoints/', 
     'a/fully/qualified/scope/embeddings' 
)) 
    .... 
    return tf.estimator.EstimatorSpec(...) 

但是这种方法不会为你工作,如果你的嵌入不是由另一TF模型产生的,因此选项(2)。


对于(2),我们需要使用tf.train.Scaffold其基本上保持所有的选项用于开始tf.Session(其估计有意隐藏有许多原因)的配置对象。

您可以在model_fn返回的tf.train.EstimatorSpec中指定Scaffold

我们在我们的model_fn中创建一个占位符,并将其设为 对我们的嵌入变量进行初始化操作,然后通过Scaffold传递init_feed_dict。例如

def model_fn(mode, features, labels, hparams): 
    embed_ph = tf.placeholder(
     shape=[hparams.vocab_size, hparams.embedding_size], 
     dtype=tf.float32) 
    embeddings = tf.Variable(embed_ph) 
    # Define your model 
    return tf.estimator.EstimatorSpec(
     ..., # normal EstimatorSpec args 
     scaffold=tf.train.Scaffold(init_feed_dict={embed_ph: my_embedding_numpy_array}) 
) 

这里发生的是init_feed_dict将填充embed_ph占位符在运行时的值,那么这将使embeddings.initialization_op(占位符的分配),以运行。


+0

谢谢,只是一个很小的事情:它应该是'tf.estimator.EstimatorSpec(...,支架= tf.train.Scaffold(ini​​t_feed_dict = {embed_ph:my_embedding_numpy_array})' – Tristan

+0

感谢特里斯坦净距那语法,即使我有解释大声笑。 –

相关问题