2017-09-05 240 views
0

我已经训练了&测试了ML模型(GBTClassificationModel或RandomForestClassificationModel)。然后,我想保存经过训练的模型以供将来使用。所以我做了以下工作:如何加载训练过的RandomForestClassificationModel模型?

model.save("..."); 

例如,在保存它之后以GBTClassificationModel为例。保存的文件是包含“数据,元数据和treesMetadata”的目录。我的问题是如何使用这个保存的模型以供将来使用?例如,我想要做类似如下的事情:

model = spark.load("..."); 
Dataset<Row> predict_data= model_model.transform(dataset_test1) 

任何建议吗?谢谢。

UPDATE:

它原来是非常简单的:

GBTClassificationModel model1 = GBTClassificationModel.load("..."); 
Dataset<Row> predict_data= model1.transform(dataset_test) 

回答

2

您应该使用RandomForestClassificationModel.load方法。

负载(路径:字符串):RandomForestClassificationModel读取从输入路径,的read.load(path)快捷方式的ML实例。

在Scala中,你的情况,这将会是如下:

import org.apache.spark.ml.classification.RandomForestClassificationModel 
val model = RandomForestClassificationModel.load("/analytics_shared/qoe/km_model") 

我强烈建议使用星火MLlib的ML Pipeline功能:

ML管道提供一个统一的在DataFrame之上构建的一组高级API,可帮助用户创建和调整实用的机器学习管道。

随着ML管道它会这么容易,你只需用PipelineModel取代RandomForestClassificationModel

import org.apache.spark.ml.PipelineModel 
val model = PipelineModel.load("...")