2016-11-14 84 views
0

我目前正在尝试在应用程序中使用训练有素的模型。如何在应用程序中使用tflearn训练模型?

我一直在使用this code来生成美国城市名称与LSTM模型。代码工作正常,我设法得到城市名称。

现在,我试图保存模型,以便我可以将其加载到不同的应用程序中,而无需再次训练模型。

这里是我的基本应用程序的代码:

from __future__ import absolute_import, division, print_function 

import os 
from six import moves 
import ssl 
import tflearn 
from tflearn.data_utils import * 


path = "US_cities.txt" 
maxlen = 20 
X, Y, char_idx = textfile_to_semi_redundant_sequences(
    path, seq_maxlen=maxlen, redun_step=3) 


# --- Create LSTM model 
g = tflearn.input_data(shape=[None, maxlen, len(char_idx)]) 
g = tflearn.lstm(g, 512, return_seq=True, name="lstm1") 
g = tflearn.dropout(g, 0.5, name='dropout1') 
g = tflearn.lstm(g, 512, name='lstm2') 
g = tflearn.dropout(g, 0.5, name='dropout') 
g = tflearn.fully_connected(g, len(char_idx), activation='softmax', name='fc') 
g = tflearn.regression(g, optimizer='adam', loss='categorical_crossentropy', 
          learning_rate=0.001) 


# --- Initializing model and loading 
model = tflearn.models.generator.SequenceGenerator(g, char_idx) 
model.load('myModel.tfl') 
print("Model is now loaded !") 


# 
# Main Application 
# 

while(True): 
    user_choice = input("Do you want to generate a U.S. city names ? [y/n]") 
    if user_choice == 'y': 
     seed = random_sequence_from_textfile(path, 20) 
     print("-- Test with temperature of 1.5 --") 
     model.generate(20, temperature=1.5, seq_seed=seed, display=True) 
    else: 
     exit() 

这里是我所得到的输出:

Do you want to generate a U.S. city names ? [y/n]y 
-- Test with temperature of 1.5 -- 
rk 
Orange Park AcresTraceback (most recent call last): 
    File "App.py", line 46, in <module> 
    model.generate(20, temperature=1.5, seq_seed=seed, display=True) 
    File "/usr/local/lib/python3.5/dist-packages/tflearn/models/generator.py", line 216, in generate 
    preds = self._predict(x)[0] 
    File "/usr/local/lib/python3.5/dist-packages/tflearn/models/generator.py", line 180, in _predict 
    return self.predictor.predict(feed_dict) 
    File "/usr/local/lib/python3.5/dist-packages/tflearn/helpers/evaluator.py", line 69, in predict 
    o_pred = self.session.run(output, feed_dict=feed_dict).tolist() 
    File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/client/session.py", line 717, in run 
    run_metadata_ptr) 
    File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/client/session.py", line 894, in _run 
    % (np_val.shape, subfeed_t.name, str(subfeed_t.get_shape()))) 
ValueError: Cannot feed value of shape (1, 25, 61) for Tensor 'InputData/X:0', which has shape '(?, 20, 61)' 

不幸的是,我不明白为什么形状已经改变时在我的应用程序中使用generate()。任何人都可以帮我解决这个问题吗?

预先感谢您

威廉

+0

这不会完全解决你的问题,但你可以尝试添加'seq_maxlen = 20'到'tflearn.models.generator.SequenceGenerator'。我想'25'来自这个构造参数。 – sygi

+0

你好sygi,谢谢你的回复,并且对于迟到的回复感到抱歉。我改变了seq_maxlen,现在形状问题已经修复!但是,正如你所说的那样,这个工作并不完美......所产生的名字根本不是新的。我试图将checkpoint_path添加到构造函数中,但仍未改变任何内容。 –

回答

0

解决了吗?

一个解决办法是简单的“模式”的python脚本由于加入参数解析器:

import argparse 
parser = argparse.ArgumentParser() 
parser.add_argument("mode", help="Train or/and test", nargs='+', choices=["train","test"]) 
args = parser.parse_args() 

然后

if args.mode == "train": 
    # define your model 
    # train the model 
    model.save('my_model.tflearn') 

if args.mode == "test": 
    model.load('my_model.tflearn') 
    # do whatever you want with your model 

我真的不明白为什么这个工程,为什么当你正在尝试从不同的脚本加载模型。 但我猜这应该是现在的罚款...

相关问题