2016-12-05 300 views
2

我有一个Spark(1.5.2)数据框和训练有素的RandomForestClassificationModel。我可以很容易地得到数据并得到一个预测,但我想深入分析哪些边缘值是每个二进制分类场景中最常见的参与者。随机森林分析

在过去,我做了类似于RDD的功能,通过自己计算预测来跟踪功能使用情况。在下面的代码中,我跟踪用于计算预测的特征列表。 DataFrame似乎并不像RDD那样直截了当。

def predict(node:Node, features: Vector, path_in:Array[Int]) : (Double,Double,Array[Int]) = 
{ 
    if (node.isLeaf) 
    { 
     (node.predict.predict,node.predict.prob,path_in) 
    } 
    else 
    { 
     //track our path through the tree 
     val path = path_in :+ node.split.get.feature 

     if (node.split.get.featureType == FeatureType.Continuous) 
     { 
      if (features(node.split.get.feature) <= node.split.get.threshold) 
      { 
       predict(node.leftNode.get, features, path) 
      } 
      else 
      { 
       predict(node.rightNode.get, features, path) 
      } 
     } 
     else 
     { 
      if (node.split.get.categories.contains(features(node.split.get.feature))) 
      { 
       predict(node.leftNode.get, features, path) 
      } 
      else 
      { 
       predict(node.rightNode.get, features, path) 
      } 
     } 
    } 
} 

我想要做类似这样的代码什么的,而是针对每个特征向量我回所有功能/边缘值对的列表。请注意,在我的数据集中,所有功能都是分类的,并且在构建林时适当使用了仓设置。

回答

0

我最终建立一个自定义udf做到这一点:

//Base Prediction method. Accepts a Random Forest Model and a Feature Vector 
// Returns an Array of predictions, one per tree, the impurity, the feature used on the final edge, and the feature value. 
def predicForest(m:RandomForestClassificationModel, point: Vector) : (Double, Array[(Double,Double,(Int,Double))])={ 
    val results = m.trees.map(t=> predict(t.rootNode,point)) 

    (results.map(x=> x._1).sum/results.count(x=> true), results) 
} 

def predict(node:Node, features: Vector) : (Double,Double,(Int,Double)) = { 
    if (node.isInstanceOf[InternalNode]){ 
     //track our path through the tree 
     val internalNode = node.asInstanceOf[InternalNode] 
     if (internalNode.split.isInstanceOf[CategoricalSplit]) { 
     val split = internalNode.split.asInstanceOf[CategoricalSplit] 
     val featureValue = features(split.featureIndex) 
     if (split.leftCategories.contains(featureValue)) { 
      if (internalNode.leftChild.isInstanceOf[LeafNode]) { 
      (node.prediction,node.impurity,(internalNode.split.featureIndex, featureValue)) 
      } else 
      predict(internalNode.leftChild, features) 
     } else { 
      if (internalNode.rightChild.isInstanceOf[LeafNode]) { 
      (node.prediction,node.impurity,(internalNode.split.featureIndex, featureValue)) 
      } else 
      predict(internalNode.rightChild, features) 
     } 
     } else { 
     //If we run into an unimplemented type we just return 
     (node.prediction,node.impurity,(-1,-1)) 
     } 
    } else { 
     //If we run into an unimplemented type we just return 
     (node.prediction,node.impurity,(-1,-1)) 
    } 
} 

val rfModel = yourInstanceOfRandomForestClassificationModel 

//This custom UDF executes the Random Forest Classification in a trackable way 
def treeAnalyzer(m:RandomForestClassificationModel) = udf((x:Vector) => 
    predicForest(m,x)) 

//Execute the UDF, this will execute the Random Forest classification on each row and store the results from each tree in a new column named `prediction` 
val df3 = testData.withColumn("prediction", treeAnalyzer(rfModel)(testData("indexedFeatures")))