1
尝试使用Spark mllib中的渐变增强树来运行多分类分类。但它给出了一个错误“仅支持二进制分类”。因变量有8个等级。数据有276列和7000个实例。运行模型后使用渐变增强多类分类Spark中的树:仅支持二进制分类
import org.apache.spark.SparkContext._
import org.apache.spark.mllib.linalg.Vectors
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.util.MLUtils
import org.apache.spark.mllib.feature.ChiSqSelector
import org.apache.spark.mllib.tree.GradientBoostedTrees
import org.apache.spark.mllib.tree.configuration.BoostingStrategy
import org.apache.spark.mllib.tree.model.GradientBoostedTreesModel
import org.apache.spark.mllib.util.MLUtils
val data = sc.textFile("data/mllib/train.csv")
val parsedData = data.map { line =>
val parts = line.split(',').map(_.toDouble)
LabeledPoint(parts(0), Vectors.dense(parts.tail))
}
// Discretize data in 10 equal bins since ChiSqSelector requires categorical features
// Even though features are doubles, the ChiSqSelector treats each unique value as a category
val discretizedData = parsedData.map { lp =>
LabeledPoint(lp.label, Vectors.dense(lp.features.toArray.map { x => (x/20).floor }))
}
// Create ChiSqSelector that will select top 5 features
val selector = new ChiSqSelector(200)
// Create ChiSqSelector model (selecting features)
val transformer = selector.fit(discretizedData)
// Filter the top 5 features from each feature vector
val filteredData = discretizedData.map { lp =>
LabeledPoint(lp.label, transformer.transform(lp.features))
}
//Splitting the data
val splits = filteredData.randomSplit(Array(0.7, 0.3), seed = 11L)
val training = splits(0).cache()
val test = splits(1)
// Train a GradientBoostedTrees model.
// The defaultParams for Classification use LogLoss by default.
val boostingStrategy = BoostingStrategy.defaultParams("Classification")
boostingStrategy.numIterations = 20 // Note: Use more iterations in practice.
boostingStrategy.treeStrategy.numClasses = 8
boostingStrategy.treeStrategy.maxDepth = 10
// Empty categoricalFeaturesInfo indicates all features are continuous.
boostingStrategy.treeStrategy.categoricalFeaturesInfo = Map[Int, Int]()
val model = GradientBoostedTrees.train(training, boostingStrategy)
错误:
java.lang.IllegalArgumentException: requirement failed: Only binary classificati
on is supported for boosting.
at scala.Predef$.require(Predef.scala:233)
at org.apache.spark.mllib.tree.configuration.BoostingStrategy.assertVali
d(BoostingStrategy.scala:60)
at org.apache.spark.mllib.tree.GradientBoostedTrees$.org$apache$spark$ml
lib$tree$GradientBoostedTrees$$boost(GradientBoostedTrees.scala:173)
at org.apache.spark.mllib.tree.GradientBoostedTrees.run(GradientBoostedT
rees.scala:71)
at org.apache.spark.mllib.tree.GradientBoostedTrees$.train(GradientBoost
edTrees.scala:143)
是否有可以做到这一点任何其他方式?
感谢Ben的建议 –