Predicting Sparkify Customer Churn Using Spark

Dilorom
12 min readNov 23, 2020
Photo by Joanna Kosinska on Unsplash

***This post was written a year ago, and was forgotten in a draft version.***

Problem Statement

A problem to solve in Sparkify project is to predict customer churn given a data set consisting of user activity on the platform. Sparkify is a fictitious digital music streaming company like Spotify. There are paid and free users on the platform. Users listen to songs, favorite songs, add songs to a playlist, add friends, and take a bunch of other actions on the platform, such as visiting settings page, upgrade, downgrade or cancel service.

Remember last time when you were trying to cancel a service you have subscribed. Sometimes annoyingly, it goes like this.

Service: “Are you sure?”

You: “Yes, I am!”

Service: “Press OK to confirm”

You: “OK”

Service: “There is a 75% discount campaign going on. Would you like to use the discount?”

You: “No, thank you. Proceed with cancel.”

Service: “Since you have been such a wonderful customer, we would like to wave 50% of your subscription fee.”

You: “No, please cancel!”

Service: “How about 1 month free, 2 months 50% off?”

You: screaming!!!

Well, that is one unpleasant experience to be in.

Customer churn is when a customer leaves the service. Customer churn is bad for business, because gaining a new customer is much more expensive than retaining a customer. To prevent customer churn and/or face a desperate situation like above, it is important for a business to predict customers you are likely to leave the service. It gives the business opportunities to leverage discounts and other benefits to persuade the customer to stay even before they decided to leave. In general, customers who stayed longer less likely to leave a service. This proves to be true in the case study of Sparkify when we explore the data below.

Why to use Spark?

A parallel goal of the project is to use Spark to create a scalable infrastructure. Platforms like Sparkify generate a huge amount of data. Trying to analyze them in a local machine is not viable. For big data we use distributed computing systems. Spark is a big data tool that uses distributed computing to process big amounts of data. Spark comes with its own syntax and way of thinking, which requires learning and getting used to, especially if you are used to Python and its sci-kit learn library. I used a mini data set to be able to run Spark in a local mode to demonstrate this project. In a local mode, I was not be able to use distributed computing. However, it is good for practicing Spark syntax and prototyping my project to scale it in Spark cluster with a full data set in the future.

Steps I took in this project:

1. Load data into Spark.

2. Explore and clean data.

3. Extract features.

4. Build models.

5. Predict churn.

The problem at in project is a binary classification problem: identifying customer as churn or not churn. Since I defined churn and provided examples of churn users, it is a supervised learning problem. Therefore, supervised classification machine learning techniques fit to to solve the problem at hand. Specifically, I used Logistic Regression, Decision Tree, Random Forest and Gradient Boosted Tree techniques below.

Loading data

I connected to SparkContext with SparkSession. Loaded and read the mini data set in json format.

Figure 1 by author

Explore and clean data

Figure 2 by author

There are 18 field in the data set. Except for itemInSession, registration, sessionId, status and ts, the rest of the fields are in a string format. However, it does not mean that they are all categorical fields. There are 286500 records in the data set.

Figure 3 by author

9 columns are missing values. It seems something to do with 8346 and 58392 numbers. All 9 columns with null values are either has 8346 or 58392 null values.

Figure 4 by author

Although userId field did not have nulls, it turns out there are corrupt values in userId field. 8346 users with empty string userId are records of people visiting Sparkify web site. They can visit most pages, but not NextSong page where only registered users can play a song. We will get rid of records with empty userId since it will not be of help in predicting churn.

Figure 5 by author

There are 22 distinct pages a user can visit.

Figure 6 by author
Figure 7 by author

Most frequently visited pages are NextSong, Thumbs Up, Home, Add to Playlist and Add Friend. Least frequently visited pages are Error, Submit Upgrade, Submit Downgrade, Cancel and Cancellation Confirmation. They are better visualized with a scaled count.

Figure 8 by author
Figure 9 by author

Not surprisingly NextSong is the most visited page, where songs are played. Cancellation Confirmation is least visited pages.

Figure 10 by author

There are free users and paid users. There are four times more paid accounts than free accounts.

I have defined churn for users when they visit the page Cancellation Confirmation, because that is where a user is truly breaking up with the service.

