问题描述
我的数据框如下
cust_id req req_met
------- --- -------
1 r1 1
1 r2 0
1 r2 1
2 r1 1
3 r1 1
3 r2 1
4 r1 0
5 r1 1
5 r2 0
5 r1 1
我必须查看客户,查看他们有多少要求,并查看他们是否至少满足一次.可以有同一客户和需求的多个记录,一个满足和不满足的记录.在上述情况下,我的输出应为
I have to look at customers, see how many requirements they have and see if they have met at least once. There can be multiple records with same customer and requirement, one with met and not met. In the above case my output should be
cust_id
-------
1
2
3
我所做的是
# say initial dataframe is df
df1 = df\
.groupby('cust_id')\
.countdistinct('req')\
.alias('num_of_req')\
.sum('req_met')\
.alias('sum_req_met')
df2 = df1.filter(df1.num_of_req == df1.sum_req_met)
但是在少数情况下,无法获得正确的结果
But in few cases it is not getting correct results
这怎么办?
推荐答案
首先,我将从上面给出的玩具数据集中进行准备,
First, I'll just prepare toy dataset from given above,
from pyspark.sql.functions import col
import pyspark.sql.functions as fn
df = spark.createDataFrame([[1, 'r1', 1],
[1, 'r2', 0],
[1, 'r2', 1],
[2, 'r1', 1],
[3, 'r1', 1],
[3, 'r2', 1],
[4, 'r1', 0],
[5, 'r1', 1],
[5, 'r2', 0],
[5, 'r1', 1]], schema=['cust_id', 'req', 'req_met'])
df = df.withColumn('req_met', col("req_met").cast(IntegerType()))
df = df.withColumn('cust_id', col("cust_id").cast(IntegerType()))
我按cust_id
和req
分组进行相同的操作,然后计算req_met
.之后,我创建函数以将这些要求限制为0、1
I do the same thing by group by cust_id
and req
then count the req_met
. After that, I create function to floor those requirement to just 0, 1
def floor_req(r):
if r >= 1:
return 1
else:
return 0
udf_floor_req = udf(floor_req, IntegerType())
gr = df.groupby(['cust_id', 'req'])
df_grouped = gr.agg(fn.sum(col('req_met')).alias('sum_req_met'))
df_grouped_floor = df_grouped.withColumn('sum_req_met', udf_floor_req('sum_req_met'))
现在,我们可以通过计算不同数量的需求和已满足的需求总数来检查每个客户是否满足所有需求.
Now, we can check if each customer has met all requirement by counting distinct number of requirement and total number of requirement met.
df_req = df_grouped_floor.groupby('cust_id').agg(fn.sum('sum_req_met').alias('sum_req'),
fn.count('req').alias('n_req'))
最后,您只需要检查两列是否相等:
Finally, you just have to check if two columns are equal:
df_req.filter(df_req['sum_req'] == df_req['n_req'])[['cust_id']].orderBy('cust_id').show()
这篇关于通过过滤Pyspark Dataframe组的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持!