Multiclass Text Classification with PySpark

In this post we’ll explore the use of PySpark for multiclass classification of text documents. The data I’ll be using here contains Stack Overflow questions and associated tags. It is available from https://storage.googleapis.com/tensorflow-workshop-examples/stack-overflow-data.csv.

The idea will be to use PySpark to create a pipeline to analyse this data and create a classifier that will classify questions. In future questions could be auto-tagged by such a classifier or tags could be recommended to users prior to posting.

Quick disclaimer: At the time of writing, I am currently a Microsoft Employee, so naturally this was all carried out using Databricks on Azure but applies to any Spark cluster

We’ll start by loading in our data. We load the data into a Spark DataFrame directly from the CSV file.

In [2]:
filepath = 'dbfs:/FileStore/tables/stack_overflow_data-0b671.csv'

data_df = spark.read.option("header", "true").csv(filepath)

Let’s have a look at our data, we can see that there are posts and tags.

The data has many nuances, including HTML tags and a lot of characters you might find when coding, such as curly braces, semicolons and square brackets.

In [4]:
data_df.show(5)
post tags
what is causing t… c#
have dynamic html… asp.net
how to convert a … objective-c
.net framework 4 … .net
trying to calcula… python

only showing top 5 rows

In [5]:
data_df.head()['post']
Out[14]: 'what is causing this behavior in our c# datetime type <pre><code>[test] public void sadness() { var datetime = datetime.utcnow; assert.that(datetime is.equalto(datetime.parse(datetime.tostring()))); } </code></pre> failed : <pre><code> expected: 2011-10-31 06:12:44.000 but was: 2011-10-31 06:12:44.350 </code></pre> i wish to know what is happening behind the scenes in tostring() etc to cause this behavior. edit after seeing jon s answer : <pre><code>[test] public void newsadness() { var datetime = datetime.utcnow; assert.that(datetime is.equalto(datetime.parse(datetime.tostring( o )))); } </code></pre> result : <pre><code>expected: 2011-10-31 12:03:04.161 but was: 2011-10-31 06:33:04.161 </code></pre> same result with capital and small o . i m reading up the docs but still unclear.'

We’ll want to get an idea of the distribution of our tags, so let’s do a count on each tag and see how many instances of each tag we have.

In [7]:
data_df.groupby('tags').count().show(30)
tags count
iphone 2000
android 2000
c# 2000
null 20798
html 2000
asp.net 2000
mysql 2000
jquery 2000
javascript 2000
css 2000
sql 2000
c++ 2000
c 2000
objective-c 2000
java 2000
php 2000
.net 2000
ios 2000
python 2000
angularjs 2000
ruby-on-rails 2000

Luckily our data is very balanced and we have a good number of samples in each class, so we won’t need to do any resampling to balance out our classes. The notable exception here is the null tag values. We’ll filter out all the observations that don’t have a tag.

In [9]:
data_df = data_df.filter(data_df.tags.isNotNull())

And now we can double check that we have 20 classes, all with 2000 observations each:

In [11]:
data_df.groupby('tags').count().show(30)
tags count
iphone 2000
android 2000
c# 2000
html 2000
asp.net 2000
mysql 2000
jquery 2000
javascript 2000
css 2000
sql 2000
c++ 2000
c 2000
objective-c 2000
java 2000
php 2000
.net 2000
ios 2000
python 2000
angularjs 2000
ruby-on-rails 2000

Great. Just as we normally we would we will split our data out into a training DataFrame and a hold-out testing DataFrame to determine how well our model is performing. We’ll use 75% of our data as a training set.

In [13]:
(trainDF, testDF) = data_df.randomSplit((0.75, 0.25), seed=100)

We’re now going to define a pipeline to clean up our data. However, the first thing we’re going to want to do is remove those HTML tags we see in the posts. As there is no built-in to do this in PySpark, we’re going to define our own custom Tranformer – we’ll call this transformer BsTextExtractor as it’ll use BeautifulSoup to extract just the text from the HTML.

We define a new class that will be a child class of the built-in Transformer class that has its own user-defined function (udf) that uses BeautifulSoup to extract the text from the post. This output will be a StringType(). This custom Transformer can then be embedded as a step in our Pipeline, creating a new column with just the extracted text.

