Rename nested field in spark dataframe

Having a dataframe df in Spark:

 |-- array_field: array (nullable = true)
 |    |-- element: struct (containsNull = true)
 |    |    |-- a: string (nullable = true)
 |    |    |-- b: long (nullable = true)
 |    |    |-- c: long (nullable = true)

How to rename field array_field.a to array_field.a_renamed?
[Update]:
.withColumnRenamed() does not work with nested fields so I tried this hacky and unsafe method:

# First alter the schema:
schema = df.schema
schema['array_field'].dataType.elementType['a'].name = 'a_renamed'

ind = schema['array_field'].dataType.elementType.names.index('a')
schema['array_field'].dataType.elementType.names[ind] = 'a_renamed'

# Then set dataframe's schema with altered schema
df._schema = schema

I know that setting a private attribute is not a good practice but I don’t know other way to set the schema for df

I think I am on a right track but df.printSchema() still shows the old name for array_field.a, though df.schema == schema is True

Answers:

Thank you for visiting the Q&A section on Magenaut. Please note that all the answers may not help you solve the issue immediately. So please treat them as advisements. If you found the post helpful (or not), leave a comment & I’ll get back to you as soon as possible.

Method 1

Python

It is not possible to modify a single nested field. You have to recreate a whole structure. In this particular case the simplest solution is to use cast.

First a bunch of imports:

from collections import namedtuple
from pyspark.sql.functions import col
from pyspark.sql.types import (
    ArrayType, LongType, StringType, StructField, StructType)

and example data:

Record = namedtuple("Record", ["a", "b", "c"])

df = sc.parallelize([([Record("foo", 1, 3)], )]).toDF(["array_field"])

Let’s confirm that the schema is the same as in your case:

df.printSchema()
root
 |-- array_field: array (nullable = true)
 |    |-- element: struct (containsNull = true)
 |    |    |-- a: string (nullable = true)
 |    |    |-- b: long (nullable = true)
 |    |    |-- c: long (nullable = true)

You can define a new schema for example as a string:

str_schema = "array<struct<a_renamed:string,b:bigint,c:bigint>>"

df.select(col("array_field").cast(str_schema)).printSchema()
root
 |-- array_field: array (nullable = true)
 |    |-- element: struct (containsNull = true)
 |    |    |-- a_renamed: string (nullable = true)
 |    |    |-- b: long (nullable = true)
 |    |    |-- c: long (nullable = true)

or a DataType:

struct_schema = ArrayType(StructType([
    StructField("a_renamed", StringType()),
    StructField("b", LongType()),
    StructField("c", LongType())
]))

 df.select(col("array_field").cast(struct_schema)).printSchema()
root
 |-- array_field: array (nullable = true)
 |    |-- element: struct (containsNull = true)
 |    |    |-- a_renamed: string (nullable = true)
 |    |    |-- b: long (nullable = true)
 |    |    |-- c: long (nullable = true)

Scala

The same techniques can be used in Scala:

case class Record(a: String, b: Long, c: Long)

val df = Seq(Tuple1(Seq(Record("foo", 1, 3)))).toDF("array_field")

val strSchema = "array<struct<a_renamed:string,b:bigint,c:bigint>>"

df.select($"array_field".cast(strSchema))

or

import org.apache.spark.sql.types._

val structSchema = ArrayType(StructType(Seq(
    StructField("a_renamed", StringType),
    StructField("b", LongType),
    StructField("c", LongType)
)))

df.select($"array_field".cast(structSchema))

Possible improvements:

If you use an expressive data manipulation or JSON processing library it could be easier to dump data types to dict or JSON string and take it from there for example (Python / toolz):

from toolz.curried import pipe, assoc_in, update_in, map
from operator import attrgetter

# Update name to "a_updated" if name is "a"
rename_field = update_in(
    keys=["name"], func=lambda x: "a_updated" if x == "a" else x)

updated_schema = pipe(
   #  Get schema of the field as a dict
   df.schema["array_field"].jsonValue(),
   # Update fields with rename
   update_in(
       keys=["type", "elementType", "fields"],
       func=lambda x: pipe(x, map(rename_field), list)),
   # Load schema from dict
   StructField.fromJson,
   # Get data type
   attrgetter("dataType"))

df.select(col("array_field").cast(updated_schema)).printSchema()

Method 2

You can recurse over the data frame’s schema to create a new schema with the required changes.

A schema in PySpark is a StructType which holds a list of StructFields and each StructField can hold some primitve type or another StructType.

This means that we can decide if we want to recurse based on whether the type is a StructType or not.

Below is an annotated sample implementation that shows you how you can implement the above idea.

# Some imports
from pyspark.sql.types import DataType, StructType, ArrayType
from copy import copy

