2017-01-16 109 views
3

我想确保我正在对数据的分层样本进行培训。使用Spark和Java进行分层抽样

看来这是由Spark 2.1和早期版本通过JavaPairRDD.sampleByKey(...)JavaPairRDD.sampleByKeyExact(...)支持的,如here的解释。

但是:我的数据存储在Dataset<Row>中,而不是JavaPairRDD。第一列是标签,其他都是功能(从libsvm格式的文件导入)。

获得我的数据集实例的分层样本并最终获得Dataset<Row>的最简单方法是什么?

在某种程度上,这个问题与Dealing with unbalanced datasets in Spark MLlib有关。

这个possible duplicate根本没有提到Dataset<Row>,它也不在Java中。它不回答我的问题。

回答

3

好吧,既然the question here的回答竟是不打算的Java,我在的Java改写它。

推理仍然是一样的想法。我们仍在使用sampleByKeyExact。没有现成的奇迹功能现在(火花2.1.0

所以在这里你去:

package org.awesomespark.examples; 

import org.apache.spark.api.java.JavaPairRDD; 
import org.apache.spark.api.java.JavaRDD; 
import org.apache.spark.api.java.function.PairFunction; 
import org.apache.spark.sql.*; 
import scala.Tuple2; 

import java.util.Map; 

public class StratifiedDatasets { 
    public static void main(String[] args) { 
     SparkSession spark = SparkSession.builder() 
       .appName("Stratified Datasets") 
       .getOrCreate(); 

     Dataset<Row> data = spark.read().format("libsvm").load("sample_libsvm_data.txt"); 

     JavaPairRDD<Double, Row> rdd = data.toJavaRDD().keyBy(x -> x.getDouble(0)); 
     Map<Double, Double> fractions = rdd.map(Tuple2::_1) 
       .distinct() 
       .mapToPair((PairFunction<Double, Double, Double>) (Double x) -> new Tuple2(x, 0.8)) 
       .collectAsMap(); 

     JavaRDD<Row> sampledRDD = rdd.sampleByKeyExact(false, fractions, 2L).values(); 
     Dataset<Row> sampledData = spark.createDataFrame(sampledRDD, data.schema()); 

     sampledData.show(); 
     sampledData.printSchema(); 
    } 
} 

现在让我们打包和提交:

$ sbt package 
[...] 
// [success] Total time: 2 s, completed Jan 16, 2017 1:45:51 PM 

$ spark-submit --class org.awesomespark.examples.StratifiedDatasets target/scala-2.10/java-stratified-dataset_2.10-1.0.jar 
[...] 

// +-----+--------------------+ 
// |label|   features| 
// +-----+--------------------+ 
// | 0.0|(692,[127,128,129...| 
// | 1.0|(692,[158,159,160...| 
// | 1.0|(692,[124,125,126...| 
// | 1.0|(692,[152,153,154...| 
// | 1.0|(692,[151,152,153...| 
// | 0.0|(692,[129,130,131...| 
// | 1.0|(692,[99,100,101,...| 
// | 0.0|(692,[154,155,156...| 
// | 0.0|(692,[127,128,129...| 
// | 1.0|(692,[154,155,156...| 
// | 0.0|(692,[151,152,153...| 
// | 1.0|(692,[129,130,131...| 
// | 0.0|(692,[154,155,156...| 
// | 1.0|(692,[150,151,152...| 
// | 0.0|(692,[124,125,126...| 
// | 0.0|(692,[152,153,154...| 
// | 1.0|(692,[97,98,99,12...| 
// | 1.0|(692,[124,125,126...| 
// | 1.0|(692,[156,157,158...| 
// | 1.0|(692,[127,128,129...| 
// +-----+--------------------+ 
// only showing top 20 rows 

// root 
// |-- label: double (nullable = true) 
// |-- features: vector (nullable = true) 

对于python用户,你也可以查看我的回答Stratified sampling with pyspark