In [15]:
from bs4 import BeautifulSoup

from pyspark import keyword_only
from pyspark.ml import Transformer
from pyspark.ml.param.shared import HasInputCol, HasOutputCol
from pyspark.sql.functions import udf
from pyspark.sql.types import StringType

class BsTextExtractor(Transformer, HasInputCol, HasOutputCol):

    @keyword_only
    def __init__(self, inputCol=None, outputCol=None):
        super(BsTextExtractor, self).__init__()
        kwargs = self._input_kwargs
        self.setParams(**kwargs)

    @keyword_only
    def setParams(self, inputCol=None, outputCol=None):
        kwargs = self._input_kwargs
        return self._set(**kwargs)

    def _transform(self, dataset):

        def f(s):
            cleaned_post = BeautifulSoup(s).text
            return cleaned_post

        t = StringType()
        out_col = self.getOutputCol()
        in_col = dataset[self.getInputCol()]
        return dataset.withColumn(out_col, udf(f, t)(in_col))

Let’s quickly test our BsTextExtractor class to make sure it does what we’d like it to i.e. remove HTML tags:

In [17]:
bs_text_extractor = BsTextExtractor(inputCol="post", outputCol="cleaned_post")
text_extracted = bs_text_extractor.transform(data_df)
text_extracted.head().cleaned_post
Out[15]: 'what is causing this behavior in our c# datetime type [test] public void sadness() { var datetime = datetime.utcnow; assert.that(datetime is.equalto(datetime.parse(datetime.tostring()))); } failed : expected: 2011-10-31 06:12:44.000 but was: 2011-10-31 06:12:44.350 i wish to know what is happening behind the scenes in tostring() etc to cause this behavior. edit after seeing jon s answer : [test] public void newsadness() { var datetime = datetime.utcnow; assert.that(datetime is.equalto(datetime.parse(datetime.tostring( o )))); } result : expected: 2011-10-31 12:03:04.161 but was: 2011-10-31 06:33:04.161 same result with capital and small o . i m reading up the docs but still unclear.'

Looks like it works as expected. Now let’s set up our ML pipeline. We set up a number of Transformers and finish up with an Estimator.

