2017-03-08 131 views
5

我想将spark中的结构体传递给udf。它正在更改字段名称并将其重命名为列位置。我如何解决它?在UDF中获得更改的Spark Struct结构体域名称

object TestCSV { 

      def main(args: Array[String]) { 

      val conf = new SparkConf().setAppName("localTest").setMaster("local") 
      val sc = new SparkContext(conf) 
      val sqlContext = new SQLContext(sc) 


      val inputData = sqlContext.read.format("com.databricks.spark.csv") 
        .option("delimiter","|") 
        .option("header", "true") 
        .load("test.csv") 


      inputData.printSchema() 

      inputData.show() 

      val groupedData = inputData.withColumn("name",struct(inputData("firstname"),inputData("lastname"))) 

      val udfApply = groupedData.withColumn("newName",processName(groupedData("name"))) 

      udfApply.show() 
      } 



      def processName = udf((input:Row) =>{ 

       println(input) 
       println(input.schema) 

       Map("firstName" -> input.getAs[String]("firstname"), "lastName" -> input.getAs[String]("lastname")) 

       }) 

     } 

输出:

root 
|-- id: string (nullable = true) 
|-- firstname: string (nullable = true) 
|-- lastname: string (nullable = true) 

+---+---------+--------+ 
| id|firstname|lastname| 
+---+---------+--------+ 
| 1|  jack| reacher| 
| 2|  john|  Doe| 
+---+---------+--------+ 

错误:

[jack,reacher] StructType(StructField(i[1],StringType,true), > StructField(i[2],StringType,true)) 17/03/08 09:45:35 ERROR Executor: Exception in task 0.0 in stage 2.0 (TID 2) java.lang.IllegalArgumentException: Field "firstname" does not exist.

+0

为什么不直接在udf中传递两个字符串(作为'字符串')? –

+0

这是可能的,但你不能在Spark UDF中传递超过10个字段作为参数。我在这里提供的是一个简化的用例。有时我必须在UDF中传递20多列。我如何实现这一目标? – hp2326

回答

1

你所遇到真的很奇怪。在玩了一段时间之后,我终于发现它可能与优化器引擎的问题有关。看来问题不是UDF,而是struct函数。

我得到它的工作(星火1.6.3),当我cachegroupedData,无缓存我知道你的报道例外:

import org.apache.spark.sql.Row 
import org.apache.spark.sql.hive.HiveContext 
import org.apache.spark.{SparkConf, SparkContext} 


object Demo { 

    def main(args: Array[String]): Unit = { 

    val sc = new SparkContext(new SparkConf().setAppName("Demo").setMaster("local[1]")) 
    val sqlContext = new HiveContext(sc) 
    import sqlContext.implicits._ 
    import org.apache.spark.sql.functions._ 


    def processName = udf((input: Row) => { 
     Map("firstName" -> input.getAs[String]("firstname"), "lastName" -> input.getAs[String]("lastname")) 
    }) 


    val inputData = 
     sc.parallelize(
     Seq(("1", "Kevin", "Costner")) 
    ).toDF("id", "firstname", "lastname") 


    val groupedData = inputData.withColumn("name", struct(inputData("firstname"), inputData("lastname"))) 
     .cache() // does not work without cache 

    val udfApply = groupedData.withColumn("newName", processName(groupedData("name"))) 
    udfApply.show() 
    } 
} 

或者您可以使用RDD API,使您的结构,但这种是不是很好:

case class Name(firstname:String,lastname:String) // define outside main 

val groupedData = inputData.rdd 
    .map{r => 
     (r.getAs[String]("id"), 
      Name(
      r.getAs[String]("firstname"), 
      r.getAs[String]("lastname") 
     ) 
     ) 
    } 
    .toDF("id","name") 
+0

谢谢@Raphael Roth。这对我来说现在工作。我会接受这个答案。 – hp2326