2017-03-02 100 views
0

我想读取数据集并使用决策树执行分类。 一旦我到达训练步骤,我会得到一个错误(如下所示)。尝试训练决策树分类器时出错

什么我迄今所做的:

第1步:读取数据

我有一个.txt文件,该文件的格式为(text \t label)制表符分隔:

val data = sparkSession.read.format("com.databricks.spark.csv") 
     .option("delimiter", "\t") 
     .load("data.txt") 

并且此操作产生以下内容:

+---------+-------+ 
    | text | label | 
    +---------+-------+ 
    | text_1 | 0 | 
    | text_2 | 1 | 
    | text_3 | 1 | 
    | text_4 | 0 | 
    +---------+-------+ 

步骤2:分割数据

val splits = data.randomSplit(Array(0.7, 0.3)) 
val (trainingData, testData) = (splits(0), splits(1)) 

步骤3:参数调谐

val numClasses = 2 
val categoricalFeaturesInfo = Map[String, Int]() 
val impurity = "gini" 
val maxDepth = 5 
val maxBins = 32 

步骤4:训练

val model = DecisionTree.trainClassifier(trainingData, numClasses, categoricalFeaturesInfo, 
    impurity, maxDepth, maxBins) 

在该步骤余出现以下错误:

Main.scala:63: overloaded method value trainClassifier with alternatives: 
    (input: org.apache.spark.api.java.JavaRDD[org.apache.spark.mllib.regression.LabeledPoint],numClasses: Int,categoricalFeaturesInfo: java.util.Map[Integer,Integer],impurity: String,maxDepth: Int,maxBins: Int)org.apache.spark.mllib.tree.model.DecisionTreeModel <and> 
    (input: org.apache.spark.rdd.RDD[org.apache.spark.mllib.regression.LabeledPoint],numClasses: Int,categoricalFeaturesInfo: scala.collection.immutable.Map[Int,Int],impurity: String,maxDepth: Int,maxBins: Int)org.apache.spark.mllib.tree.model.DecisionTreeModel 
cannot be applied to (org.apache.spark.sql.Dataset[org.apache.spark.sql.Row], Int, scala.collection.immutable.Map[String,Int], String, Int, Int) 
     val model = DecisionTree.trainClassifier(trainingData, numClasses, categoricalFeaturesInfo, 

完整的代码如下所示:

import org.apache.spark.sql.SparkSession 
import org.apache.spark.mllib.util.MLUtils 
import org.apache.spark.mllib.tree.DecisionTree 
import org.apache.spark.mllib.tree.model.DecisionTreeModel 

object DC_classifier { 
    def main() { 

     val sparkSession = SparkSession.builder 
      .master("local") 
      .appName("Decision tree") 
      .getOrCreate() 

     val sc = sparkSession.sparkContext 
     import sparkSession.implicits._ 

     val data = sparkSession.read.format("com.databricks.spark.csv") 
      .option("delimiter", "\t") 
      .load("data.txt") 

     val splits = data.randomSplit(Array(0.7, 0.3)) 
     val (trainingData, testData) = (splits(0), splits(1)) 

     val numClasses = 2 
     val categoricalFeaturesInfo = Map[String, Int]() 
     val impurity = "gini" 
     val maxDepth = 5 
     val maxBins = 32 

     val model = DecisionTree.trainClassifier(trainingData, numClasses, categoricalFeaturesInfo, 
      impurity, maxDepth, maxBins) 

    } 

} 

DC_classifier.main() 

任何帮助将非常感激。

+1

错误的API:'RDD' - >'org.apache.spark.mllib','Dataset' - >'org.apache.spark.ml'(https://spark.apache.org/docs/latest/在这种情况下ml-classification-regression.html#决策树分类器)。不要提及正确的类型('LabeledPoint','Vector'列)和特征提取/选择。 – zero323

+0

你能更具体吗?谢谢。 –

+1

首先转到https://spark.apache.org/docs/latest/ml-guide.html#example-pipeline,检查您感兴趣的API - main(= DataFrame/Dataset)或RDD,并按照示例进行操作。对于RDD,一定要检查签名。 – zero323

回答