Keras新手在这里。我在一个非常大的CSV文件上做了一些深入的学习实验(keras 2.x,tensorflow作为背景,python3.5)。如何创建一个在Keras模式下读取一个巨型数据框的线程安全生成器fit_generator
将CSV加载到Pandas数据框后,我需要读取数据帧以将数据转换为X_train,y_train/label。因为转换后的X_train非常大,不适合内存。我开始使用generator和model.fit_generator()。我已经了解到,通过创建一个线程安全的生成器,我可以使用多个工作器,并使用use_multiprocessing = True,以便更高效。然而,在我的情况下,在内部生成器中它总是读取相同的数据帧,我想知道如何使它成为线程安全的,因为相同的数据/行不会被多个生成器实例读取并生成?没有线程安全
我的电流发生器的实现是这样的:
data = pd.read_csv("data.csv", header=0, delimiter="\t", quoting=3, encoding="utf-8")
y = data.label
X_train, X_test, y_train, y_test = train_test_split(data, y, test_size=0.2)
def data_genereator(data, batch_size):
num_rows = int(data.shape[0])
# Initialize a counter
counter = 0
while True:
for content, label in zip(data['content'], data['label']):
X_train[counter%batch_size] = transform(content)
y_train[counter%batch_size] = np.asarray(label)
counter = counter + 1
if(counter%batch_size == 0):
yield X_train, y_train
training_generator = data_genereator(X_train, batch_size=1024)
validation_generator = data_genereator(X_test, batch_size=1024)
model = Sequential()
model.add(LSTM(64, input_shape=(1000, 2400), return_sequences=False,
kernel_initializer='he_normal', dropout=0.15, recurrent_dropout=0.15, implementation=2))
model.add(Dropout(0.3))
model.add(Dense(1, activation = 'sigmoid'))
model.compile(loss='binary_crossentropy',
optimizer='adam',
metrics=['accuracy'])
model.fit_generator(training_generator,
steps_per_epoch=8000,
validation_data=validation_generator,
epochs=3,
verbose=1,
workers=1,
use_multiprocessing=False,
validation_steps=2000)
我可能是完全错误的,但想要得到我的作品和use_multiprocessing参数的了解您的反馈,是多个生成器实例(如生产商)将被启动以将数据馈送到由model.fit_generator()函数创建/维护的队列中,同时将数据从队列中抓取到GPU以用于训练(消费者)。如果使用GPU进行培训不是瓶颈,那么发电机可以生产/生产的数据越多,整个过程就会越快。我默认了max_queue_size = 10,一旦生成器是线程安全的,如何定义正确的max_queue_size?
另外,有没有一种方法可以衡量天气发生器(生产者)或GPU培训(消费者)的瓶颈? 我使用verbose = 1来打印状态栏,以及单个线程生成器产生多少行。现在,它总是喜欢:
行数的产量=(max_queue_size +步数已处理)的batch_size *
所以我真的不能告诉如果发电机太慢喂在数据中或GPU训练是瓶颈的时候,似乎稍后队列总是满员,但我不确定,任何洞察力都非常感谢。谢谢!
Keras建议您使用'Sequence'此:https://keras.io/utils/ –
还是提到[这里](https://开头stanford.edu/~shervine/blog/keras-generator-multiprocessing.html),使用一个简单的锁定机制使迭代器/生成器线程安全 – scarecrow
感谢Daniel,再次:)我没有发现除https以外的太多示例: //gist.github.com/alxndrkalinin/6cc4228e9178ec4af7b2696a0d1ad5a1,会试试看。在我使用model.fit_generator()时,我注意到,在第二个时期,由于已经完成了半个步骤,准确度开始下降,它一直下降得很厉害,并且从未再次上升。你能否对这种情况有所了解?这是否在同一个时代过度适应?您能不能请我纠正我对Queue,多处理工作者和吞吐量瓶颈测量的理解? –