Multiclass Text Classification with PySpark
In this post we’ll explore the use of PySpark for multiclass classification of text documents. The data I’ll be using here contains Stack Overflow questions and associated tags. It is available from https://storage.googleapis.com/tensorflow-workshop-examples/stack-overflow-data.csv.
The idea will be to use PySpark to create a pipeline to analyse this data and create a classifier that will classify questions. In future questions could be auto-tagged by such a classifier or tags could be recommended to users prior to posting.
Quick disclaimer: At the time of writing, I am currently a Microsoft Employee, so naturally this was all carried out using Databricks on Azure but applies to any Spark cluster
We’ll start by loading in our data. We load the data into a Spark DataFrame directly from the CSV file.
filepath = 'dbfs:/FileStore/tables/stack_overflow_data-0b671.csv'
data_df = spark.read.option("header", "true").csv(filepath)
Let’s have a look at our data, we can see that there are posts and tags.
The data has many nuances, including HTML tags and a lot of characters you might find when coding, such as curly braces, semicolons and square brackets.
data_df.show(5)
data_df.head()['post']
We’ll want to get an idea of the distribution of our tags, so let’s do a count on each tag and see how many instances of each tag we have.
data_df.groupby('tags').count().show(30)
Luckily our data is very balanced and we have a good number of samples in each class, so we won’t need to do any resampling to balance out our classes. The notable exception here is the null
tag values. We’ll filter out all the observations that don’t have a tag.
data_df = data_df.filter(data_df.tags.isNotNull())
And now we can double check that we have 20 classes, all with 2000 observations each:
data_df.groupby('tags').count().show(30)
Great. Just as we normally we would we will split our data out into a training DataFrame and a hold-out testing DataFrame to determine how well our model is performing. We’ll use 75% of our data as a training set.
(trainDF, testDF) = data_df.randomSplit((0.75, 0.25), seed=100)
We’re now going to define a pipeline to clean up our data. However, the first thing we’re going to want to do is remove those HTML tags we see in the posts. As there is no built-in to do this in PySpark, we’re going to define our own custom Tranformer – we’ll call this transformer BsTextExtractor
as it’ll use BeautifulSoup
to extract just the text from the HTML.
We define a new class that will be a child class of the built-in Transformer class that has its own user-defined function (udf
) that uses BeautifulSoup
to extract the text from the post. This output will be a StringType()
. This custom Transformer can then be embedded as a step in our Pipeline
, creating a new column with just the extracted text.
from bs4 import BeautifulSoup
from pyspark import keyword_only
from pyspark.ml import Transformer
from pyspark.ml.param.shared import HasInputCol, HasOutputCol
from pyspark.sql.functions import udf
from pyspark.sql.types import StringType
class BsTextExtractor(Transformer, HasInputCol, HasOutputCol):
@keyword_only
def __init__(self, inputCol=None, outputCol=None):
super(BsTextExtractor, self).__init__()
kwargs = self._input_kwargs
self.setParams(**kwargs)
@keyword_only
def setParams(self, inputCol=None, outputCol=None):
kwargs = self._input_kwargs
return self._set(**kwargs)
def _transform(self, dataset):
def f(s):
cleaned_post = BeautifulSoup(s).text
return cleaned_post
t = StringType()
out_col = self.getOutputCol()
in_col = dataset[self.getInputCol()]
return dataset.withColumn(out_col, udf(f, t)(in_col))
Let’s quickly test our BsTextExtractor
class to make sure it does what we’d like it to i.e. remove HTML tags:
bs_text_extractor = BsTextExtractor(inputCol="post", outputCol="cleaned_post")
text_extracted = bs_text_extractor.transform(data_df)
text_extracted.head().cleaned_post
Looks like it works as expected. Now let’s set up our ML pipeline. We set up a number of Transformers and finish up with an Estimator.
Our pipeline looks like this:
StringIndexer
- Convert our tags from string tags to integer labels
BsTextExtractor
- Our custom Transformer to extract out HTML tags
RegexTokenizer
- Tokenize our posts into words, keeping only alphanumerical characters and some other select characters (e.g. we want to keep # or + so that any posts that mention c# or c++ maintain these as whole tokens)
StopWordsRemover
- Removes common stop words that are frequently occurring in the English language and would not necessarily provide any additional information when attempting to separate classes
CountVectorizer
- Our TF-IDF (Term Frequency-Inverse Document Frequency) is split up into 2 parts here, a TF transformer (CountVectorizer) and an IDF transformer (IDF). The CountVectorizer counts the number of words in the post that appear in at least 4 other posts. This is because words that appear in fewer posts than this are likely not to be applicable (e.g. variable names).
IDF
- Inverse Document Frequency. Combined with the CountVectorizer, this provides a statistic that indicates how important a word is relative to other documents. If a word appears regularly in a document and also appears regularly in other documents, it is likely it has no predictive power towards classification. However, if a term appears in
- E.g. if the words “set”, “query” or “dynamic” appears regularly in one class, but also appears regularly across classes, it won’t necessarily provide additional information when trying to classify documents
- Conversely, the words “npm” or “maven” might appear disproportionately frequently in questions about JavaScript or Java, respectively
LogisticRegression
- Our estimator. A multinomial logistic regression estimator is used as the model to classify documents into one of our given classes.
from pyspark.ml import Pipeline
from pyspark.ml.classification import LogisticRegression, OneVsRest
from pyspark.ml.feature import IDF, StringIndexer, StopWordsRemover, CountVectorizer
import nltk
nltk.download('stopwords')
stop_words = list(set(stopwords.words('english')))
label_stringIdx = StringIndexer(inputCol="tags", outputCol="label")
bs_text_extractor = BsTextExtractor(inputCol="post", outputCol="untagged_post")
regex_tokenizer = RegexTokenizer(inputCol=bs_text_extractor.getOutputCol(), outputCol="words", pattern="[^0-9a-z#+_]+")
stopword_remover = StopWordsRemover(inputCol=regex_tokenizer.getOutputCol(), outputCol="filtered_words").setStopWords(
stop_words)
count_vectorizer = CountVectorizer(inputCol=stopword_remover.getOutputCol(), outputCol="countFeatures", minDF=5)
idf = IDF(inputCol=count_vectorizer.getOutputCol(), outputCol="features")
lr = LogisticRegression(featuresCol=idf.getOutputCol(), labelCol="label")
pipeline = Pipeline(stages=[
label_stringIdx,
bs_text_extractor,
regex_tokenizer,
stopword_remover,
count_vectorizer,
idf,
lr])
Now that we’ve defined our pipeline, let’s fit it to our training DataFrame trainDF
:
model = pipeline.fit(trainDF)
We’ll evaluate how well our fitted pipeline performs by then transforming our test DataFrame testDF
to get predicted classes.
This transformation adds classes rawPrediction
(raw output of model with values for each class), probability
(predicted proabability of each class), and prediction
(an integer corresponding to an individual class)
predictions = model.transform(testDF)
We can see this by taking a look at the schema for this DataFrame after the prediction columns have been appended.
predictions.printSchema()
To evaluate our Multi-class classification we’ll use a MulticlassClassificationEvaluator
that will evaluate the predictions using the f1 metric, which is a weighted average of precision and recall scores, which a perfect score at 1.0.
from pyspark.ml.evaluation import MulticlassClassificationEvaluator
evaluator = MulticlassClassificationEvaluator(labelCol="label", predictionCol="prediction", metricName="f1")
evaluator.evaluate(predictions)
Our F1 score here is ~0.66, not bad but there’s room for improvement.
Let’s do some hyperparameter tuning to see if we can nudge that score up a bit. For the most part, our pipeline has stuck to just the default parameters. Here we’ll alter some of these parameters to see if we can improve on our F1 score from before.
We’ll set up a hyperparameter grid and do an exhaustive grid search on these hyperparameters. We start by setting up our hyperparameter grid using the ParamGridBuilder
, then we determine their performance using the CrossValidator
, which does k-fold cross validation (k=3 in this case).
from pyspark.ml.tuning import CrossValidator, ParamGridBuilder
paramGrid = (ParamGridBuilder()
.addGrid(lr.regParam, [0.01, 0.1, 0.5]) \
.addGrid(lr.maxIter, [10, 20, 50]) \
.addGrid(lr.elasticNetParam, [0.0, 0.8]) \
.build())
crossval = CrossValidator(estimator=pipeline,
estimatorParamMaps=paramGrid,
evaluator=evaluator,
numFolds=3)
With our cross validator set up, we can then fit it to our training data.
Note that this takes a while as it has to train 54 models – 3 for regParam
x 3 for maxIter
x 2 for elasticNetParam
and then each of these for 3-folds of data.
model = crossval.fit(trainDF)
We can then make our predictions on the best performing model from our cross validation.
The best performing model significantly outperforms our previous model with no hyperparameter tuning and we’ve brought our F1 score up to ~0.76.
predictions = model.transform(testDF)
evaluator.evaluate(predictions)
This analysis was done with a relatively simple model in a logistic regression. Often One-vs-All Linear Support Vector Machines perform well in this task, I’ll leave it to the reader to see if this can improve further on this F1 score.