Our pipeline looks like this:

  • StringIndexer
    • Convert our tags from string tags to integer labels
  • BsTextExtractor
    • Our custom Transformer to extract out HTML tags
  • RegexTokenizer
    • Tokenize our posts into words, keeping only alphanumerical characters and some other select characters (e.g. we want to keep # or + so that any posts that mention c# or c++ maintain these as whole tokens)
  • StopWordsRemover
    • Removes common stop words that are frequently occurring in the English language and would not necessarily provide any additional information when attempting to separate classes
  • CountVectorizer
    • Our TF-IDF (Term Frequency-Inverse Document Frequency) is split up into 2 parts here, a TF transformer (CountVectorizer) and an IDF transformer (IDF). The CountVectorizer counts the number of words in the post that appear in at least 4 other posts. This is because words that appear in fewer posts than this are likely not to be applicable (e.g. variable names).
  • IDF
    • Inverse Document Frequency. Combined with the CountVectorizer, this provides a statistic that indicates how important a word is relative to other documents. If a word appears regularly in a document and also appears regularly in other documents, it is likely it has no predictive power towards classification. However, if a term appears in
    • E.g. if the words “set”, “query” or “dynamic” appears regularly in one class, but also appears regularly across classes, it won’t necessarily provide additional information when trying to classify documents
    • Conversely, the words “npm” or “maven” might appear disproportionately frequently in questions about JavaScript or Java, respectively
  • LogisticRegression
    • Our estimator. A multinomial logistic regression estimator is used as the model to classify documents into one of our given classes.
In [19]:
from pyspark.ml import Pipeline
from pyspark.ml.classification import LogisticRegression, OneVsRest
from pyspark.ml.feature import IDF, StringIndexer, StopWordsRemover, CountVectorizer
import nltk

nltk.download('stopwords')
stop_words = list(set(stopwords.words('english')))

label_stringIdx = StringIndexer(inputCol="tags", outputCol="label")
bs_text_extractor = BsTextExtractor(inputCol="post", outputCol="untagged_post")
regex_tokenizer = RegexTokenizer(inputCol=bs_text_extractor.getOutputCol(), outputCol="words", pattern="[^0-9a-z#+_]+")
stopword_remover = StopWordsRemover(inputCol=regex_tokenizer.getOutputCol(), outputCol="filtered_words").setStopWords(
    stop_words)
count_vectorizer = CountVectorizer(inputCol=stopword_remover.getOutputCol(), outputCol="countFeatures", minDF=5)
idf = IDF(inputCol=count_vectorizer.getOutputCol(), outputCol="features")
lr = LogisticRegression(featuresCol=idf.getOutputCol(), labelCol="label")

pipeline = Pipeline(stages=[
    label_stringIdx,
    bs_text_extractor,
    regex_tokenizer,
    stopword_remover,
    count_vectorizer,
    idf,
    lr])
[nltk_data] Downloading package stopwords to /root/nltk_data…
[nltk_data] Package stopwords is already up-to-date!

Now that we’ve defined our pipeline, let’s fit it to our training DataFrame trainDF:

In [21]:
model = pipeline.fit(trainDF)

We’ll evaluate how well our fitted pipeline performs by then transforming our test DataFrame testDF to get predicted classes.

This transformation adds classes rawPrediction (raw output of model with values for each class), probability (predicted proabability of each class), and prediction (an integer corresponding to an individual class)

In [23]:
predictions = model.transform(testDF)

We can see this by taking a look at the schema for this DataFrame after the prediction columns have been appended.

In [25]:
predictions.printSchema()
root
— post: string (nullable = true)
— tags: string (nullable = true)
— clean_post: string (nullable = true)
— label: double (nullable = false)
— untagged_post: string (nullable = true)
— words: array (nullable = true)
|– element: string (containsNull = true)
— filtered_words: array (nullable = true)
|– element: string (containsNull = true)
— countFeatures: vector (nullable = true)
— features: vector (nullable = true)
— rawPrediction: vector (nullable = true)
— probability: vector (nullable = true)
— prediction: double (nullable = false)

To evaluate our Multi-class classification we’ll use a MulticlassClassificationEvaluator that will evaluate the predictions using the f1 metric, which is a weighted average of precision and recall scores, which a perfect score at 1.0.

In [27]:
from pyspark.ml.evaluation import MulticlassClassificationEvaluator

evaluator = MulticlassClassificationEvaluator(labelCol="label", predictionCol="prediction", metricName="f1")

evaluator.evaluate(predictions)
Out[116]: 0.6567801627875625

Our F1 score here is ~0.66, not bad but there’s room for improvement.

Let’s do some hyperparameter tuning to see if we can nudge that score up a bit. For the most part, our pipeline has stuck to just the default parameters. Here we’ll alter some of these parameters to see if we can improve on our F1 score from before.

We’ll set up a hyperparameter grid and do an exhaustive grid search on these hyperparameters. We start by setting up our hyperparameter grid using the ParamGridBuilder, then we determine their performance using the CrossValidator, which does k-fold cross validation (k=3 in this case).

In [29]:
from pyspark.ml.tuning import CrossValidator, ParamGridBuilder

paramGrid = (ParamGridBuilder()
  .addGrid(lr.regParam, [0.01, 0.1, 0.5]) \
  .addGrid(lr.maxIter, [10, 20, 50]) \
  .addGrid(lr.elasticNetParam, [0.0, 0.8]) \
  .build())

crossval = CrossValidator(estimator=pipeline,
                          estimatorParamMaps=paramGrid,
                          evaluator=evaluator,
                          numFolds=3)

With our cross validator set up, we can then fit it to our training data.

Note that this takes a while as it has to train 54 models – 3 for regParam x 3 for maxIter x 2 for elasticNetParam and then each of these for 3-folds of data.

In [31]:
model = crossval.fit(trainDF)

We can then make our predictions on the best performing model from our cross validation.

The best performing model significantly outperforms our previous model with no hyperparameter tuning and we’ve brought our F1 score up to ~0.76.

In [33]:
predictions = model.transform(testDF)

evaluator.evaluate(predictions)
Out[120]: 0.7635483181399657

This analysis was done with a relatively simple model in a logistic regression. Often One-vs-All Linear Support Vector Machines perform well in this task, I’ll leave it to the reader to see if this can improve further on this F1 score.