There’s a DataFrame in pyspark with data as below:
user_id object_id score user_1 object_1 3 user_1 object_1 1 user_1 object_2 2 user_2 object_1 5 user_2 object_2 2 user_2 object_2 6
What I expect is returning 2 records in each group with the same user_id, which need to have the highest score. Consequently, the result should look as the following:
user_id object_id score user_1 object_1 3 user_1 object_2 2 user_2 object_2 6 user_2 object_1 5
I’m really new to pyspark, could anyone give me a code snippet or portal to the related documentation of this problem? Great thanks!
Answers:
Thank you for visiting the Q&A section on Magenaut. Please note that all the answers may not help you solve the issue immediately. So please treat them as advisements. If you found the post helpful (or not), leave a comment & I’ll get back to you as soon as possible.
Method 1
I believe you need to use window functions to attain the rank of each row based on user_id and score, and subsequently filter your results to only keep the first two values.
from pyspark.sql.window import Window
from pyspark.sql.functions import rank, col
window = Window.partitionBy(df['user_id']).orderBy(df['score'].desc())
df.select('*', rank().over(window).alias('rank'))
.filter(col('rank') <= 2)
.show()
#+-------+---------+-----+----+
#|user_id|object_id|score|rank|
#+-------+---------+-----+----+
#| user_1| object_1| 3| 1|
#| user_1| object_2| 2| 2|
#| user_2| object_2| 6| 1|
#| user_2| object_1| 5| 2|
#+-------+---------+-----+----+
In general, the official programming guide is a good place to start learning Spark.
Data
rdd = sc.parallelize([("user_1", "object_1", 3),
("user_1", "object_2", 2),
("user_2", "object_1", 5),
("user_2", "object_2", 2),
("user_2", "object_2", 6)])
df = sqlContext.createDataFrame(rdd, ["user_id", "object_id", "score"])
Method 2
Top-n is more accurate if using row_number instead of rank when getting rank equality:
val n = 5
df.select(col('*'), row_number().over(window).alias('row_number'))
.where(col('row_number') <= n)
.limit(20)
.toPandas()
Note
limit(20).toPandas()trick instead ofshow()for Jupyter notebooks for nicer formatting.
Method 3
I know the question is asked for pyspark and I was looking for the similar answer in Scala i.e.
Retrieve top n values in each group of a DataFrame in Scala
Here is the scala version of @mtoto’s answer.
import org.apache.spark.sql.expressions.Window
import org.apache.spark.sql.functions.rank
import org.apache.spark.sql.functions.col
val window = Window.partitionBy("user_id").orderBy('score desc')
val rankByScore = rank().over(window)
df1.select('*, rankByScore as 'rank).filter(col("rank") <= 2).show()
# you can change the value 2 to any number you want. Here 2 represents the top 2 values
More examples can be found here.
Method 4
with Python 3 and Spark 2.4
from pyspark.sql import Window
import pyspark.sql.functions as f
def get_topN(df, group_by_columns, order_by_column, n=1):
window_group_by_columns = Window.partitionBy(group_by_columns)
ordered_df = df.select(df.columns + [
f.row_number().over(window_group_by_columns.orderBy(order_by_column.desc())).alias('row_rank')])
topN_df = ordered_df.filter(f"row_rank <= {n}").drop("row_rank")
return topN_df
top_n_df = get_topN(your_dataframe, [group_by_columns],[order_by_columns], 1)
Method 5
Here is another solution without a window function to get the top N records from pySpark DataFrame.
# Import Libraries
from pyspark.sql.functions import col
# Sample Data
rdd = sc.parallelize([("user_1", "object_1", 3),
("user_1", "object_2", 2),
("user_2", "object_1", 5),
("user_2", "object_2", 2),
("user_2", "object_2", 6)])
df = sqlContext.createDataFrame(rdd, ["user_id", "object_id", "score"])
# Get top n records as Row Objects
row_list = df.orderBy(col("score").desc()).head(5)
# Convert row objects to DF
sorted_df = spark.createDataFrame(row_list)
# Display DataFrame
sorted_df.show()
Output
+-------+---------+-----+ |user_id|object_id|score| +-------+---------+-----+ | user_1| object_2| 2| | user_2| object_2| 2| | user_1| object_1| 3| | user_2| object_1| 5| | user_2| object_2| 6| +-------+---------+-----+
If you are interested in more window functions in Spark you can refer to one of my blogs: https://medium.com/expedia-group-tech/deep-dive-into-apache-spark-window-functions-7b4e39ad3c86
Method 6
To Find Nth highest value in PYSPARK SQLquery using ROW_NUMBER() function:
SELECT * FROM (
SELECT e.*,
ROW_NUMBER() OVER (ORDER BY col_name DESC) rn
FROM Employee e
)
WHERE rn = N
N is the nth highest value required from the column
Output:
[Stage 2:> (0 + 1) / 1]++++++++++++++++ +-----------+ |col_name | +-----------+ |1183395 | +-----------+
query will return N highest value
All methods was sourced from stackoverflow.com or stackexchange.com, is licensed under cc by-sa 2.5, cc by-sa 3.0 and cc by-sa 4.0