2017-02-22 45 views
1

我设法让我的决策树分类器适用于基于RDD的API,但现在我正试图切换到Spark中基于Dataframes的API。在带有字符串字段的spark中使用决策树分类器的数据框

我有一个这样的数据集(但有更多的字段):

国家,目的地,时间,标签

Belgium, France, 10, 0 
Bosnia, USA, 120, 1 
Germany, Spain, 30, 0 

首先,我加载一个数据帧我的csv文件:

val data = session.read 
    .format("org.apache.spark.csv") 
    .option("header", "true") 
    .csv("/home/Datasets/data/dataset.csv") 

然后,我将字符串列转换为数字列

val stringColumns = Array("country", "destination") 

val index_transformers = stringColumns.map(
    cname => new StringIndexer() 
    .setInputCol(cname) 
    .setOutputCol(s"${cname}_index") 
) 

然后我组装我的所有功能集成到一个单一的载体,使用VectorAssembler,像这样:

val assembler = new VectorAssembler() 
    .setInputCols(Array("country_index", "destination_index", "duration_index")) 
    .setOutputCol("features") 

我我的数据分割为训练和测试:

val Array(trainingData, testData) = data.randomSplit(Array(0.7, 0.3)) 

然后创建我DecisionTree分类

val dt = new DecisionTreeClassifier() 
    .setLabelCol("label") 
    .setFeaturesCol("features") 

然后我使用一个管道进行所有转换

val pipeline = new Pipeline() 
    .setStages(Array(index_transformers, assembler, dt)) 

我训练我的模型,并用它来预测:

val model = pipeline.fit(trainingData) 

val predictions = model.transform(testData) 

,但我得到了一些错误,我不明白:

当我运行我的代码这样的,我有这样的错误:

[error] found : Array[org.apache.spark.ml.feature.StringIndexer] 
[error] required: org.apache.spark.ml.PipelineStage 
[error]   .setStages(Array(index_transformers, assembler,dt)) 

因此,我所做的是,我添加了一个管道index_transformers VAL和Val汇编权前右后:

val index_pipeline = new Pipeline().setStages(index_transformers) 
val index_model = index_pipeline.fit(data) 
val df_indexed = index_model.transform(data) 

和我的训练集和测试集,我的新df_indexed数据框中使用,我从我的管道用汇编和DT

val Array(trainingData, testData) = df_indexed.randomSplit(Array(0.7, 0.3)) 

val pipeline = new Pipeline() 
    .setStages(Array(assembler,dt)) 

去除index_transformers我得到这个错误:

Exception in thread "main" java.lang.IllegalArgumentException: Data type StringType is not supported. 

它基本上说我在字符串上使用VectorAssembler,而我告诉它在df_indexed上使用它,它现在有一个数字column_index,但它似乎并没有在vectorAssembler中使用它,我只是不清楚和..

谢谢

编辑

现在我几乎设法得到它的工作:

val data = session.read 
    .format("org.apache.spark.csv") 
    .option("header", "true") 
    .csv("/home/hvfd8529/Datasets/dataOINIS/dataset.csv") 

val stringColumns = Array("country_index", "destination_index", "duration_index") 

val stringColumns_index = stringColumns.map(c => s"${c}_index") 

val index_transformers = stringColumns.map(
    cname => new StringIndexer() 
    .setInputCol(cname) 
    .setOutputCol(s"${cname}_index") 
) 

val assembler = new VectorAssembler() 
    .setInputCols(stringColumns_index) 
    .setOutputCol("features") 

val labelIndexer = new StringIndexer() 
    .setInputCol("label") 
    .setOutputCol("indexedLabel") 

val Array(trainingData, testData) = data.randomSplit(Array(0.7, 0.3)) 

// Train a DecisionTree model. 
val dt = new DecisionTreeClassifier() 
    .setLabelCol("indexedLabel") 
    .setFeaturesCol("features") 
    .setImpurity("entropy") 
    .setMaxBins(1000) 
    .setMaxDepth(15) 

// Convert indexed labels back to original labels. 
val labelConverter = new IndexToString() 
    .setInputCol("prediction") 
    .setOutputCol("predictedLabel") 
    .setLabels(labelIndexer.labels()) 

val stages = index_transformers :+ assembler :+ labelIndexer :+ dt :+ labelConverter 

val pipeline = new Pipeline() 
    .setStages(stages) 


// Train model. This also runs the indexers. 
val model = pipeline.fit(trainingData) 

// Make predictions. 
val predictions = model.transform(testData) 

// Select example rows to display. 
predictions.select("predictedLabel", "label", "indexedFeatures").show(5) 

// Select (prediction, true label) and compute test error. 
val evaluator = new MulticlassClassificationEvaluator() 
    .setLabelCol("indexedLabel") 
    .setPredictionCol("prediction") 
    .setMetricName("accuracy") 
val accuracy = evaluator.evaluate(predictions) 
println("accuracy = " + accuracy) 

val treeModel = model.stages(2).asInstanceOf[DecisionTreeClassificationModel] 
println("Learned classification tree model:\n" + treeModel.toDebugString) 

只是现在我有一个错误,说这样的:

value labels is not a member of org.apache.spark.ml.feature.StringIndexer 

,我不明白,因为我在跟随在火花DOC例子:/

回答

0

我做什么我的第一个问题:

val stages = index_transformers :+ assembler :+ labelIndexer :+ rf :+ labelConverter 

val pipeline = new Pipeline() 
    .setStages(stages) 

对于我的标签第二个问题,我需要使用.fit(数据)这样

val labelIndexer = new StringIndexer() 
    .setInputCol("label_fraude") 
    .setOutputCol("indexedLabel") 
    .fit(data) 
0

应该是:

val pipeline = new Pipeline() 
    .setStages(index_transformers ++ Array(assembler, dt): Array[PipelineStage]) 
+0

我仍然有同样的错误:( 我也曾尝试\t \t VAL阶段= index_transformers:+汇编:+ dt的 VAL管道=新管道() \t \t .setStages(级) 但不工作:不支持数据类型StringType:java.lang.IllegalArgumentException异常 –