I also converted time columns into timestamp format. Looked into how long users been with the service.

Figure 11 by author

It looks like churned users have been on the platform for less time than not churned users.

Looked into distribution by gender.

Figure 12 by author

There seems to be not a big difference in churn by gender distribution. But men have slightly more churn than women.

Extract features

I picked following for features:

1. Session Counts — number of times a user logs into the service.

2. Account Ages — duration of time since a user signed up for the service.

3. Weekly Song Counts — number of songs a user played on a weekly basis.

4. Gender — gender of a user.

5. Artist Count per User — number of artists, whose songs a user listened to.

Pre-processing steps

To prepare the data to process, I removed records with empty userId. There was no duplicate rows, so I did not drop any duplicate records. I converted timestamp field to proper data structure to be able to use them.To extract different features, I created data frames for each. For example, in Weekly Song Count feature, I extracted week out of timestamp, converted it to date data structure, pivoted the table and filled null values with zeros.

Modeling

Balance of churn and not churn users is 52/173.

Figure 13 by author

I split the data set into train and test on 80/20 ratio.

Predicting churn is a type of Supervised Machine Learning problem. To solve a problem with supervised machine learning technique, we build a model, feed a training data matched with correct outputs. The model learns from patterns in the training. Then we give our model new data that it has never seen before (test data). Then we evaluate its performance on the test data. Below I will use Logistic Regression, Decision Tree, Random Forest and Gradient Boosted Tree machine learning techniques to see which one performs best.

Logistic Regression is a supervised classification algorithm. It predicts the probability that a given data entry belongs to the category numbered as ‘1’, in this case churn. We want both precision and recall to be ‘1’, it is achieved rarely.

Decision Tree is a Supervised Machine Learning where the data is continuously split according to some parameter.

Random Forest is a Supervised Machine Learning algorithm, that also belongs to Ensemble learning. It runs multiple Decision Tree classification models in parallel and takes the average of their results. It also belongs to a group of algorithms called bagging algorithms, where they run a bunch of random sample models in parallel and take the average of their results. This technique is good in cutting out outliers, preventing from over-fitting or under-fitting.

Gradient Boosted Tree Classifier is a Supervised Machine Learning algorithm, that uses boosting technique and performs prediction sequentially. After each run, it takes the errors, and tries to understand patterns from the errors and improves on it in the next run. It is also based on tree, therefore Tree in the naming.

Metrics evaluation

Figure 14 by author

Type of evaluation to use in a model performance depends on a type of problem,model (supervised vs. unsupervised) used, and implementation chosen. The problem at hand is a classification problem: predicting churn or not churn on users. This problem is more concerned about accuracy and F1-score. On the contrary, problems related to drug industry would be more concerned with specificity (the proportion of actual negative cases which are correctly identified). Since the data set is small, I particularly focus on F1-score metric. In all four metrics I use, higher score indicates a better model performance.

  • Accuracy : the proportion of the total number of predictions that are correct.
  • Precision : the proportion of positive cases that are correctly identified.
  • Recall : the proportion of actual positive cases which are correctly identified.
  • F1-score: the harmonic mean of precision and recall values that gives the best of precision and recall. It punishes outliers, and is best used with small data sets.

Among four models, Random Forest performed the best, followed by Gradient Boost, Decision Tree and Logistic Regression.

Figure 15 by author

The Logistic Regression model performed fairly in accuracy, but very low in F1-score. I will drop this model as it did not result in an acceptable performance.

Figure 16 by author

The Decisition Tree model performed better than the Logistic Regression. It’s score is higher both in accuracy and F1. However, F1-score of 0.57 still not good enough to accept the model.

Figure 17 by author

The Random Forest model performed high in all four metrics. Accuracy of 0.91 is good and F1-score of 0.8 good as well. This model looks like something we can use, unless we get a better performance with other models or improve it.

Figure 18 by author

The Gradient Boosted Tree model performed better than the Decision Tree model. However, it is worse than the Random Forest model. Accuracy of 0.85 is OK, but F1-score of 0.67 is still on the lower side. I expected the Gradient Boosted Tree model to perform better than Random Forest as it is generally expected to have higher accuracy than the Random Forest model(Source). However, data characteristics might have played a role in a different performance output. The Gradient Boosted Tree model learns from each run error, while the Random Forest model gets the average score of random samples of parallel running decision trees.

