基于条件的窗口求和

基于条件的窗口求和

本文介绍了Pyspark:基于条件的窗口求和的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

考虑简单的DataFrame:

Consider the simple DataFrame:

from pyspark import SparkContext
import pyspark
from pyspark.sql import SparkSession
import pyspark.sql.functions as F
from pyspark.sql.window import Window
from pyspark.sql.types import *
from pyspark.sql.functions import pandas_udf, PandasUDFType
spark = SparkSession.builder.appName('Trial').getOrCreate()

simpleData = (("2000-04-17", "144", 1), \
    ("2000-07-06", "015", 1),  \
    ("2001-01-23", "015", -1),   \
    ("2001-01-18", "144", -1),  \
    ("2001-04-17", "198", 1),    \
    ("2001-04-18", "036", -1),  \
    ("2001-04-19", "012", -1),    \
    ("2001-04-19", "188", 1), \
    ("2001-04-25", "188", 1),\
    ("2001-04-27", "015", 1) \
  )

columns= ["dates", "id", "eps"]
df = spark.createDataFrame(data = simpleData, schema = columns)
df.printSchema()
df.show(truncate=False)

出局:

root
 |-- dates: string (nullable = true)
 |-- id: string (nullable = true)
 |-- eps: long (nullable = true)

+----------+---+---+
|dates     |id |eps|
+----------+---+---+
|2000-04-17|144|1  |
|2000-07-06|015|1  |
|2001-01-23|015|-1 |
|2001-01-18|144|-1 |
|2001-04-17|198|1  |
|2001-04-18|036|-1 |
|2001-04-19|012|-1 |
|2001-04-19|188|1  |
|2001-04-25|188|1  |
|2001-04-27|015|1  |
+----------+---+---+

我想将滚动窗口上 eps 列中的值求和,而在 id 列中仅保留任何给定ID的最后一个值.例如,定义一个5行的窗口并假设我们在2001-04-17上,我只想对每个给定的唯一ID的最后一个 eps 值求和.在5行中,我们只有3个不同的ID,因此总和必须是3个元素:-1代表ID 144(第四行),-1代表ID 015(第三行)和1代表ID 198(第五行)),总计为-1.

I would like to sum the values in the eps column over a rolling window keeping only the last value for any given ID in the id column. For example, defining a window of 5 rows and assuming we are on 2001-04-17, I want to sum only the last eps value for each given unique ID. In the 5 rows we have only 3 different ID, so the sum must be of 3 elements: -1 for the ID 144 (forth row), -1 for the ID 015 (third row) and 1 for the ID 198 (fifth row) for a total of -1.

在我看来,在滚动窗口中,我应该执行类似 F.sum(groupBy('id').agg(F.last('eps')))这样的操作,在滚动窗口中无法实现.

In my mind, within the rolling window I should do something like F.sum(groupBy('id').agg(F.last('eps'))) that of course is not possible to achieve in a rolling window.

我使用UDF获得了预期的结果.

I obtained the desired result using a UDF.

@pandas_udf(IntegerType(), PandasUDFType.GROUPEDAGG)
def fun_sum(id, eps):
    df = pd.DataFrame()
    df['id'] = id
    df['eps'] = eps
    value = df.groupby('id').last().sum()
    return value

然后:

w = Window.orderBy('dates').rowsBetween(-5,0)
df = df.withColumn('sum', fun_sum(F.col('id'), F.col('eps')).over(w))

问题是我的数据集包含超过800万行,并且使用此UDF执行此任务大约需要2个小时.

The problem is that my dataset contains more than 8 milion rows and performing this task with this UDF takes about 2 hours.

我一直在想是否有一种方法可以通过内置的PySpark函数来达到避免使用UDF的效果,或者至少是否有一种方法可以提高我的UDF的性能.

I was wandering whether there is a way to achieve the same result with built-in PySpark functions avoiding using a UDF or at least whether there is a way to improve the performance of my UDF.

为完整起见,所需的输出应为:

For completeness, the desired output should be:

+----------+---+---+----+
|dates     |id |eps|sum |
+----------+---+---+----+
|2000-04-17|144|1  |1   |
|2000-07-06|015|1  |2   |
|2001-01-23|015|-1 |0   |
|2001-01-18|144|-1 |-2  |
|2001-04-17|198|1  |-1  |
|2001-04-18|036|-1 |-2  |
|2001-04-19|012|-1 |-3  |
|2001-04-19|188|1  |-1  |
|2001-04-25|188|1  |0   |
|2001-04-27|015|1  |0   |
+----------+---+---+----+

必须使用 .rangeBetween()窗口来实现rseult.

the rseult must also be achievable using a .rangeBetween() window.

推荐答案

如果您还没有弄清楚,这是实现它的一种方法.

In case you haven't figured it out yet, here's one way of achieving it.

假设已定义并初始化了 df ,并按照您在问题中对其定义和初始化的方式.

Assuming that df is defined and initialised the way you defined and initialised it in your question.

导入所需的函数和类:

from pyspark.sql.functions import row_number, col
from pyspark.sql.window import Window

创建必要的 WindowSpec :

Create the necessary WindowSpec:

window_spec = (
    Window
    # Partition by 'id'.
    .partitionBy(df.id)
    # Order by 'dates', latest dates first.
    .orderBy(df.dates.desc())
)

使用分区数据创建 DataFrame :

partitioned_df = (
    df
    # Use the window function 'row_number()' to populate a new column
    # containing a sequential number starting at 1 within a window partition.
    .withColumn('row', row_number().over(window_spec))
    # Only select the first entry in each partition (i.e. the latest date).
    .where(col('row') == 1)
)

以防万一您想再次检查数据:

Just in case you want to double-check the data:

partitioned_df.show()

# +----------+---+---+---+
# |     dates| id|eps|row|
# +----------+---+---+---+
# |2001-04-19|012| -1|  1|
# |2001-04-25|188|  1|  1|
# |2001-04-27|015|  1|  1|
# |2001-04-17|198|  1|  1|
# |2001-01-18|144| -1|  1|
# |2001-04-18|036| -1|  1|
# +----------+---+---+---+

汇总数据:

sum_rows = (
    partitioned_df
    # Aggragate data.
    .groupBy()
    # Sum all rows in 'eps' column.
    .sum('eps')
    # Get all records as a list of Rows.
    .collect()
)

获取结果:

print(f"sum eps: {sum_rows[0][0]})
# sum eps: 0

这篇关于Pyspark:基于条件的窗口求和的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持!

07-31 00:50