Unit Testing with Databricks
Part 1 – PySpark Unit Testing using Databricks Connect
On my most recent project, I’ve been working with Databricks for the first time.
At first I found using Databricks to write production code somewhat jarring – using the notebooks in the web portal isn’t the most developer-friendly and I found it akin to using Jupyter notebooks for writing production code.
However, game-changer: enter Databricks Connect, a way of remotely executing code on your Databricks Cluster. I was back at home, developing in the comfort of my IDE and running PySpark commands in the cloud. I now really enjoy using Databricks and would happily recommend it to anyone that needs to do distributed data engineering on big data.
One of the things you’ll certainly need to do if you’re looking to write production code yourself in Databricks is unit tests. This blog post, and the next part, aim to help you do this with a super simple example of unit testing functionality in PySpark.
To follow along with this blog post you’ll need
- Python 3.7
- A Databricks Workspace in Microsoft Azure with a cluster running Databricks Runtime 7.3 LTS
Quick disclaimer: At the time of writing, I am currently a Microsoft Employee
Setting up your local environment
Start by cloning the repository that goes along with this blog post here.
Now create a new virtual environment and run:
pip install -r requirements.txt
This will install some testing requirements, Databricks Connect, and a python package defined in our git repository.
Then you’ll have to set up your Databricks Connect. You can do this by running databricks-connect configure
and following the instructions given in the Databricks Connect documentation.
You can test your Databricks Connect is working correctly by running:
databricks-connect test
Function to test
We’re going to test a function that takes in some data as a Spark DataFrame and returns some transformed data as a Spark DataFrame.
Let’s say we start with some data that looks like this, where we have 3 pumps that are pumping liquid:
pump_id | start_time | end_time | litres_pumped |
---|---|---|---|
1 | 2021-02-01 01:05:32 | 2021-02-01 01:09:13 | 24 |
2 | 2021-02-01 01:09:14 | 2021-02-01 01:14:17 | 41 |
1 | 2021-02-01 01:14:18 | 2021-02-01 01:15:58 | 11 |
2 | 2021-02-01 01:15:59 | 2021-02-01 01:18:26 | 13 |
1 | 2021-02-01 01:18:27 | 2021-02-01 01:26:26 | 45 |
3 | 2021-02-01 01:26:27 | 2021-02-01 01:38:57 | 15 |
And we want to know the average litres pumped per second for each of these pumps. (British English spelling of “litres” 😉).
So we want an output that reads something like this:
pump_id | total_duration_seconds | total_litres_pumped | avg_litres_per_second |
---|---|---|---|
1 | 800 | 80 | 0.12 |
2 | 450 | 54 | 0.2 |
3 | 750 | 15 | 0.02 |
We can create a function as follows to do this:
from pyspark.sql.functions import col, sum as col_sum
def aggregate_pump_data(pump_data_df):
aggregated_df = pump_data_df.withColumn(
"duration_seconds",
(
col("end_time").cast('timestamp').cast('long')
- col("start_time").cast('timestamp').cast('long')
)
).groupBy("pump_id").agg(
col_sum("duration_seconds").alias("total_duration_seconds"),
col_sum('litres_pumped').alias("total_litres_pumped")
).withColumn(
"avg_litres_per_second",
col("total_litres_pumped") / col("total_duration_seconds")
)
return aggregated_df
This function can be found in our repository in databricks_pkg/databricks_pkg/pump_utils.py
.
Unit testing our function
The unit test for our function can be found in the repository in databricks_pkg/test/test_pump_utils.py
.
To run the unit tests run:
pytest -v ./databricks_pkg/test/test_pump_utils.py
If everything is working correctly, the unit test should pass.
First I’ll show the unit testing file, then go through it line-by-line.
import unittest
import pandas as pd
from pyspark.sql import Row, SparkSession
from pyspark.sql.dataframe import DataFrame
from databricks_pkg.pump_utils import get_litres_per_second
class TestGetLitresPerSecond(unittest.TestCase):
def setUp(self):
self.spark = SparkSession.builder.getOrCreate()
def test_get_litres_per_second(self):
test_data = [
# pump_id, start_time, end_time, litres_pumped
(1, '2021-02-01 01:05:32', '2021-02-01 01:09:13', 24),
(2, '2021-02-01 01:09:14', '2021-02-01 01:14:17', 41),
(1, '2021-02-01 01:14:18', '2021-02-01 01:15:58', 11),
(2, '2021-02-01 01:15:59', '2021-02-01 01:18:26', 13),
(1, '2021-02-01 01:18:27', '2021-02-01 01:26:26', 45),
(3, '2021-02-01 01:26:27', '2021-02-01 01:38:57', 15)
]
test_data = [
{
'pump_id': row[0],
'start_time': row[1],
'end_time': row[2],
'litres_pumped': row[3]
} for row in test_data
]
test_df = self.spark.createDataFrame(map(lambda x: Row(**x), test_data))
output_df = get_litres_per_second(test_df)
self.assertIsInstance(output_df, DataFrame)
output_df_as_pd = output_df.sort('pump_id').toPandas()
expected_output_df = pd.DataFrame([
{
'pump_id': 1,
'total_duration_seconds': 800,
'total_litres_pumped': 80,
'avg_litres_per_second': 0.1
},
{
'pump_id': 2,
'total_duration_seconds': 450,
'total_litres_pumped': 54,
'avg_litres_per_second': 0.12
},
{
'pump_id': 3,
'total_duration_seconds': 750,
'total_litres_pumped': 15,
'avg_litres_per_second': 0.02
},
])
pd.testing.assert_frame_equal(expected_output_df, output_df_as_pd)
We’ll now go through this file line-by-line:
Imports
The unit testing function starts with some imports, we start with the builtins, then external packages, then finally internal packages – which includes the function we’ll be testing:
import unittest
import pandas as pd
from pyspark.sql import Row, SparkSession
from pyspark.sql.dataframe import DataFrame
from databricks_pkg.pump_utils import get_litres_per_second
Define testing class
The Testing class is a child class of the unittest.TestCase
class. The unittest.TestCase
comes with a range of class methods for helping with unit testing.
Although in this case we’re only running one test, we might have run multiple tests and we can use the same TestGetLitresPerSecond.spark
attribute as our spark session.
class TestGetLitresPerSecond(unittest.TestCase):
def setUp(self):
self.spark = SparkSession.builder.getOrCreate()
Define Test Data
The next thing for us to do is to define some test data, we’ll use the test data shown at the top of this post. It’s first defined as a list of tuples and then I use a list comprehension to convert it to a list of dicts.
I’ve defined it this way for readability, you can define your test data however you feel comfortable.
def test_get_litres_per_second(self):
test_data = [
# pump_id, start_time, end_time, litres_pumped
(1, '2021-02-01 01:05:32', '2021-02-01 01:09:13', 24),
(2, '2021-02-01 01:09:14', '2021-02-01 01:14:17', 41),
(1, '2021-02-01 01:14:18', '2021-02-01 01:15:58', 11),
(2, '2021-02-01 01:15:59', '2021-02-01 01:18:26', 13),
(1, '2021-02-01 01:18:27', '2021-02-01 01:26:26', 45),
(3, '2021-02-01 01:26:27', '2021-02-01 01:38:57', 15)
]
test_data = [
{
'pump_id': row[0],
'start_time': row[1],
'end_time': row[2],
'litres_pumped': row[3]
} for row in test_data
]
Convert to Spark DataFrame
This is where we’re first going to be using our spark session to run in our Databricks cluster, this converts our list of dicts to a spark DataFrame:
test_df = self.spark.createDataFrame(map(lambda x: Row(**x), test_data))
Run our function
We now run the function we’re testing with our test DataFrame.
output_df = get_litres_per_second(test_df)
Test the output of the function
The first thing to check is whether the output of our function is the correct data type we expect, we can do this using the unittest.TestCase
class method assertIsInstance
:
self.assertIsInstance(output_df, DataFrame)
We’ll then convert our spark DataFrame into a pandas DataFrame. We also need to sort the DataFrame, there’s no guarantee that the processed output of the DataFrame is in any order, particularly as rows are partitioned and processed on different nodes.
output_df_as_pd = output_df.sort('pump_id').toPandas()
We can then check that this output DataFrame is equal to our expected output:
expected_output_df = pd.DataFrame([
{
'pump_id': 1,
'total_duration_seconds': 800,
'total_litres_pumped': 80,
'avg_litres_per_second': 0.1
},
{
'pump_id': 2,
'total_duration_seconds': 450,
'total_litres_pumped': 54,
'avg_litres_per_second': 0.12
},
{
'pump_id': 3,
'total_duration_seconds': 750,
'total_litres_pumped': 15,
'avg_litres_per_second': 0.02
},
])
pd.testing.assert_frame_equal(expected_output_df, output_df_as_pd)
Hopefully this blog post has helped you understand the basics of PySpark unit testing using Databricks and Databricks Connect.
In the next part of this blog post series, we’ll be diving into how we can integrate this unit testing into our CI pipeline.
For part 2, see here.