我想计算在一个分组的Spark数据框架的列中有多少记录是真的,但是我不知道如何在Python中这样做。例如,我有一个带有regionsalaryIsUnemployed列的数据,其中IsUnemployed是一个布尔值。我想看看每个地区有多少失业者。我知道我们可以先做一个filter然后再做一个groupby但是我想同时生成两个聚合,如下所示

from pyspark.sql import functions as F
data.groupby("Region").agg(F.avg("Salary"), F.count("IsUnemployed"))

最佳答案

可能最简单的解决方案是一个简单的(c样式,其中CAST->1,TRUE->0)和FALSE

(data
    .groupby("Region")
    .agg(F.avg("Salary"), F.sum(F.col("IsUnemployed").cast("long"))))

一个更加通用和惯用的解决方案是SUMwithCASE WHEN
(data
    .groupby("Region")
    .agg(
        F.avg("Salary"),
        F.count(F.when(F.col("IsUnemployed"), F.col("IsUnemployed")))))

但这显然是一种过度杀戮。

10-07 16:57