我能够训练我的模型并使用ML引擎进行预测,但我的结果不包含任何识别信息。在提交一行时一次提交预测,但在提交多行时,这种方法可以正常工作,但我无法将预测连接回原始输入数据。 GCP documentation讨论了使用实例密钥,但我找不到任何使用实例密钥进行训练和预测的示例代码。以GCP人口普查为例,我将如何更新输入函数以通过图表传递一个唯一的ID,并在训练期间忽略它,但返回具有预测的唯一ID?或者,如果有人知道已经在使用密钥的另一个示例,那也可以提供帮助。使用实例密钥进行训练和预测
def serving_input_fn():
feature_placeholders = {
column.name: tf.placeholder(column.dtype, [None])
for column in INPUT_COLUMNS
}
features = {
key: tf.expand_dims(tensor, -1)
for key, tensor in feature_placeholders.items()
}
return input_fn_utils.InputFnOps(
features,
None,
feature_placeholders
)
def generate_input_fn(filenames,
num_epochs=None,
shuffle=True,
skip_header_lines=0,
batch_size=40):
def _input_fn():
files = tf.concat([
tf.train.match_filenames_once(filename)
for filename in filenames
], axis=0)
filename_queue = tf.train.string_input_producer(
files, num_epochs=num_epochs, shuffle=shuffle)
reader = tf.TextLineReader(skip_header_lines=skip_header_lines)
_, rows = reader.read_up_to(filename_queue, num_records=batch_size)
row_columns = tf.expand_dims(rows, -1)
columns = tf.decode_csv(row_columns, record_defaults=CSV_COLUMN_DEFAULTS)
features = dict(zip(CSV_COLUMNS, columns))
# Remove unused columns
for col in UNUSED_COLUMNS:
features.pop(col)
if shuffle:
features = tf.train.shuffle_batch(
features,
batch_size,
capacity=batch_size * 10,
min_after_dequeue=batch_size*2 + 1,
num_threads=multiprocessing.cpu_count(),
enqueue_many=True,
allow_smaller_final_batch=True
)
label_tensor = parse_label_column(features.pop(LABEL_COLUMN))
return features, label_tensor
return _input_fn
更新: 我能够使用建议的代码this answer below我只是需要改变它略微以更新model_fn_ops而不只是预测字典的输出方案。但是,这只有在我的服务输入功能针对类似于this的json输入进行编码时才有效。我的服务输入功能先前是在Census Core Sample中的CSV服务输入功能之后建模的。
我觉得我的问题来自build_standardized_signature_def函数,甚至更多,所以它调用的功能is_classification_problem。使用csv服务函数的输入字典长度为1,因此该逻辑使用classification_signature_def结束,其最终只显示分数(结果实际上是probabilities),而输入字典长度大于1且具有json服务输入功能而是使用包含所有输出的predict_signature_def。
这是ModelServer中的分类标记(CMLE用于推理)中的已知问题。在1.2中,EstimatorSpec允许您选择自己的导出方法,因此希望能够为您解决问题,但是需要重写才能使用tf.estimator.Estimator而不是tf.contrib.learn.Estimator。 –