Bank Customer Churn Prediction using ANN, TabNet, FT-Transformer and Autoencoder#1088
Open
kapoorraaghav wants to merge 5 commits into
Open
Bank Customer Churn Prediction using ANN, TabNet, FT-Transformer and Autoencoder#1088kapoorraaghav wants to merge 5 commits into
kapoorraaghav wants to merge 5 commits into
Conversation
|
Our team will soon review your PR. Thanks @kapoorraaghav :) |
Author
|
Hi @abhisheks008 any updates? |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Pull Request for DL-Simplified 💡
Issue Title : #924
Closes: #issue number that will be closed through this PR
Describe the add-ons or changes you've made 📃
Give a clear description of what have you added or modifications made
Type of change ☑️
Added a complete end-to-end deep learning project for predicting bank customer churn. Here's what's included:
EDA: class distribution, feature distributions by churn status, correlation heatmap, churn rate by country and gender. The dataset has about a 20/80 class imbalance, which is handled using SMOTE.
ANN: 3-layer feedforward network (128, 64, 32) with BatchNorm and Dropout, used as a baseline. EarlyStopping and ReduceLROnPlateau callbacks are included.
TabNet: built from scratch in pure TensorFlow, with no external library. It uses sequential attention across 3 steps to select important features at each step.
FT-Transformer: a custom FeatureTokenizer layer embeds each scalar feature into a d-dimensional token. Then, 2 transformer blocks with MultiHeadAttention capture feature interactions.
Autoencoder + Classifier: an unsupervised pre-training stage where the encoder compresses input to 16 dimensions, followed by a supervised classifier trained on the compressed representations.
SHAP: KernelExplainer on the best-performing model for feature contribution analysis.
Keras Tuner: RandomSearch over ANN hyperparameters (hidden units, dropout rate, learning rate) with 10 trials.
Streamlit Web App: a dark-themed interactive demo that takes customer details as input. It runs all 4 models simultaneously and shows per-model churn probability, ensemble average, a risk badge, bar chart comparison, and model performance table.
How Has This Been Tested? ⚙️
Describe how it has been tested
Notebook was run end to end on the Churn Modelling dataset (10,000 rows) in a
conda environment with Python 3.9, TensorFlow 2.10, Keras 2.10.
All 4 models trained successfully with EarlyStopping — no overfitting observed.
SMOTE applied successfully — class balance confirmed after resampling.
Model weights saved and reloaded correctly — predictions verified on test set.
Streamlit app tested locally — all 4 models load and return predictions without errors.
ROC-AUC verified on held-out test set (80/20 stratified split) for all models.
Checklist: ☑️