2016-11-08 97 views
2

我的raw_data是PTB数据集。 我正在通过以下代码生成批次。feed_dict中喂食问题(Tensorflow)

def generate_batches(raw_data, batch_size, unrollings): 
    global data_index 
    data_len = len(raw_data) 
    num_batches = data_len // batch_size 
    inputs = [] 
    labels = [] 
    print (num_batches, data_len, batch_size) 
    for j in xrange(unrollings) : 
     inputs.append([]) 
     labels.append([]) 
     for i in xrange(batch_size) : 
     inputs[j].append(raw_data[i + data_index]) 
     labels[j].append(raw_data[i + data_index + 1])  
     data_index = (data_index + batch_size) % len(raw_data) 
    return inputs, labels 

在会话运行中,生成的相同批生产饲料feed_dict,如以下代码中所示。

for step in xrange(num_steps) : 
batch_inputs, batch_labels = generate_batches(train_dataset, batch_size, unrollings=5) 
feed_dict = dict() 
for i in range(unrollings): 
    feed_dict = {train_inputs : batch_inputs, train_labels : batch_labels} 
    _, l, predictions, lr = session.run([optimizer, loss, train_prediction, learning_rate], feed_dict=feed_dict) 

的培训投入和标签如下:

for _ in range(unrollings) : 
train_data.append(tf.placeholder(shape=[batch_size], dtype=tf.int32)) 
train_label.append(tf.placeholder(shape=[batch_size, 1], dtype=tf.float32)) 
train_inputs = train_data[:unrollings] 
train_labels = train_label[:unrollings] 

首先,我得到了错误TypeError: unhashable type: 'list'到我转换batch_input列表使用tuple(batch_input[i])这是在Python dictionary : TypeError: unhashable type: 'list'解释清楚元组。
解决:然后我得到这个错误TypeError: unhashable type: 'numpy.ndarray'

+1

您试图使用ndarray作为字典的关键字,它应该是字符串 –

+0

谢谢+1我更正了代码。 – SupposeXYZ

回答

1

我想你是误解feed_dict是如何工作的。但首先,python dict不接受任何不可关联的类作为关键的实例。 list和numpy.ndarray都不能用作dict键(即使你用一个元组包装它)。我发现list post解释关于字典的关键。

feed_dict如何工作

在图形中,应该有象征意义的张量创建占位符。假设您的原始数据是2D:(num_samples,num_features),第一个维度对应于样本的大小,第二个维度对应于特征的数量。假设标签是一种热门编码,并且总共有num_classes。

train_data = tf.placeholder(shape=[batch_size, num_features], dtype=tf.float32) 
train_labels = tf.placeholder(shape=[batch_size, num_classes], dtype=tf.float32) 
在会话建立feed_dict时

然后,使用这些符号占位张量的关键和采样batch_data的价值。

feed_dict = {train_data:batch_inputs, train_labels:batch_labels}