0
我想读取数据集并使用决策树执行分类。 一旦我到达训练步骤,我会得到一个错误(如下所示)。尝试训练决策树分类器时出错
什么我迄今所做的:
第1步:读取数据
我有一个.txt
文件,该文件的格式为(text \t label)
制表符分隔:
val data = sparkSession.read.format("com.databricks.spark.csv")
.option("delimiter", "\t")
.load("data.txt")
并且此操作产生以下内容:
+---------+-------+
| text | label |
+---------+-------+
| text_1 | 0 |
| text_2 | 1 |
| text_3 | 1 |
| text_4 | 0 |
+---------+-------+
步骤2:分割数据
val splits = data.randomSplit(Array(0.7, 0.3))
val (trainingData, testData) = (splits(0), splits(1))
步骤3:参数调谐
val numClasses = 2
val categoricalFeaturesInfo = Map[String, Int]()
val impurity = "gini"
val maxDepth = 5
val maxBins = 32
步骤4:训练
val model = DecisionTree.trainClassifier(trainingData, numClasses, categoricalFeaturesInfo,
impurity, maxDepth, maxBins)
在该步骤余出现以下错误:
Main.scala:63: overloaded method value trainClassifier with alternatives:
(input: org.apache.spark.api.java.JavaRDD[org.apache.spark.mllib.regression.LabeledPoint],numClasses: Int,categoricalFeaturesInfo: java.util.Map[Integer,Integer],impurity: String,maxDepth: Int,maxBins: Int)org.apache.spark.mllib.tree.model.DecisionTreeModel <and>
(input: org.apache.spark.rdd.RDD[org.apache.spark.mllib.regression.LabeledPoint],numClasses: Int,categoricalFeaturesInfo: scala.collection.immutable.Map[Int,Int],impurity: String,maxDepth: Int,maxBins: Int)org.apache.spark.mllib.tree.model.DecisionTreeModel
cannot be applied to (org.apache.spark.sql.Dataset[org.apache.spark.sql.Row], Int, scala.collection.immutable.Map[String,Int], String, Int, Int)
val model = DecisionTree.trainClassifier(trainingData, numClasses, categoricalFeaturesInfo,
完整的代码如下所示:
import org.apache.spark.sql.SparkSession
import org.apache.spark.mllib.util.MLUtils
import org.apache.spark.mllib.tree.DecisionTree
import org.apache.spark.mllib.tree.model.DecisionTreeModel
object DC_classifier {
def main() {
val sparkSession = SparkSession.builder
.master("local")
.appName("Decision tree")
.getOrCreate()
val sc = sparkSession.sparkContext
import sparkSession.implicits._
val data = sparkSession.read.format("com.databricks.spark.csv")
.option("delimiter", "\t")
.load("data.txt")
val splits = data.randomSplit(Array(0.7, 0.3))
val (trainingData, testData) = (splits(0), splits(1))
val numClasses = 2
val categoricalFeaturesInfo = Map[String, Int]()
val impurity = "gini"
val maxDepth = 5
val maxBins = 32
val model = DecisionTree.trainClassifier(trainingData, numClasses, categoricalFeaturesInfo,
impurity, maxDepth, maxBins)
}
}
DC_classifier.main()
任何帮助将非常感激。
错误的API:'RDD' - >'org.apache.spark.mllib','Dataset' - >'org.apache.spark.ml'(https://spark.apache.org/docs/latest/在这种情况下ml-classification-regression.html#决策树分类器)。不要提及正确的类型('LabeledPoint','Vector'列)和特征提取/选择。 – zero323
你能更具体吗?谢谢。 –
首先转到https://spark.apache.org/docs/latest/ml-guide.html#example-pipeline,检查您感兴趣的API - main(= DataFrame/Dataset)或RDD,并按照示例进行操作。对于RDD,一定要检查签名。 – zero323