2016-11-09 98 views
2

假设我有一个抽象类A。我也有类BC从类A继承。Spark Scala:将子类型传递给接受父类型的函数

abstract class A { 
    def x: Int 
} 
case class B(i: Int) extends A { 
    override def x = -i 
} 
case class C(i: Int) extends A { 
    override def x = i 
} 

鉴于这些类中,我们构建了以下RDD:

val data = sc.parallelize(Seq(
     Set(B(1), B(2)), 
     Set(B(1), B(3)), 
     Set(B(1), B(5)) 
    )).cache 
     .zipWithIndex 
     .map {case(k, v) => (v, k)} 

我还具有以下功能得到一个RDD作为输入,并返回每个元素的计数:

def f(data: RDD[(Long, Set[A])]) = { 
    data.flatMap({ 
    case (k, v) => v map { af => 
     (af, 1) 
    } 
    }).reduceByKey(_ + _) 
} 

请注意,RDD正在接受类型A。现在,我希望val x = f(data)返回预期的计数,作为B是子类型的A,但我得到以下编译错误:

type mismatch; 
found : org.apache.spark.rdd.RDD[(Long, scala.collection.immutable.Set[B])] 
required: org.apache.spark.rdd.RDD[(Long, Set[A])] 
    val x = f(data) 

这个错误消失,如果我改变函数签名f(data: RDD[(Long, Set[B])]);但是,我不能这样做,因为我想在RDD中使用其他子类(如C)。

我也曾尝试以下方法:

def f[T <: A](data: RDD[(Long, Set[T])]) = { 
    data.flatMap({ 
    case (k, v) => v map { af => 
     (af, 1) 
    } 
    }) reduceByKey(_ + _) 
} 

然而,这也给了我以下运行时错误:

value reduceByKey is not a member of org.apache.spark.rdd.RDD[(T, Int)] 
possible cause: maybe a semicolon is missing before `value reduceByKey'? 
     }) reduceByKey(_ + _) 

我感谢有这方面的帮助。

+2

仅仅因为B是A的子类型并不代表集[B]设置[A]的亚型。这是因为'Set'是不变的。你需要确保你的集合是一个集合[A] – puhlen

回答

2

Set[T]T不变,这意味着给定亚型的BASet[A]不是亚型也不的Set[B] RDD[T]的超类型也不变上T进一步限制的选项,因为,即使使用一个协变Collection[+T](例如一个List[+T])会出现相同的情况。

我们可以求助于替代方法的多态形式: 上面的版本中缺少的是Spark需要在擦除后保留类信息。

这应该工作:

import scala.reflect.{ClassTag} 
def f[T:ClassTag](data: RDD[(Long, Set[T])]) = { 
    data.flatMap({ 
    case (k, v) => v map { af => 
     (af, 1) 
    } 
    }) reduceByKey(_ + _) 
} 

让我们来看看:

val intRdd = sparkContext.parallelize(Seq((1l, Set(1,2,3)), (2L, Set(4,5,6)))) 
val res1= f(intRdd).collect 
// Array[(Int, Int)] = Array((4,1), (1,1), (5,1), (6,1), (2,1), (3,1)) 

val strRdd = sparkContext.parallelize(Seq((1l, Set("a","b","c")), (2L, Set("d","e","f")))) 
val res2 = f(strRdd).collect 
// Array[(String, Int)] = Array((d,1), (e,1), (a,1), (b,1), (f,1), (c,1)) 
+0

当RDD只包含一个对象的实例(例如在你的例子中是int和string)时,这是完美的。但假设你有这样的RDD:'val mixRdd = sc.parallelize(Seq((1l,Set(1,2,3)),(2L,Set(4,5,6)),(3L,Set “a”,“b”))))'。在这种情况下,代码失败。 – Ashkan

+0

这就是b/c Scala推断一个产品类型与'Set [_>:String with Int]'并且找不到ClassTag。如果你想将类型绑定到一个具体类型,那么它会工作:'val mixRdd:RDD [(Long,Set [Any])] = sc.parallelize(Seq((1l,Set(1,2,3)) ,(2L,Set(4,5,6)),(3L,Set(“a”,“b”))))' – maasg

相关问题