我的raw_data是PTB数据集。 我正在通过以下代码生成批次。feed_dict中喂食问题(Tensorflow)
def generate_batches(raw_data, batch_size, unrollings):
global data_index
data_len = len(raw_data)
num_batches = data_len // batch_size
inputs = []
labels = []
print (num_batches, data_len, batch_size)
for j in xrange(unrollings) :
inputs.append([])
labels.append([])
for i in xrange(batch_size) :
inputs[j].append(raw_data[i + data_index])
labels[j].append(raw_data[i + data_index + 1])
data_index = (data_index + batch_size) % len(raw_data)
return inputs, labels
在会话运行中,生成的相同批生产饲料feed_dict,如以下代码中所示。
for step in xrange(num_steps) :
batch_inputs, batch_labels = generate_batches(train_dataset, batch_size, unrollings=5)
feed_dict = dict()
for i in range(unrollings):
feed_dict = {train_inputs : batch_inputs, train_labels : batch_labels}
_, l, predictions, lr = session.run([optimizer, loss, train_prediction, learning_rate], feed_dict=feed_dict)
的培训投入和标签如下:
for _ in range(unrollings) :
train_data.append(tf.placeholder(shape=[batch_size], dtype=tf.int32))
train_label.append(tf.placeholder(shape=[batch_size, 1], dtype=tf.float32))
train_inputs = train_data[:unrollings]
train_labels = train_label[:unrollings]
首先,我得到了错误TypeError: unhashable type: 'list'
到我转换batch_input列表使用tuple(batch_input[i])
这是在Python dictionary : TypeError: unhashable type: 'list'解释清楚元组。
解决:然后我得到这个错误TypeError: unhashable type: 'numpy.ndarray'
。
。
您试图使用ndarray作为字典的关键字,它应该是字符串 –
谢谢+1我更正了代码。 – SupposeXYZ