2016-08-02 188 views
5

我正在处理大火花DataFrame中的一列数字,我想创建一个新列,该列存储该列中出现的唯一数字的汇总列表。有没有办法将限制参数传递给Spark中的functions.collect_set?

基本上正是functions.collect_set所做的。不过,我只需要聚合列表中最多1000个元素。有什么办法可以通过某种方式将参数传递给函数.collect_set()或者其他任何方式来获取聚合列表中最多1000个元素,而不使用UDAF?

由于列太大,我希望避免收集所有元素并在之后修剪列表。

谢谢!

回答

1

使用采取

val firstThousand = rdd.take(1000) 

将返回第1000 收集也有可以提供过滤功能。这可以让你对返回的内容更加具体。

+0

感谢您的回答。但是, 1)我只喜欢_distinct_值的列表。我看到有一个rdd.distinct(),但似乎没有限制参数 2)不知道如何在collect中使用过滤器函数。我将如何使用过滤器来获取一定数量的值? – user1500142

+0

此外,理想情况下我想避免使用rdds。我目前像df.groupBy()。agg( user1500142

1
scala> df.show 
    +---+-----+----+--------+ 
    | C0| C1| C2|  C3| 
    +---+-----+----+--------+ 
    | 10| Name|2016| Country| 
    | 11|Name1|2016|country1| 
    | 10| Name|2016| Country| 
    | 10| Name|2016| Country| 
    | 12|Name2|2017|Country2| 
    +---+-----+----+--------+ 

scala> df.groupBy("C1").agg(sum("C0")) 
res36: org.apache.spark.sql.DataFrame = [C1: string, sum(C0): bigint] 

scala> res36.show 
+-----+-------+ 
| C1|sum(C0)| 
+-----+-------+ 
|Name1|  11| 
|Name2|  12| 
| Name|  30| 
+-----+-------+ 

scala> df.limit(2).groupBy("C1").agg(sum("C0")) 
    res33: org.apache.spark.sql.DataFrame = [C1: string, sum(C0): bigint] 

    scala> res33.show 
    +-----+-------+ 
    | C1|sum(C0)| 
    +-----+-------+ 
    | Name|  10| 
    |Name1|  11| 
    +-----+-------+ 



    scala> df.groupBy("C1").agg(sum("C0")).limit(2) 
res2: org.apache.spark.sql.DataFrame = [C1: string, sum(C0): bigint] 

scala> res2.show 
+-----+-------+ 
| C1|sum(C0)| 
+-----+-------+ 
|Name1|  11| 
|Name2|  12| 
+-----+-------+ 

scala> df.distinct 
res8: org.apache.spark.sql.DataFrame = [C0: int, C1: string, C2: int, C3: string] 

scala> res8.show 
+---+-----+----+--------+ 
| C0| C1| C2|  C3| 
+---+-----+----+--------+ 
| 11|Name1|2016|country1| 
| 10| Name|2016| Country| 
| 12|Name2|2017|Country2| 
+---+-----+----+--------+ 

scala> df.dropDuplicates(Array("c1")) 
res11: org.apache.spark.sql.DataFrame = [C0: int, C1: string, C2: int, C3: string] 

scala> res11.show 
+---+-----+----+--------+              
| C0| C1| C2|  C3| 
+---+-----+----+--------+ 
| 11|Name1|2016|country1| 
| 12|Name2|2017|Country2| 
| 10| Name|2016| Country| 
+---+-----+----+--------+ 
+0

感谢您的答案,但这并不完全符合我的要求。如果我想从一列中取得1000个不同的值,“df.limit(1000)”将对返回值的数量设置一个硬性上限,但是我可能会丢失不同的值,否则我应该添加其他值。 – user1500142

+0

你有两种不同的方法,你可以在limit,groupby和agg方法之前执行dropDuplicates。 Distinct将查看所有列,droDuplicates允许您控制要比较哪些列以识别重复项。 @ user1500142 – mark

2

我正在使用collect_set和collect_list函数的修改副本;由于代码范围的原因,修改后的副本必须与原始文件位于相同的包路径中。链接的代码适用于Spark 2.1.0;如果您使用的是先前版本,方法签名可能会有所不同。

抛出此文件(https://gist.github.com/lokkju/06323e88746c85b2ce4de3ea9cdef9bc)为您的项目的src/main /组织/阿帕奇/火花/ SQL /催化剂/表达/ collect_limit.scala

使用它作为:

import org.apache.spark.sql.catalyst.expression.collect_limit._ 
df.groupBy('set_col).agg(collect_set_limit('set_col,1000) 
3

我的解决办法与Loki's answer with collect_set_limit非常相似。


我会使用一个UDF你想要什么,会做后collect_set(或collect_list)或更难UDAF。

鉴于UDF的更多经验,我会首先考虑。即使UDF没有优化,对于这种用例也没有问题。

val limitUDF = udf { (nums: Seq[Long], limit: Int) => nums.take(limit) } 
val sample = spark.range(50).withColumn("key", $"id" % 5) 

scala> sample.groupBy("key").agg(collect_set("id") as "all").show(false) 
+---+--------------------------------------+ 
|key|all         | 
+---+--------------------------------------+ 
|0 |[0, 15, 30, 45, 5, 20, 35, 10, 25, 40]| 
|1 |[1, 16, 31, 46, 6, 21, 36, 11, 26, 41]| 
|3 |[33, 48, 13, 38, 3, 18, 28, 43, 8, 23]| 
|2 |[12, 27, 37, 2, 17, 32, 42, 7, 22, 47]| 
|4 |[9, 19, 34, 49, 24, 39, 4, 14, 29, 44]| 
+---+--------------------------------------+ 

scala> sample. 
    groupBy("key"). 
    agg(collect_set("id") as "all"). 
    withColumn("limit(3)", limitUDF($"all", lit(3))). 
    show(false) 
+---+--------------------------------------+------------+ 
|key|all         |limit(3) | 
+---+--------------------------------------+------------+ 
|0 |[0, 15, 30, 45, 5, 20, 35, 10, 25, 40]|[0, 15, 30] | 
|1 |[1, 16, 31, 46, 6, 21, 36, 11, 26, 41]|[1, 16, 31] | 
|3 |[33, 48, 13, 38, 3, 18, 28, 43, 8, 23]|[33, 48, 13]| 
|2 |[12, 27, 37, 2, 17, 32, 42, 7, 22, 47]|[12, 27, 37]| 
|4 |[9, 19, 34, 49, 24, 39, 4, 14, 29, 44]|[9, 19, 34] | 
+---+--------------------------------------+------------+ 

functions对象(udf功能的文档)。

相关问题