这可以使用聚合来完成,但是这种方法比pandas方法具有更高的复杂度。但是您可以使用UDF实现类似的性能。它不会像大熊猫一样优雅,但:
假定该数据集的节日:
holidays = ['2016-01-03', '2016-09-09', '2016-12-12', '2016-03-03']
index = spark.sparkContext.broadcast(sorted(holidays))
而2016年的日期数据集的数据帧:
from datetime import datetime, timedelta
dates_array = [(datetime(2016, 1, 1) + timedelta(i)).strftime('%Y-%m-%d') for i in range(366)]
from pyspark.sql import Row
df = spark.createDataFrame([Row(date=d) for d in dates_array])
UDF可以使用熊猫searchsorted
,但需要在执行者上安装熊猫。 insted的,你可以使用Python的计划是这样的:
def nearest_holiday(date):
last_holiday = index.value[0]
for next_holiday in index.value:
if next_holiday >= date:
break
last_holiday = next_holiday
if last_holiday > date:
last_holiday = None
if next_holiday < date:
next_holiday = None
return (last_holiday, next_holiday)
from pyspark.sql.types import *
return_type = StructType([StructField('last_holiday', StringType()), StructField('next_holiday', StringType())])
from pyspark.sql.functions import udf
nearest_holiday_udf = udf(nearest_holiday, return_type)
,可与withColumn
使用:
df.withColumn('holiday', nearest_holiday_udf('date')).show(5, False)
+----------+-----------------------+
|date |holiday |
+----------+-----------------------+
|2016-01-01|[null,2016-01-03] |
|2016-01-02|[null,2016-01-03] |
|2016-01-03|[2016-01-03,2016-01-03]|
|2016-01-04|[2016-01-03,2016-03-03]|
|2016-01-05|[2016-01-03,2016-03-03]|
+----------+-----------------------+
only showing top 5 rows
谢谢,这看起来不错。我需要将它移植到scala中;) –
你指的是什么'sorted(holidays)'操作?它是一个pyspark api吗? –
这是python的。它对UDF进行排序,我可以通过它来查找匹配的日期。 – Mariusz