2017-07-24 69 views
0

我有一个数据帧,看起来像这样:星火合并套常见的元素

+-----------+-----------+ 
| Package | Addresses | 
+-----------+-----------+ 
| Package 1 | address1 | 
| Package 1 | address2 | 
| Package 1 | address3 | 
| Package 2 | address3 | 
| Package 2 | address4 | 
| Package 2 | address5 | 
| Package 2 | address6 | 
| Package 3 | address7 | 
| Package 3 | address8 | 
| Package 4 | address9 | 
| Package 5 | address9 | 
| Package 5 | address1 | 
| Package 6 | address10 | 
| Package 7 | address8 | 
+-----------+-----------+ 

我需要找到被视为一起在不同的软件包的所有地址。输出示例:

+----+------------------------------------------------------------------------+ 
| Id |        Addresses        | 
+----+------------------------------------------------------------------------+ 
| 1 | [address1, address2, address3, address4, address5, address6, address9] | 
| 2 | [address7, address8]             | 
| 3 | [address10]               | 
+----+------------------------------------------------------------------------+ 

所以,我有DataFrame。我被package分组它(而不是分组):

val rdd = packages.select($"package", $"address"). 
    map{ 
    x => { 
     (x(0).toString(), x(1).toString()) 
    } 
    }.rdd.combineByKey(
    (source) => { 
    Set[String](source) 
    }, 

    (acc: Set[String], v) => { 
    acc + v 
    }, 

    (acc1: Set[String], acc2: Set[String]) => { 
    acc1 ++ acc2 
    } 
) 

然后,我合并具有共同地址行:

val result = rdd.treeAggregate(
    Set.empty[Set[String]] 
)(
    (map: Set[Set[String]], row) => { 
    val vals = row._2 
    val sets = map + vals 

    // copy-paste from here https://stackoverflow.com/a/25623014/772249 
    sets.foldLeft(Set.empty[Set[String]])((cum, cur) => { 
     val (hasCommon, rest) = cum.partition(_ & cur nonEmpty) 
     rest + (cur ++ hasCommon.flatten) 
    }) 
    }, 
    (map1, map2) => { 
    val sets = map1 ++ map2 

    // copy-paste from here https://stackoverflow.com/a/25623014/772249 
    sets.foldLeft(Set.empty[Set[String]])((cum, cur) => { 
     val (hasCommon, rest) = cum.partition(_ & cur nonEmpty) 
     rest + (cur ++ hasCommon.flatten) 
    }) 
    }, 
    10 
) 

但是,无论我做什么,treeAggregate正在很长,我不能完成单一任务。原始数据大小约为250GB。我尝试过不同的群集,但treeAggregate花费的时间太长。

treeAggregate之前的所有内容都很好用,但之后就会出现问题。

我试过了不同的spark.sql.shuffle.partitions(默认值是2000,10000),但它似乎并不重要。

我试过不同depthtreeAggregate,但没有注意到区别。

相关问题:

  1. Merge Sets of Sets that contain common elements in Scala
  2. Spark complex grouping
+0

我不知道我明白你想要做什么。为什么不这样做:packages.groupBy(“packages”)。agg(collect_set(“address”))? –

+0

@AssafMendelson因为它会给我完全不同的结果比我需要。请仔细观察预期结果。如果我会分组,我会得到7个不同的结果,但我预计只有三个。 – twoface88

+0

@AssafMendelson示例:address4和address1属于一起,即使它们属于不同的包,因为地址3已经在package1和package2中看到。因此,来自package1和package2的所有地址都属于同一个地址,依此类推。 – twoface88

回答

3

看看你的数据,就好像它在这里的地址是顶点的图,他们有一个连接,如果有包两个都。那么解决您的问题将是图的connected components

Sparks gpraphX库具有优化函数以查找connected components。它将返回不同连接组件中的顶点,将它们视为每个连接组件的ID。

然后有了id,你可以收集连接到它的所有其他地址,如果需要的话。

看看this article他们如何使用图表来实现与您相同的分组。

+0

这非常有趣,谢谢。 – twoface88