


I have a spark dataframe that contains sales prediction data for some products in some stores over a time period. How do I calculate the rolling sum of Predictions for a window size of next N values?


| ProductId | StoreId |    Date    | Prediction | N |
|         1 |     100 | 2019-07-01 | 0.92       | 2 |
|         1 |     100 | 2019-07-02 | 0.62       | 2 |
|         1 |     100 | 2019-07-03 | 0.89       | 2 |
|         1 |     100 | 2019-07-04 | 0.57       | 2 |
|         2 |     200 | 2019-07-01 | 1.39       | 3 |
|         2 |     200 | 2019-07-02 | 1.22       | 3 |
|         2 |     200 | 2019-07-03 | 1.33       | 3 |
|         2 |     200 | 2019-07-04 | 1.61       | 3 |


| ProductId | StoreId |    Date    | Prediction | N |       RollingSum       |
|         1 |     100 | 2019-07-01 | 0.92       | 2 | sum(0.92, 0.62)        |
|         1 |     100 | 2019-07-02 | 0.62       | 2 | sum(0.62, 0.89)        |
|         1 |     100 | 2019-07-03 | 0.89       | 2 | sum(0.89, 0.57)        |
|         1 |     100 | 2019-07-04 | 0.57       | 2 | sum(0.57)              |
|         2 |     200 | 2019-07-01 | 1.39       | 3 | sum(1.39, 1.22, 1.33)  |
|         2 |     200 | 2019-07-02 | 1.22       | 3 | sum(1.22, 1.33, 1.61 ) |
|         2 |     200 | 2019-07-03 | 1.33       | 3 | sum(1.33, 1.61)        |
|         2 |     200 | 2019-07-04 | 1.61       | 3 | sum(1.61)              |


There are lots of questions and answers to this problem in Python but I couldn't find any in PySpark.


Similar Question 1
There is a similar question here but in this one frame size is fixed to 3. In the provided answer rangeBetween function is used and it is only working with fixed sized frames so I cannot use it for varying sizes.


Similar Question 2
There is also a similar question here. In this one, writing cases for all possible sizes is suggested but it is not applicable for my case since I don't know how many distinct frame sizes I need to calculate.


Solution attempt 1
I've tried to solve the problem using a pandas udf:

rolling_sum_predictions = predictions.groupBy('ProductId', 'StoreId').apply(calculate_rolling_sums)

calculate_rolling_sums is a pandas udf where I solve the problem in python. This solution works with a small amount of test data. However, when the data gets bigger (in my case, the input df has around 1B rows), calculations take so long.

Solution attempt 2
I have used a workaround of the answer of Similar Question 1 above. I've calculated the biggest possible N, created the list using it and then calculate the sum of predictions by slicing the list.

predictions = predictions.withColumn('DayIndex', F.rank().over(Window.partitionBy('ProductId', 'StoreId').orderBy('Date')))

# find the biggest period
biggest_period = predictions.agg({"N": "max"}).collect()[0][0]

# calculate rolling predictions starting from the DayIndex
w = (Window.partitionBy(F.col("ProductId"), F.col("StoreId")).orderBy(F.col('DayIndex')).rangeBetween(0, biggest_period - 1))
rolling_prediction_lists = predictions.withColumn("next_preds", F.collect_list("Prediction").over(w))

# calculate rolling forecast sums
pred_sum_udf = udf(lambda preds, period: float(np.sum(preds[:period])), FloatType())
rolling_pred_sums = rolling_prediction_lists \
    .withColumn("RollingSum", pred_sum_udf("next_preds", "N"))


This solution is also works with the test data. I couldn't have chance to test it with the original data yet but whether it works or not I do not like this solution. Is there any smarter way to solve this?


如果您使用的是Spark 2.4+,则可以使用新的高阶数组函数 slice 聚合以有效地实现您的要求而无需任何UDF:

If you're using spark 2.4+, you can use the new higher-order array functions slice and aggregate to efficiently implement your requirement without any UDFs:

summed_predictions = predictions\
   .withColumn("summed", F.collect_list("Prediction").over(Window.partitionBy("ProductId", "StoreId").orderBy("Date").rowsBetween(Window.currentRow, Window.unboundedFollowing))\
   .withColumn("summed", F.expr("aggregate(slice(summed,1,N), cast(0 as double), (acc,d) -> acc + d)"))

|ProductId|StoreId|               Date|Prediction|  N|            summed|
|        1|    100|2019-07-01 00:00:00|      0.92|  2|              1.54|
|        1|    100|2019-07-02 00:00:00|      0.62|  2|              1.51|
|        1|    100|2019-07-03 00:00:00|      0.89|  2|              1.46|
|        1|    100|2019-07-04 00:00:00|      0.57|  2|              0.57|
|        2|    200|2019-07-01 00:00:00|      1.39|  3|              3.94|
|        2|    200|2019-07-02 00:00:00|      1.22|  3|              4.16|
|        2|    200|2019-07-03 00:00:00|      1.33|  3|2.9400000000000004|
|        2|    200|2019-07-04 00:00:00|      1.61|  3|              1.61|