Model Evaluation and Validation

A model robustness is measured by difference in each performance when features are changed. A robust model output label — in this case, churn, would not change dramatically when input variables (parameters) are changed. This indicates that even with the full data set, the model performances would not differ dramatically.

On top of evaluation metrics above, I used cross validation framework. Cross validation is not a metric, but a technique used to evaluate a model performance. It takes training a model, reserves some portion of model for validation, and runs validation on it. With each repetition, called folds, these validation steps are repeated. In this case, I used 3-fold cross validation. I tuned hyper parameters to see if I can get a better performance in the Random Forest model. However, the refined model performance did not improve past Random Forest model performance.

Model improvement

I used ParamGridBuilder() to conduct a grid search-based model selection. It also allows to add a param with multiple values and overwrites if the input param exists. I tried to run it on the Random Forest model, since my Random Forest model was the winner out of four models I trained. I tried to tune maximum depth of the tree to [7, 7], minimum instances per node to [1, 3] and the number of trees to [30, 40]. My model performance improved. Accuracy went up from 0.85 to 0.88, Precision increased from 0.71 to 0.83, the recall stayed the same and F1-score improved from 0.67 to 0.71. I think it proved that the Random Forest model is robust and a right choice out of four models I ran.

Figure 19 by author

To improve the model, I could have tried playing with different set and range of parameters. When running the project on Spark cluster mode, I would like to try improving the Gradient Boosted Tree model and see if it results in a better performance. (I tried improving the Gradient Boosted Tree model in this project, but it took very long time and I was not able to play around to see the best results, therefore, dropped from the project for now). Another potential improvement is to try different machine learning techniques such as Support Vector Machine (SVM), NaiveBayes, and AdaBoost. SVM is a great for this type of problems, because it focuses on training features along the separating hyper planes (in this case, binary). Naive Bayes classification applies a simplified version of Bayes’ Theorem to every observation based on its features. It treats every feature independently, and especially because of this, trains the data quickly. This classification might come handy especially when running the project on the full data set. AdaBoost (Adaptive Boosting) classification combines multiple weak classifiers to build a strong classifier. It progressively learns from weak classifier’s errors to improve on the next run. Usually, AdaBoost techniques results in higher accuracy, and it might be a good use case for Sparkify project problem with the full data set.

Conclusion

The final model with the best performance is the Random Forest model. It results in accuracy of 0.91, which is good, precision of 0.86, which is good enough, recall of 0.76, which is not bad, and F1-score of 0.80 which is good. This means 91% of time this model predicts churn(true positive) and not churn(true negative) users correctly out of all users it. It performs on the lower side on sensitivity or recall, which is the proportion of actual positive cases which are correctly identified. For the business it means not all users who will leave Sparkify will be predicted as churn users. So Sparkify will miss an opportunity to take proactive actions on these users.

Important parameters for predicting churn includes ( weather high or low) visit to NextSong page(where songs are played), run of advertisements, and Add to Playlist in descending order. The full list is below:

Figure 20 by author

Using this insights, Sparkify may try to use different techniques of dealing with advertisements. For example, it might try to reduce the number of advertisements played for users who are predicted to be churn users. It might also try to play different kind, different duration of advertisements and see if it makes a difference. For the NextSong, assuming that low number of visits to NextSong page (which became apparent from the EDA process), it may try to use recommending engines to find user specific songs to recommend to the user. If it is already using such systems, then it may focus on diversity and novelty effects of running recommendation systems.

The difficult part of the project was feature extraction and interpreting model metrics. I also faced complications while trying to improve my model. My initial hyper parameter tuning on the Random Forest model did not result in an improved performance. In my second try of refinement, my model improved. I also tried tuning parameters for the Gradient Boosted Tree model initially. However, model training took very long time, and did not allow me to play with different sets and ranges of parameters to get better experience. Therefore, I postponed refinement on the Gradient Boosted Tree model until I run the project on Spark cluster. Overall, I had good experience of getting introduced to Spark, and using machine learning techniques to predict customer churn. With each step and iteration, I learned a lot. I look forward to applying this project on Spark cluster and learn from the experience.

--

--