2016-01-13 49 views
-4

我是斯卡拉的新手,希望能够解决以下问题。斯卡拉:可变数量的数组之间的笛卡尔积

输入

我有一个Map[String, Array[Double]],看起来像如下:

Map(foo -> Array(12, 25, 100), bar -> Array(0.1, 0.001)) 

地图可以包含1和10键之间(在我的应用程序依赖于一些参数)。

处理

我想申请所有键的阵列之间的笛卡尔积并生成包含所有阵列的所有值的所有可能的组合的结构。

在上例中,笛卡尔产品将创建3x2=6不同的组合:(12, 0.1), (12, 0.001), (25, 0.1), (25, 0.01), (100,0.1) and (100, 0.01)

为了另一个例子,在某些情况下我可能有三个键:第一个键有4个值的数组,第二个键有5个值的数组,第三个键有3个数值的数组,在这个产品必须产生4x5x3=60不同的组合。

所需的输出

喜欢的东西:

Map(config1 -> (foo -> 12, bar -> 0.1), config2 -> (foo -> 12, bar -> 0.001), config3 -> (foo -> 25, bar -> 0.1), config4 -> (foo -> 25, bar -> 0.001), config5 -> (foo -> 100, bar -> 0.1), config6 -> (foo -> 100, bar -> 0.001)) 
+0

你认为有什么解决办法/试过这么远? – Dima

+0

TBH,无,任何解决方案将不胜感激 – Rami

回答

1

您可以使用为理解创建两个列表,数组的笛卡尔积...

val start = Map(
    'foo -> Array(12, 25, 100), 
    'bar -> Array(0.1, 0.001), 
    'baz -> Array(2)) 

// transform arrays into lists with values paired with map key 
val pairedWithKey = start.map { case (k,v) => v.map(i => k -> i).toList } 

val accumulator = pairedWithKey.head.map(x => Vector(x)) 
val cartesianProd = pairedWithKey.tail.foldLeft(accumulator)((acc, elem) => 
    for { x <- acc; y <- elem } yield x :+ y 
) 

cartesianProd foreach println 
// Vector(('foo,12), ('bar,0.1), ('baz,2)) 
// Vector(('foo,12), ('bar,0.001), ('baz,2)) 
// Vector(('foo,25), ('bar,0.1), ('baz,2)) 
// Vector(('foo,25), ('bar,0.001), ('baz,2)) 
// Vector(('foo,100), ('bar,0.1), ('baz,2)) 
// Vector(('foo,100), ('bar,0.001), ('baz,2)) 

你可能需要在使用headtail之前添加一些检查。

+0

感谢您的整洁解决方案彼得 – Rami

+2

这不是问题:( – Dima

1

由于数组的数量是动态的,所以有无法获得元组作为结果

你可以,但是,使用递归你的目的:

def process(a: Map[String, Seq[Double]]) = { 
    def product(a: List[(String, Seq[Double])]): Seq[List[(String, Double)]] = 
     a match { 
      case (name, values) :: tail => 
       for { 
        result <- product(tail) 
        value <- values 
       } yield (name, value) :: result 

      case Nil => Seq(List()) 
     } 

    product(a.toList) 
} 


val a = Map("foo" -> List(12.0, 25.0, 100.0), "bar" -> List(0.1, 0.001)) 
println(process(a)) 

其中给出的结果是:

List(List((foo,12.0), (bar,0.1)), List((foo,25.0), (bar,0.1)), List((foo,100.0), (bar,0.1)), List((foo,12.0), (bar,0.001)), List((foo,25.0), (bar,0.001)), List((foo,100.0), (bar,0.001)))