2016-06-09 382 views
0

如何在单列(在新DataFrame中)中将DataFrame中的多个列(例如3)组合成一个Spark DenseVector?类似于这个thread,但在Java中,并在下面提到的一些调整。在Spark DataFrame中合并多个列[Java]

我尝试使用UDF这样的:

private UDF3<Double, Double, Double, Row> toColumn = new UDF3<Double, Double, Double, Row>() { 

    private static final long serialVersionUID = 1L; 

    public Row call(Double first, Double second, Double third) throws Exception {   
     Row row = RowFactory.create(Vectors.dense(first, second, third)); 

     return row; 
    } 
}; 

,然后注册UDF:

sqlContext.udf().register("toColumn", toColumn, dataType); 

dataType是:

StructType dataType = DataTypes.createStructType(new StructField[]{ 
    new StructField("bla", new VectorUDT(), false, Metadata.empty()), 
    }); 

当我把这个UDF一个有3列的DataFrame并打印出新的DataFrame的模式,我得到这个:

root |-- features: struct (nullable = true) | |-- bla: vector (nullable = false)

这里的问题是,我需要一个载体来在外面,而不是一个结构内。 事情是这样的:

root 
|-- features: vector (nullable = true) 

我不知道如何得到这个,因为register功能需要UDF的返回类型为DataType(这反过来,不提供VectorType)

回答

0

你居然手动通过使用这种数据类型嵌套向量型成一个结构:

new StructField("bla", new VectorUDT(), false, Metadata.empty()), 

如果删除外StructField,你会得到你想要的东西。当然,在这种情况下,你需要修改你的函数定义的签名。也就是说,你需要返回Vector类型。

请参阅下面的具体示例,我的意思是以简单的JUnit测试的形式。

package sample.spark.test; 

import org.apache.spark.api.java.JavaSparkContext; 
import org.apache.spark.mllib.linalg.Vector; 
import org.apache.spark.mllib.linalg.VectorUDT; 
import org.apache.spark.mllib.linalg.Vectors; 
import org.apache.spark.sql.DataFrame; 
import org.apache.spark.sql.RowFactory; 
import org.apache.spark.sql.SQLContext; 
import org.apache.spark.sql.api.java.UDF3; 
import org.apache.spark.sql.types.DataTypes; 
import org.apache.spark.sql.types.Metadata; 
import org.apache.spark.sql.types.StructField; 
import org.junit.Test; 

import java.io.Serializable; 
import java.util.Arrays; 
import java.util.HashSet; 
import java.util.Set; 

import static org.junit.Assert.assertEquals; 
import static org.junit.Assert.assertTrue; 

public class ToVectorTest implements Serializable { 
    private static final long serialVersionUID = 2L; 

    private UDF3<Double, Double, Double, Vector> toColumn = new UDF3<Double, Double, Double, Vector>() { 

    private static final long serialVersionUID = 1L; 

    public Vector call(Double first, Double second, Double third) throws Exception { 
     return Vectors.dense(first, second, third); 
    } 
    }; 

    @Test 
    public void testUDF() { 
    // context 
    final JavaSparkContext sc = new JavaSparkContext("local", "ToVectorTest"); 
    final SQLContext sqlContext = new SQLContext(sc); 

    // test input 
    final DataFrame input = sqlContext.createDataFrame(
     sc.parallelize(
      Arrays.asList(
       RowFactory.create(1.0, 2.0, 3.0), 
       RowFactory.create(4.0, 5.0, 6.0), 
       RowFactory.create(7.0, 8.0, 9.0), 
       RowFactory.create(10.0, 11.0, 12.0) 
      )), 
     DataTypes.createStructType(
      Arrays.asList(
       new StructField("feature1", DataTypes.DoubleType, false, Metadata.empty()), 
       new StructField("feature2", DataTypes.DoubleType, false, Metadata.empty()), 
       new StructField("feature3", DataTypes.DoubleType, false, Metadata.empty()) 
      ) 
     ) 
    ); 
    input.registerTempTable("input"); 

    // expected output 
    final Set<Vector> expectedOutput = new HashSet<>(Arrays.asList(
     Vectors.dense(1.0, 2.0, 3.0), 
     Vectors.dense(4.0, 5.0, 6.0), 
     Vectors.dense(7.0, 8.0, 9.0), 
     Vectors.dense(10.0, 11.0, 12.0) 
    )); 

    // processing 
    sqlContext.udf().register("toColumn", toColumn, new VectorUDT()); 
    final DataFrame outputDF = sqlContext.sql("SELECT toColumn(feature1, feature2, feature3) AS x FROM input"); 
    final Set<Vector> output = new HashSet<>(outputDF.toJavaRDD().map(r -> r.<Vector>getAs("x")).collect()); 

    // evaluation 
    assertEquals(expectedOutput.size(), output.size()); 
    for (Vector x : output) { 
     assertTrue(expectedOutput.contains(x)); 
    } 

    // show the schema and the content 
    System.out.println(outputDF.schema()); 
    outputDF.show(); 

    sc.stop(); 
    } 
} 
+0

这正是我所需要的。不知何故,我设法不考虑从UDF返回一个Vector,并向VectorUDT注册函数。感谢罗伯特! – Rajko