2017-04-21 114 views
1

我想从PostgreSQL数据库将大约1M行加载到Spark中。使用Spark时,大约需要10秒钟。但是,使用psycopg2驱动程序加载相同的查询需要2s。我使用PostgreSQL JDBC驱动版本42.0.0Spark从Postgres JDBC表缓慢读取

def _loadFromPostGres(name): 
    url_connect = "jdbc:postgresql:"+dbname 
    properties = {"user": "postgres", "password": "postgres"} 
    df = SparkSession.builder.getOrCreate().read.jdbc(url=url_connect, table=name, properties=properties) 
    return df 

df = _loadFromPostGres(""" 
    (SELECT "seriesId", "companyId", "userId", "score" 
    FROM user_series_game 
    WHERE "companyId"=655124304077004298) as 
user_series_game""") 

print measure(lambda : len(df.collect())) 

输出是 -

--- 10.7214591503 seconds --- 
1076131 

使用psycopg2 -

import psycopg2 
conn = psycopg2.connect(conn_string) 
cur = conn.cursor() 

def _exec(): 
    cur.execute("""(SELECT "seriesId", "companyId", "userId", "score" 
     FROM user_series_game 
     WHERE "companyId"=655124304077004298)""") 
    return cur.fetchall() 
print measure(lambda : len(_exec())) 
cur.close() 
conn.close() 

输出是 -

--- 2.27961301804 seconds --- 
1076131 

的测量功能 -

def measure(func) : 
    start_time = time.time() 
    x = func() 
    print("--- %s seconds ---" % (time.time() - start_time)) 
    return x 

请帮我看看这个问题的原因。


编辑1

我做了几个标准。使用Scala和JDBC -

import java.sql._; 
import scala.collection.mutable.ArrayBuffer; 

def exec() { 

val url = ("jdbc:postgresql://prod.caumccqvmegm.ap-southeast-1.rds.amazonaws.com/prod"+ 
    "?tcpKeepAlive=true&prepareThreshold=-1&binaryTransfer=true&defaultRowFetchSize=10000") 

val conn = DriverManager.getConnection(url,"postgres","postgres"); 

val sqlText = """SELECT "seriesId", "companyId", "userId", "score" 
     FROM user_series_game 
     WHERE "companyId"=655124304077004298""" 

val t0 = System.nanoTime() 

val stmt = conn.prepareStatement(sqlText, ResultSet.TYPE_FORWARD_ONLY, ResultSet.CONCUR_READ_ONLY) 

val rs = stmt.executeQuery() 

val list = new ArrayBuffer[(Long, Long, Long, Double)]() 

while (rs.next()) { 
    val seriesId = rs.getLong("seriesId") 
    val companyId = rs.getLong("companyId") 
    val userId = rs.getLong("userId") 
    val score = rs.getDouble("score") 
    list.append((seriesId, companyId, userId, score)) 
} 

val t1 = System.nanoTime() 

println("Elapsed time: " + (t1 - t0) * 1e-9 + "s") 

println(list.size) 

rs.close() 
stmt.close() 
conn.close() 
} 

exec() 

产量为 -

Elapsed time: 1.922102285s 
1143402 

当我没有收集()在星火+斯卡拉 -

import org.apache.spark.sql.SparkSession 

def exec2() { 

    val spark = SparkSession.builder().getOrCreate() 

    val url = ("jdbc:postgresql://prod.caumccqvmegm.ap-southeast-1.rds.amazonaws.com/prod"+ 
    "?tcpKeepAlive=true&prepareThreshold=-1&binaryTransfer=true&defaultRowFetchSize=10000") 

    val sqlText = """(SELECT "seriesId", "companyId", "userId", "score" 
     FROM user_series_game 
     WHERE "companyId"=655124304077004298) as user_series_game""" 

    val t0 = System.nanoTime() 

    val df = spark.read 
      .format("jdbc") 
      .option("url", url) 
      .option("dbtable", sqlText) 
      .option("user", "postgres") 
      .option("password", "postgres") 
      .load() 

    val list = df.collect() 

    val t1 = System.nanoTime() 

    println("Elapsed time: " + (t1 - t0) * 1e-9 + "s") 

    print (list.size) 
} 

exec2() 

产量为

Elapsed time: 1.486141076s 
1143445 

因此,在Python中花费了4倍的额外时间序列化。我知道会有一些惩罚,但这似乎太多了。

回答

0

原因很简单,有两个同时的原因。

首先我会告诉你psycopg2是如何工作的。

这个lib psycopg2的工作方式与其他任何lib连接到一个RDMS。这个lib会将查询发送到你的postgres的引擎,它会将数据返回给你。像这样直向前进。

康涅狄格州 - >查询 - > ReturnData - > FetchData

当您使用的火花在两个方面有一点点不同。 Spark不像一个运行在单线程中的编程语言。它有一个分布式系统来工作。即使你在本地机器上运行。参见Spark有一个基本的司机(主)和工人的概念。

驱动程序接收到执行查询的请求到Postgres,驱动程序将不会为每个工作者请求来自Postgres的信息。

如果你看到的文档here你会SE这样一张纸条:

不要在大型​​集群上并行创建分区太多;否则Spark可能会导致外部数据库系统崩溃。

本说明表示每位工作人员都有责任为您的postgres请求数据。这是开始这个​​过程的一个小开销,但没有什么大的。但是在这里有一个开销,将数据发送给每个工作人员。

Seccond点,你收集的这部分代码:

print measure(lambda : len(df.collect())) 

Collect函数将发送一个命令为所有的工人将数据发送到驱动程序。为了存储你的驱动程序的内存,它就像一个Reduce,它在过程中创建Shuffle。洗牌是数据发送给其他员工的过程的一部分。在收集的情况下,每个工作人员都会将其发送给您的司机。

所以,在你的代码的JDBC星火的步骤是:

(工人)美国康涅狄格州 - >(工人)查询 - >(工人)FetchData - >(驱动器) 请求中的数据 - >(工人)Shuffle - >(驱动程序)收集

那么在Spark中发生的一堆其他东西,如QueryPlan,构建DataFrame和其他东西。

这就是为什么你的Python简单代码比Spark更快的原因。

+0

我们正在将数据从postgresql加载到spark中遇到重大问题。基本上,我们的想法是将驱动程序中的所有数据加载到熊猫数据框中,并将其转换为火花数据帧,然后运行spark分布式。你会建议什么? – Sandeep

+0

请不要将所有的数据加载到熊猫,这将是不好的。如果你有一个Spark Cluster,你应该使用python的JDBC工具从postgres加载数据,直接将数据加载到worker。 https://spark.apache.org/docs/latest/api/python/pyspark.sql.html#pyspark.sql.DataFrameReader.jdbc –