# We take a dataframe and return a new one with required changes
def cleanDataFrame(df: DataFrame) -> DataFrame:
    # Returns a new sanitized field name (this function can be anything really)
    def sanitizeFieldName(s: str) -> str:
        return s.replace("-", "_").replace("&", "_").replace(""", "_")
            .replace("[", "_").replace("]", "_").replace(".", "_")
    
    # We call this on all fields to create a copy and to perform any 
    # changes we might want to do to the field.
    def sanitizeField(field: StructField) -> StructField:
        field = copy(field)
        field.name = sanitizeFieldName(field.name)
        # We recursively call cleanSchema on all types
        field.dataType = cleanSchema(field.dataType)
        return field
    
    def cleanSchema(dataType: [DataType]) -> [DataType]:
        dataType = copy(dataType)
        # If the type is a StructType we need to recurse otherwise 
        # we can return since we've reached the leaf node
        if isinstance(dataType, StructType):
            # We call our sanitizer for all top level fields
            dataType.fields = [sanitizeField(f) for f in dataType.fields]
        elif isinstance(dataType, ArrayType):
            dataType.elementType = cleanSchema(dataType.elementType)
        return dataType

    # Now since we have the new schema we can create a new DataFrame 
    # by using the old Frame's RDD as data and the new schema as the 
    # schema for the data
    return spark.createDataFrame(df.rdd, cleanSchema(df.schema))

Method 3

I found a much easier way than the one provided by @zero323, along the lines
of @MaxPY:

Pyspark 2.4:

# Get the schema from the dataframe df
schema = df.schema

# Override `fields` with a list of new StructField, equals to the previous but for the names
schema.fields = (list(map(lambda field: 
                          StructField(field.name + "_renamed", field.dataType), schema.fields)))

# Override also `names` with the same mechanism
schema.names = list(map(lambda name: name + "_renamed", table_schema.names))

Now df.schema will print all the renewed names.

Method 4

Another much easier solution if it works for you like it works for me is to flatten the structure and then rename:

Using Scala:

val df_flat = df.selectExpr("array_field.*")

Now the rename works

val df_renamed = df_flat.withColumnRenamed("a", "a_renamed")

Of course this only works for you if you dont need the hierarchy (although I suppose it can be recreated again if needed)

Method 5

Using answer provided by Leo C in:https://stackoverflow.com/a/55363153/5475506, I have built what I consider a more human-friendly/pythoniac script:

    import pyspark.sql.types as sql_types

    path_table = "<PATH_TO_DATA>"
    table_name = "<TABLE_NAME>"

    def recur_rename(schema: StructType, old_char, new_char):
        schema_new = []
        for struct_field in schema:
            if type(struct_field.dataType)==sql_types.StructType:
                schema_new.append(sql_types.StructField(struct_field.name.replace(old_char, new_char), sql_types.StructType(recur_rename(struct_field.dataType, old_char, new_char)), struct_field.nullable, struct_field.metadata))
            elif type(struct_field.dataType)==sql_types.ArrayType: 
                if type(struct_field.dataType.elementType)==sql_types.StructType:
                    schema_new.append(sql_types.StructField(struct_field.name.replace(old_char, new_char), sql_types.ArrayType(sql_types.StructType(recur_rename(struct_field.dataType.elementType, old_char, new_char)),True), struct_field.nullable, struct_field.metadata)) # Recursive call to loop over all Array elements
                else:
                    schema_new.append(sql_types.StructField(struct_field.name.replace(old_char, new_char), struct_field.dataType.elementType, struct_field.nullable, struct_field.metadata)) # If ArrayType only has one field, it is no sense to use an Array so Array is exploded
            else:
                schema_new.append(sql_types.StructField(struct_field.name.replace(old_char, new_char), struct_field.dataType, struct_field.nullable, struct_field.metadata))
        return schema_new

    def rename_columns(schema: StructType, old_char, new_char):
        return sql_types.StructType(recur_rename(schema, old_char, new_char))

    df = spark.read.format("json").load(path_table) # Read data whose schema has to be changed.
    newSchema = rename_columns(df.schema, ":", "_") # Replace special characters in schema (More special characters not allowed in Spark/Hive meastore: ':', ',', ';')
    df2= spark.read.format("json").schema(newSchema).load(path_table) # Read data with new schema.

I consider the code self explanatory (furthermore, it has comments) but what it does is recursively loop over all the fields in the schema, replacing “old_char” by “new_char” in each of them. If field type is a nested one (StructType or ArrayType) new recursive calls are made.


All methods was sourced from stackoverflow.com or stackexchange.com, is licensed under cc by-sa 2.5, cc by-sa 3.0 and cc by-sa 4.0

0 0 votes
Article Rating
Subscribe
Notify of
guest

0 Comments
Oldest
Newest Most Voted
Inline Feedbacks
View all comments
0
Would love your thoughts, please comment.x
()
x