Introduction to PySpark Part 3 – Adding, Updating and Removing Columns
This is the third part in a series of blog posts as an introduction to PySpark.
The other parts of this blog post series can be found here:
- Part 1 – Creating Data Frames and Reading Data from Files
- Part 2 – Selecting, Filtering and Sorting Data
- Part 3 – Adding, Updating and Removing Columns
- Part 4 – Summarising Data
- Part 5 – Aggregating Data
In this part I’ll be covering:
- Adding Columns
- Summing Two Columns
- expr for Cumulative Sum
- Other Arithmetic Operators
- String Concatenation
- Updating Columns
- Renaming Columns
- Casting Columns
- When…Otherwise
- UDFs
- Filling Nulls
- Removing Columns
Adding Columns
We can add columns in PySpark using the withColumn
method of DataFrames, let’s say we have a DataFrame:
from pyspark.sql import Row
fruit_data = [
{'day': 1, 'apples': 3, 'oranges': 9},
{'day': 2, 'apples': 7, 'oranges': 1},
{'day': 3, 'apples': 2, 'oranges': 3},
{'day': 4, 'apples': 4, 'oranges': 2},
]
fruit_df = spark.createDataFrame(map(lambda x: Row(**x), fruit_data))
display(fruit_df)
day | apples | oranges |
---|---|---|
1 | 3 | 9 |
2 | 7 | 1 |
3 | 2 | 3 |
4 | 4 | 2 |
Summing two columns
We can create a new column named "fruit"
, which is the sum of these 2 columns by doing:
from pyspark.sql.functions import col
fruit_df = fruit_df.withColumn(
"fruit",
col('apples') + col('oranges')
)
display(fruit_df)
day | apples | oranges | fruit |
---|---|---|---|
1 | 3 | 9 | 12 |
2 | 7 | 1 | 8 |
3 | 2 | 3 | 5 |
4 | 4 | 2 | 6 |
expr for Cumulative Sum
We can use the expr
function to provide a SQL-like expression. Here we’ll use it to calculate a cumulative sum
from pyspark.sql.functions import expr
fruit_df = fruit_df.withColumn('cumsum', expr('sum(fruit) over (order by day)'))
display(fruit_df)
day | apples | oranges | fruit | cumsum |
---|---|---|---|---|
1 | 3 | 9 | 12 | 12 |
2 | 7 | 1 | 8 | 20 |
3 | 2 | 3 | 5 | 25 |
4 | 4 | 2 | 6 | 31 |
If we wanted to partition by another column first and then do our cumulative sum, we could do so using:
sum(value) over (partition by partition_column order by order_column)
Other Arithmetic Operators
It’s not just addition we have the option of creating new columns with, let’s say we have the following data, with y
as the ground truth and y_hat
as our prediction.
error_data = [
{'y': 1.7, 'y_hat': 2.1},
{'y': 3.4, 'y_hat': 3.2},
{'y': 7.2, 'y_hat': 7.0},
{'y': 4.1, 'y_hat': 4.6},
]
error_df = spark.createDataFrame(map(lambda x: Row(**x), error_data))
display(error_df)
y | y_hat |
---|---|
1.7 | 2.1 |
3.4 | 3.2 |
7.2 | 7.0 |
4.1 | 4.6 |
We can start by subtracting our error from our ground truth to get our error values.
We get some float errors here, so we’ll also round this column to the nearest 2 decimal places.
from pyspark.sql.functions import round as col_round
error_df = error_df.withColumn(
'error',
col_round(col('y') - col('y_hat'), 2)
)
display(error_df)
y | y_hat | error |
---|---|---|
1.7 | 2.1 | -0.4 |
3.4 | 3.2 | 0.2 |
7.2 | 7.0 | 0.2 |
4.1 | 4.6 | -0.5 |
Now if we want the absolute value of this error, we can import a function from pyspark.sql.functions
to help us do that
from pyspark.sql.functions import abs as col_abs
error_df = error_df.withColumn(
'absolute_error',
col_abs(col('error'))
)
display(error_df)
y | y_hat | error | absolute_error |
---|---|---|---|
1.7 | 2.1 | -0.4 | 0.4 |
3.4 | 3.2 | 0.2 | 0.2 |
7.2 | 7.0 | 0.2 | 0.2 |
4.1 | 4.6 | -0.5 | 0.5 |
Alternatively, we might want the squared error, we can get that as follows:
from pyspark.sql.functions import abs as col_abs
error_df = error_df.withColumn(
'squared_error',
col_round(col('error') ** 2, 3)
)
display(error_df)
y | y_hat | error | absolute_error | squared_error |
---|---|---|---|---|
1.7 | 2.1 | -0.4 | 0.4 | 0.16 |
3.4 | 3.2 | 0.2 | 0.2 | 0.04 |
7.2 | 7.0 | 0.2 | 0.2 | 0.04 |
4.1 | 4.6 | -0.5 | 0.5 | 0.25 |
We can even go as far as use the sine operator on a column of data and chain our column additions.
import math
from pyspark.sql.functions import sin
angle_df = spark.createDataFrame(
[Row(angle=0.0), Row(angle=math.pi * 0.25),
Row(angle=math.pi * 0.5), Row(angle=math.pi)]
)
angle_df = angle_df.withColumn(
'sine',
col_round(sin(angle_df['angle']), 5)
).withColumn(
'sine_squared',
col_round(sin(angle_df['angle']) ** 2, 5)
)
display(angle_df)
angle | sine | sine_squared |
---|---|---|
0.0 | 0.0 | 0.0 |
0.7853981633974483 | 0.70711 | 0.5 |
1.5707963267948966 | 1.0 | 1.0 |
3.141592653589793 | 0.0 | 0.0 |
String Concatenation
We can concatenate strings from two columns using either the concat
or the concat_ws
functions. concat_ws
takes a separator as its first parameter.
In the example below we use lit
as a literal, this will be provided as a constant across all rows.
car_data = [
{'make': 'Ford', 'model': 'GT40'},
{'make': 'Volkswagen', 'model': 'Golf'},
{'make': 'Aston Martin', 'model': 'Vanquish'},
{'make': 'Mazda', 'model': 'MX-5 RF'},
]
car_df = spark.createDataFrame(map(lambda x: Row(**x), car_data))
display(car_df)
make | model |
---|---|
Ford | GT40 |
Volkswagen | Golf |
Aston Martin | Vanquish |
Mazda | MX-5 RF |
from pyspark.sql.functions import concat, concat_ws, lit
display(
car_df.withColumn(
'name',
concat(col('make'), lit(' '), col('model'))
)
)
make | model | name |
---|---|---|
Ford | GT40 | Ford GT40 |
Volkswagen | Golf | Volkswagen Golf |
Aston Martin | Vanquish | Aston Martin Vanquish |
Mazda | MX-5 RF | Mazda MX-5 RF |
car_df = car_df.withColumn(
'name',
concat_ws(' ', col('make'), col('model'))
)
display(car_df)
make | model | name |
---|---|---|
Ford | GT40 | Ford GT40 |
Volkswagen | Golf | Volkswagen Golf |
Aston Martin | Vanquish | Aston Martin Vanquish |
Mazda | MX-5 RF | Mazda MX-5 RF |
Updating Columns
Renaming Columns in PySpark is relatively straightforward using the withColumnRenamed
method of spark DataFrames, let’s take a look using the DataFrame used above.
car_df = car_df.withColumnRenamed('make', 'manufacturer')
display(car_df)
manufacturer | model | name |
---|---|---|
Ford | GT40 | Ford GT40 |
Volkswagen | Golf | Volkswagen Golf |
Aston Martin | Vanquish | Aston Martin Vanquish |
Mazda | MX-5 RF | Mazda MX-5 RF |
Casting Columns
We can cast columns from one type to another using the cast
method of a column. Let’s say we read a CSV in and, by default, all the columns are strings, we would want to then cast the columns to the correct types.
city_population = spark.read.csv('/mnt/tmp/city_population.csv', header=True)
display(city_population)
Population | Year | City |
---|---|---|
6.9 | 1991 | London |
7.2 | 2001 | London |
8.2 | 2011 | London |
2.5 | 2001 | Manchester |
2.7 | 2011 | Manchester |
print(city_population.dtypes)
from pyspark.sql.types import IntegerType, DoubleType
city_population = city_population.withColumn(
'Population', col('Population').cast(DoubleType())
).withColumn(
'Year', col('Year').cast(IntegerType())
)
print(city_population.dtypes)
We can also cast columns to dates using the to_date
function, for example:
from pyspark.sql.functions import to_date
date_df = spark.createDataFrame(
[
Row(date='2021-01-01'),
Row(date='2021-01-02'),
Row(date='2021-01-03'),
Row(date='2021-01-04'),
]
)
date_df = date_df.withColumn('date', to_date(col('date')))
display(date_df)
date |
---|
2021-01-01 |
2021-01-02 |
2021-01-03 |
2021-01-04 |
When…Otherwise
We can use the when
function applied with the otherwise
as an IF…ELSE statement, where if a condition is met then return one value, otherwise return another value.
Let’s say we have the following transaction data and we want to work out our net revenue.
transaction_data = [
{'price_per_item': 5.5, 'quantity': 2, 'transaction_type': 'sale'},
{'price_per_item': 3.25, 'quantity': 4, 'transaction_type': 'return'},
{'price_per_item': 4.99, 'quantity': 9, 'transaction_type': 'sale'},
{'price_per_item': 7.10, 'quantity': 3, 'transaction_type': 'return'},
]
transaction_df = spark.createDataFrame(map(lambda x: Row(**x), transaction_data))
display(transaction_df)
price_per_item | quantity | transaction_type |
---|---|---|
5.5 | 2 | sale |
3.25 | 4 | return |
4.99 | 9 | sale |
7.1 | 3 | return |
For transactions that are sales we want to have the price_per_item
x quantity
, but for transactions that are returns we want the negative value of this.
from pyspark.sql.functions import when
transaction_df = transaction_df.withColumn(
'revenue',
when(
col('transaction_type') == 'sale',
col_round(col('price_per_item') * col('quantity'), 2)
).otherwise(
col_round(col('price_per_item') * col('quantity') * -1, 2)
)
)
display(transaction_df)
price_per_item | quantity | transaction_type | revenue |
---|---|---|---|
5.5 | 2 | sale | 11.0 |
3.25 | 4 | return | -13.0 |
4.99 | 9 | sale | 44.91 |
7.1 | 3 | return | -21.3 |
We can also chain operators in the when
keyword and chain when
clauses too. Let’s take a look at an example of the first one. Let’s say a “bulk purchase” is a sale with a quantity more than 5 items, we can do:
transaction_df = transaction_df.withColumn(
'bulk_purchase',
when(
(col('quantity') > 5) & (col('transaction_type') == 'sale'),
True
).otherwise(
False
)
)
display(transaction_df)
price_per_item | quantity | transaction_type | revenue | bulk_purchase |
---|---|---|---|---|
5.5 | 2 | sale | 11.0 | false |
3.25 | 4 | return | -13.0 | false |
4.99 | 9 | sale | 44.91 | true |
7.1 | 3 | return | -21.3 | false |
UDFs
UDFs are user-defined functions, that will take in values from one or more columns and return a value so you can create a new column.
You provide the output type of the column when you define your udf and you can define your udf using the @udf
decorator.
Let’s take a look at a very simple UDF in action:
cost_df = spark.createDataFrame(
[Row(cost=1.0), Row(cost=11.25),
Row(cost=0.4), Row(cost=0.0)]
)
display(cost_df)
cost |
---|
1.0 |
11.25 |
0.4 |
0.0 |
from pyspark.sql.functions import udf
from pyspark.sql.types import StringType
@udf(StringType())
def format_cost(cost_value):
return "£%.2f" % cost_value
cost_df = cost_df.withColumn('cost_formatted', format_cost('cost'))
display(cost_df)
cost | cost_formatted |
---|---|
1.0 | £1.00 |
11.25 | £11.25 |
0.4 | £0.40 |
0.0 | £0.00 |
Let’s take a look at a more complex example that requires multiple columns as inputs and returns a single output.
In this case we’re getting the distance between two sets of latitude/longitude coordinates.
location_data = [
{'start_lat': 52.05489, 'start_long': -1.02910, 'end_lat': 51.49749, 'end_long': -0.05131},
{'start_lat': 48.86728, 'start_long': 2.26679, 'end_lat': 51.94667, 'end_long': 7.683055},
{'start_lat': 51.89246, 'start_long': 4.42011, 'end_lat': 45.04524, 'end_long': 7.74897},
{'start_lat': 41.3962, 'start_long': 2.11298, 'end_lat': 52.29743, 'end_long': 21.28412},
]
location_df = spark.createDataFrame(map(lambda x: Row(**x), location_data))
display(location_data)
end_lat | end_long | start_lat | start_long |
---|---|---|---|
51.49749 | -0.05131 | 52.05489 | -1.0291 |
51.94667 | 7.683055 | 48.86728 | 2.26679 |
45.04524 | 7.74897 | 51.89246 | 4.42011 |
52.29743 | 21.28412 | 41.3962 | 2.11298 |
from math import sin, cos, sqrt, atan2, radians
from pyspark.sql.functions import udf
from pyspark.sql.types import DoubleType
@udf(DoubleType())
def get_distance(lat1, long1, lat2, long2):
R = 6373.0
lat1 = radians(lat1)
lon1 = radians(long1)
lat2 = radians(lat2)
lon2 = radians(long2)
dlon = lon2 - lon1
dlat = lat2 - lat1
a = sin(dlat / 2)**2 + cos(lat1) * cos(lat2) * sin(dlon / 2)**2
c = 2 * atan2(sqrt(a), sqrt(1 - a))
distance = R * c
return distance
points = [
col('start_lat'),
col('start_long'),
col('end_lat'),
col('end_long')
]
location_df = location_df.withColumn('distance', get_distance(*points))
display(location_df)
start_lat | start_long | end_lat | end_long | distance |
---|---|---|---|---|
52.05489 | -1.0291 | 51.49749 | -0.05131 | 91.49887592401357 |
48.86728 | 2.26679 | 51.94667 | 7.683055 | 514.2846661188934 |
51.89246 | 4.42011 | 45.04524 | 7.74897 | 799.9819123478403 |
41.3962 | 2.11298 | 52.29743 | 21.28412 | 1886.377347021304 |
Filling Null Values
In our previous blog post on filtering data we looked at dropping null values, if we wanted to fill these instead, we can do so in a number of ways.
Let’s create some data with a missing value.
data = [
Row(a='A0', b='B0', c='C0'),
Row(a='A1', b='B1', c='C1'),
Row(a='A2', b='B2', c=None),
Row(a='A3', b='B3', c='C3'),
]
data_df = spark.createDataFrame(data)
display(data_df)
a | b | c |
---|---|---|
A0 | B0 | C0 |
A1 | B1 | C1 |
A2 | B2 | null |
A3 | B3 | C3 |
We can fill null values in a particular column with a literal using fillna
, for example, if we wanted to fill in missing values in the “C” column with “C2” then we can do that as follows:
display(data_df.fillna(value="C2", subset=["C"]))
a | b | c |
---|---|---|
A0 | B0 | C0 |
A1 | B1 | C1 |
A2 | B2 | C2 |
A3 | B3 | C3 |
However, if we wanted to, we could also use our when
… otherwise
to fill the value from another column, for example:
display(
data_df.withColumn(
'c',
when(
col('c').isNull(),
col('b')
).otherwise(
col('c')
)
)
)
a | b | c |
---|---|---|
A0 | B0 | C0 |
A1 | B1 | C1 |
A2 | B2 | B2 |
A3 | B3 | C3 |
Because of the way rows are processed across nodes, we need windowing functions in order to process rows in a manner that we could forward fill or backward fill null values and we’ll look at those in our section on aggregating data.
Removing Columns
In order to remove columns we can use the drop
method. Let’s load in the Pokemon dataset we looked at in the last post:
pokemon = spark.read.parquet('/mnt/tmp/pokemon.parquet')
pokemon = pokemon.select('#', 'Name', 'Type1', 'Type2', 'Generation')
display(pokemon.limit(10))
# | Name | Type1 | Type2 | Generation |
---|---|---|---|---|
1 | Bulbasaur | Grass | Poison | 1 |
2 | Ivysaur | Grass | Poison | 1 |
3 | Venusaur | Grass | Poison | 1 |
3 | VenusaurMega Venusaur | Grass | Poison | 1 |
4 | Charmander | Fire | null | 1 |
5 | Charmeleon | Fire | null | 1 |
6 | Charizard | Fire | Flying | 1 |
6 | CharizardMega Charizard X | Fire | Dragon | 1 |
6 | CharizardMega Charizard Y | Fire | Flying | 1 |
7 | Squirtle | Water | null | 1 |
To drop a column, we just pass that column to the drop
method of our DataFrame:
pokemon = pokemon.drop('Generation')
display(pokemon.limit(5))
# | Name | Type1 | Type2 |
---|---|---|---|
1 | Bulbasaur | Grass | Poison |
2 | Ivysaur | Grass | Poison |
3 | Venusaur | Grass | Poison |
3 | VenusaurMega Venusaur | Grass | Poison |
4 | Charmander | Fire | null |