Pivot String column on Pyspark Dataframe

I have a simple dataframe like this:

rdd = sc.parallelize(
    [
        (0, "A", 223,"201603", "PORT"), 
        (0, "A", 22,"201602", "PORT"), 
        (0, "A", 422,"201601", "DOCK"), 
        (1,"B", 3213,"201602", "DOCK"), 
        (1,"B", 3213,"201601", "PORT"), 
        (2,"C", 2321,"201601", "DOCK")
    ]
)
df_data = sqlContext.createDataFrame(rdd, ["id","type", "cost", "date", "ship"])

df_data.show()
 +---+----+----+------+----+
| id|type|cost|  date|ship|
+---+----+----+------+----+
|  0|   A| 223|201603|PORT|
|  0|   A|  22|201602|PORT|
|  0|   A| 422|201601|DOCK|
|  1|   B|3213|201602|DOCK|
|  1|   B|3213|201601|PORT|
|  2|   C|2321|201601|DOCK|
+---+----+----+------+----+

and I need to pivot it by date:

df_data.groupby(df_data.id, df_data.type).pivot("date").avg("cost").show()

+---+----+------+------+------+
| id|type|201601|201602|201603|
+---+----+------+------+------+
|  2|   C|2321.0|  null|  null|
|  0|   A| 422.0|  22.0| 223.0|
|  1|   B|3213.0|3213.0|  null|
+---+----+------+------+------+

Everything works as expected. But now I need to pivot it and get a non-numeric column:

df_data.groupby(df_data.id, df_data.type).pivot("date").avg("ship").show()

and of course I would get an exception:

AnalysisException: u'"ship" is not a numeric column. Aggregation function can only be applied on a numeric column.;'

I would like to generate something on the line of

+---+----+------+------+------+
| id|type|201601|201602|201603|
+---+----+------+------+------+
|  2|   C|DOCK  |  null|  null|
|  0|   A| DOCK |  PORT| DOCK|
|  1|   B|DOCK  |PORT  |  null|
+---+----+------+------+------+

Is that possible with pivot?

Answers:

Thank you for visiting the Q&A section on Magenaut. Please note that all the answers may not help you solve the issue immediately. So please treat them as advisements. If you found the post helpful (or not), leave a comment & I’ll get back to you as soon as possible.

Method 1

Assuming that (id |type | date) combinations are unique and your only goal is pivoting and not aggregation you can use first (or any other function not restricted to numeric values):

from pyspark.sql.functions import first

(df_data
    .groupby(df_data.id, df_data.type)
    .pivot("date")
    .agg(first("ship"))
    .show())

## +---+----+------+------+------+
## | id|type|201601|201602|201603|
## +---+----+------+------+------+
## |  2|   C|  DOCK|  null|  null|
## |  0|   A|  DOCK|  PORT|  PORT|
## |  1|   B|  PORT|  DOCK|  null|
## +---+----+------+------+------+

If these assumptions is not correct you’ll have to pre-aggregate your data. For example for the most common ship value:

from pyspark.sql.functions import max, struct

(df_data
    .groupby("id", "type", "date", "ship")
    .count()
    .groupby("id", "type")
    .pivot("date")
    .agg(max(struct("count", "ship")))
    .show())

## +---+----+--------+--------+--------+
## | id|type|  201601|  201602|  201603|
## +---+----+--------+--------+--------+
## |  2|   C|[1,DOCK]|    null|    null|
## |  0|   A|[1,DOCK]|[1,PORT]|[1,PORT]|
## |  1|   B|[1,PORT]|[1,DOCK]|    null|
## +---+----+--------+--------+--------+

Method 2

In case, if someone is looking for SQL style approach.

rdd = spark.sparkContext.parallelize(
    [
        (0, "A", 223,"201603", "PORT"), 
        (0, "A", 22,"201602", "PORT"), 
        (0, "A", 422,"201601", "DOCK"), 
        (1,"B", 3213,"201602", "DOCK"), 
        (1,"B", 3213,"201601", "PORT"), 
        (2,"C", 2321,"201601", "DOCK")
    ]
)
df_data = spark.createDataFrame(rdd, ["id","type", "cost", "date", "ship"])
df_data.createOrReplaceTempView("df")
df_data.show()

dt_vals=spark.sql("select collect_set(date) from df").collect()[0][0]
['201601', '201602', '201603']

dt_vals_colstr=",".join(["'" + c + "'" for c in sorted(dt_vals)])
"'201601','201602','201603'"

Part-1 (Note the f format specifier)

spark.sql(f"""
select * from 
(select id , type, date, ship from df)
pivot (
first(ship) for date in ({dt_vals_colstr})
)
""").show(100,truncate=False)

+---+----+------+------+------+
|id |type|201601|201602|201603|
+---+----+------+------+------+
|1  |B   |PORT  |DOCK  |null  |
|2  |C   |DOCK  |null  |null  |
|0  |A   |DOCK  |PORT  |PORT  |
+---+----+------+------+------+

Part-2

spark.sql(f"""
select * from 
(select id , type, date, ship from df)
pivot (
case when count(*)=0 then null 
else struct(count(*),first(ship)) end for date in ({dt_vals_colstr})
)
""").show(100,truncate=False)

+---+----+---------+---------+---------+
|id |type|201601   |201602   |201603   |
+---+----+---------+---------+---------+
|1  |B   |[1, PORT]|[1, DOCK]|null     |
|2  |C   |[1, DOCK]|null     |null     |
|0  |A   |[1, DOCK]|[1, PORT]|[1, PORT]|
+---+----+---------+---------+---------+


All methods was sourced from stackoverflow.com or stackexchange.com, is licensed under cc by-sa 2.5, cc by-sa 3.0 and cc by-sa 4.0

0 0 votes
Article Rating
Subscribe
Notify of
guest

0 Comments
Oldest
Newest Most Voted
Inline Feedbacks
View all comments
0
Would love your thoughts, please comment.x
()
x