Bank Customer Segmentation
Published:
Bank Customer Segmentation: Data Science Portfolio Project
Table of Contents
- Project Overview
- EDA
- Feature Engineering
- Customer Segmentation
- Demographic Insight
- Supervised ML Prediction
- Conclusion
Project Overview
Banks face the challenge of understanding, managing, and retaining millions of customers. This project applies data-driven segmentation techniques to classify bank customers according to profitability and behavior, enabling targeted marketing and business strategies that maximize revenue and retention. I use a real-world dataset (Kaggle: 1M+ transactions, 800K+ customers, India) with rich demographics, account balances, and transactional histories. My workflow combines unsupervised clustering (K-Means) with demographic analysis and supervised machine learning (LightGBM).
Project repo: GitHub - BankCustomerSegmentation
EDA
For the first step of the project, I loaded the 1-million+ transaction records, filtereed out rows with missing key fields (DOB, gender, location, balances), standardised date formats (birthdate, transaction date), and cleaned the rare gender occurence. To have a brief understanding of how the data looks like, I have plotted a few graphs to illustrtae some patterns found in different features.
1. Gender
# Distribution of Customer Gender
sns.countplot(data=df, x='CustGender')
plt.title('Customer Gender Distribution')
plt.show()
The data shows that there is an imbalance between gender groups, with males almost 2.5x fold of females.
2. Locations
print(f'Customers are from {len(df['CustLocation'].unique())} unqiue locations.')
location_counts = df['CustLocation'].value_counts()
total_customers = len(df)
# Compute cumulative sum of customer counts sorted by top locations
cumulative_pct = location_counts.cumsum() / total_customers * 100
# Print coverage for top 10, 15, 20 locations
for n in [10, 15, 20]:
print(f"Top {n} locations cover: {cumulative_pct.iloc[n-1]:.2f}% of customers")
Customers are from 8157 unqiue locations.
Top 10 locations cover: 52.24% of customers
Top 15 locations cover: 59.14% of customers
Top 20 locations cover: 62.93% of customers
Because there is a large number of unique locations in the customer records, I decide to keep the top 15 locations and bin the rest into a ‘Other’ category as it covers about 60% of the customers already.
# Compute and plot top 15 locations by count
top_15_locations = df['CustLocation'].value_counts().nlargest(15).index
df['Top15Location'] = df['CustLocation'].where(df['CustLocation'].isin(top_15_locations), 'Other')
top_15_location_counts = df['Top15Location'].value_counts().sort_values()
plt.figure(figsize=(10,6))
top_15_location_counts.plot(kind='barh', color='skyblue')
plt.title('Top 15 Customer Locations plus Other')
plt.xlabel('Count')
plt.ylabel('CustLocation')
plt.tight_layout()
plt.show()
3. Account Balance and Transaction Amounts Account Balance and Transaction Amounts were foudn to be heavily skewed, which is very common in real-world scenarios. Log transformation was applied to both to better visualise its distribution.
# Histogram of Customer Account Balance (log-transformed)
df['log_balance'] = np.log1p(df['CustAccountBalance'] + 1)
plt.figure(figsize=(8, 5))
sns.histplot(df['log_balance'], bins=50, kde=True)
plt.title('Log-Transformed Customer Account Balance Distribution')
plt.xlabel('Log(1 + Balance) (INR)')
plt.ylabel('Count')
plt.show()
# Histogram of Transaction Amounts (log-transformed)
df['log_amount'] = np.log1p(df['TransactionAmount (INR)'] +1)
plt.figure(figsize=(8, 5))
sns.histplot(df['log_amount'], bins=50, kde=True)
plt.title('Log-Transformed Transaction Amount Distribution')
plt.xlabel('Log(1 + Amount) (INR)')
plt.ylabel('Count')
plt.show()
4. Age
# Additional: Analyze customer age if needed
if 'CustomerDOB' in df.columns:
current_date = pd.to_datetime('today').normalize()
df.loc[df['CustomerDOB'] > current_date, 'CustomerDOB'] -= pd.DateOffset(years=100)
df['Age'] = (current_date - df['CustomerDOB']).dt.days // 365
sns.histplot(df['Age'], bins=20)
plt.title('Customer Age Distribution')
plt.xlabel('Age')
plt.show()
Age was computed from the difference between current date and customer’s DOB, and then was ploted in a historgram.
5. EDA Summary:
- Visualized distributions: account balance and transaction amounts (log-transform to handle skew).
- Grouped top 15 locations (covering 60% of all customers), assigned ‘Other’ for analysis clarity.
- Filtered out ages <18 and >90 for robust modeling.
Feature Engineering
After EDA, I then proceeded to feature engineering before applying clustering methods to group customers into 4 segments. These include aggeregate customer activities in terms of their transaction sums, recency, frequency. I also generate interaction features, such as age bins, age-bins-location mix, gender-location mix, age-bins-gender mix, location based GDP per hear as attributes.
# Group transactions by CustomerID to get aggregate behavioral and value metrics
customer_report = df.groupby('CustomerID').agg({
'TransactionAmount (INR)': ['sum', 'mean', 'count'],
'CustAccountBalance': 'last',
'TransactionDate': ['min', 'max']
})
customer_report.columns = ['TotalTransSum', 'AvgTransAmount', 'TransCount', 'EndBalance', 'FirstTrans', 'LastTrans']
customer_report['Log_TotalTransSum'] = np.log1p(customer_report['TotalTransSum'])
customer_report['Log_AvgTransAmount'] = np.log1p(customer_report['AvgTransAmount'])
customer_report['Log_EndBalance'] = np.log1p(customer_report['EndBalance'])
customer_report['RecencyDays'] = (pd.to_datetime('today') - customer_report['LastTrans']).dt.days
customer_report.reset_index(inplace=True)
features = ['Log_TotalTransSum', 'Log_AvgTransAmount', 'TransCount', 'Log_EndBalance', 'RecencyDays']
scaler = RobustScaler()
X_scaled = scaler.fit_transform(customer_report[features])
Customer Segmentation
Now, it goes to the main task, using K-means clustering to put customers into 4 segments. The reason why I picked 4 segments and K-means clustering is based on a recent systematic review on algorithmic customer segmentation, that K-means clustering with 4 segments is found to be the most common method in using machine learning methods to group customers. For more details, see Salminen, J., Mustak, M., Sufyan, M. et al..
kmeans = KMeans(n_clusters=4, random_state=42)
customer_report['Segment'] = kmeans.fit_predict(X_scaled)
segment_profile = customer_report.groupby('Segment').agg({
'CustomerID': 'count',
'TotalTransSum': 'sum',
'AvgTransAmount': 'mean',
'EndBalance': 'mean'
}).rename(columns={'CustomerID': 'NumOfCustomers'})
# Calculate percentage revenue/profit per segment
total_revenue = segment_profile['TotalTransSum'].sum()
segment_profile['RevenuePct'] = 100 * segment_profile['TotalTransSum'] / total_revenue
print(segment_profile)
| Segment | NumOfCustomers | TotalTransSum | AvgTransAmount | EndBalance | RevenuePct |
|---|---|---|---|---|---|
| 0 | 231385 | 1.146041e+09 | 4193.875822 | 228902.678497 | 80.224432 |
| 1 | 359981 | 1.857673e+08 | 461.943159 | 77088.697022 | 13.003959 |
| 2 | 148792 | 9.890596e+06 | 64.467939 | 46192.802503 | 0.692355 |
| 3 | 98428 | 8.684481e+07 | 724.731447 | 261.786160 | 6.079254 |
I further explore the composition of these segments by investigating their number and revenue contributioons.
segment_counts = customer_report['Segment'].value_counts(normalize=True) * 100
percent_df = segment_counts.reset_index()
percent_df.columns = ['Segment', 'Percent']
sns.barplot(x='Segment', y='Percent', data=percent_df)
plt.title('Percentage of Customers in Each Segment')
plt.xlabel('Segment')
plt.ylabel('Percent of Customers')
plt.show()
segment_profile['RevenuePct'].plot(kind='bar', color='skyblue')
plt.title('Revenue Contribution by Segment')
plt.xlabel('Segment')
plt.ylabel('Percentage of Total Revenue (%)')
plt.show()
According to the result, these are the identified groups: - Segment 0: Premier Clients (top 20%, contribute 80%+ revenue) - Segment 1: Mass Market (largest % of clients, steady mid-level value) - Segment 2: Dormant/Low Value (least value, disengaged) - Segment 3: Transactional/Emerging (active, low-average balances)
Demographic Insight
To understand if there is a significant relationship between customers’ background and their banking behaviours, Chi-square/Kruskal-Wallis tests were applied.
# Create observed contingency table
contingency = pd.crosstab(customer_with_demo['Segment'], customer_with_demo['CustGender'])
chi2, p, dof, expected = chi2_contingency(contingency)
# Calculate standardized residuals
residuals = (contingency - expected) / np.sqrt(expected)
# Plot heatmap of standardized residuals
plt.figure(figsize=(8, 6))
sns.heatmap(residuals, annot=True, cmap='coolwarm', center=0)
plt.title('Standardized Residuals for Gender vs. Segment')
plt.xlabel('Gender')
plt.ylabel('Segment')
plt.show()
# Build observed contingency table for Segment vs TopLocation
contingency_loc = pd.crosstab(customer_with_demo['Segment'], customer_with_demo['Top15Location'])
# Chi-square test and expected counts
chi2, p, dof, expected_loc = chi2_contingency(contingency_loc)
# Calculate standardized residuals
residuals_loc = (contingency_loc - expected_loc) / np.sqrt(expected_loc)
# Plot heatmap of standardized residuals
plt.figure(figsize=(18, 6))
sns.heatmap(residuals_loc, annot=True, fmt=".1f", cmap="coolwarm", center=0, cbar_kws={'label': 'Std. Residual'})
plt.title("Standardized Residuals for Top Locations vs. Segment")
plt.xlabel("Top Location")
plt.ylabel("Segment")
plt.xticks(rotation=45, ha='right')
plt.tight_layout()
plt.show()
plt.figure(figsize=(10,6))
sns.boxplot(x='Segment', y='Age', data=customer_with_demo, showfliers=False)
plt.title('Age Distribution by Segment')
plt.xlabel('Segment')
plt.ylabel('Age')
plt.show()
Summary
- Significant patterns found: Segment 0 skewed female and metropolitan, Segment 2/3 skewed male, strong location impact (Mumbai/New Delhi clusters for Premier clients).
Supervised ML Prediction
Noting that there is significant relationship between customers’ background, I moved on to train a Light GBM to classify whether a customer is likely to be fall into the high-valued Segment 0 or not.
# Train LightGBM classifier
baseline_model = lgb.LGBMClassifier(
objective='binary',
n_estimators=1000,
random_state=42
)
# Fit with early stopping + logging (LightGBM 4.x uses callbacks)
baseline_model.fit(
X_train, y_train,
eval_set=[(X_val, y_val)],
callbacks=[lgb.early_stopping(50), lgb.log_evaluation(50)]
)
# Predict and evaluate
baseline_proba = baseline_model.predict_proba(X_val)
baseline_y_pred = np.argmax(baseline_proba, axis=1)
print("Baseline Macro F1:", f1_score(y_val, baseline_y_pred, average='macro'))
print(classification_report(y_val, baseline_y_pred))
Baseline Macro F1: 0.4897450189570368
precision recall f1-score support
False 0.71 0.97 0.82 136987
True 0.55 0.09 0.16 59503
accuracy 0.70 196490
macro avg 0.63 0.53 0.49 196490
weighted avg 0.66 0.70 0.62 196490
classes, counts = np.unique(y_train, return_counts=True)
weights = {int(c): len(y_train)/(len(classes)*cnt) for c, cnt in zip(classes, counts)}
# Train LightGBM classifier
added_cw_model = lgb.LGBMClassifier(
class_weight=weights,
objective='binary',
n_estimators=1000,
random_state=42
)
# Fit with early stopping + logging (LightGBM 4.x uses callbacks)
added_cw_model.fit(
X_train, y_train,
eval_set=[(X_val, y_val)],
callbacks=[lgb.early_stopping(50), lgb.log_evaluation(50)]
)
# Predict and evaluate
added_cw_model_proba = added_cw_model.predict_proba(X_val)
added_cw_model_y_pred = np.argmax(added_cw_model_proba, axis=1)
print("After adding class weight Macro F1:", f1_score(y_val, added_cw_model_y_pred, average='macro'))
print(classification_report(y_val, added_cw_model_y_pred))
After adding class weight Macro F1: 0.5857263380582927
precision recall f1-score support
False 0.77 0.63 0.70 136987
True 0.40 0.58 0.48 59503
accuracy 0.61 196490
macro avg 0.59 0.60 0.59 196490
weighted avg 0.66 0.61 0.63 196490
metrics_before = [
precision_score(y_val, baseline_y_pred),
recall_score(y_val, baseline_y_pred),
f1_score(y_val, baseline_y_pred)
]
metrics_after = [
precision_score(y_val, added_cw_model_y_pred),
recall_score(y_val, added_cw_model_y_pred),
f1_score(y_val, added_cw_model_y_pred)
]
metrics = ['Precision', 'Recall', 'F1']
x = np.arange(len(metrics))
width = 0.35
fig, ax = plt.subplots(figsize=(7,5))
ax.bar(x - width/2, metrics_before, width, label='Before')
ax.bar(x + width/2, metrics_after, width, label='After', color='orange')
ax.set_xticks(x)
ax.set_xticklabels(metrics)
ax.set_ylabel('Score')
ax.set_ylim(0, 1)
ax.set_title('Key Metrics Before and After Class Weights')
ax.legend()
plt.tight_layout()
plt.show()
This barplot shows the difference between the light GBM model before and after adding calss weights. As it shows, Marco F1 increased from 0.49 to 0.58 and recall for top-valued customer (segment 0) has improved substantially from 0.09 to 0.58), making the model much better at catching all valuabale customers after adding class weight. It can facilitate marketing/retention strategies that target this group of customers by predicting their likelihood to be a top-valued customer just based on a few simple backgrdoun information (gender, locations, and age).
Conclusion
Summary
- Premier Clients generate 80%+ of revenue from only 20% of the base—critical for retention focus!
- Segment demographics underpin the business case for personalized offerings.
- Location-based targeting revealed: Mumbai/New Delhi dominate Premier segment opportunity.
- Class weighting in ML models dramatically improves recall for strategic segment identification.
Business & Technical Impact
- Enables targeted engagement: VIP campaigns, mass-market upselling, dormant customer reactivation
- Facilitates resource allocation: banks can prioritize segments and regions that drive the most business impact
- Provides a repeatable framework for clustering and prediction in other domains (retail, insurance, fintech)
Tech Stack
- Python (3.9+), Jupyter Notebook
- Pandas, Numpy, Matplotlib, Seaborn
- scikit-learn, LightGBM, SciPy
Source & Next Steps
- Full notebook and code at GitHub
- Next steps: extend segmentation to temporal analysis, deploy as real-time dashboard, and apply transfer learning on related banking datasets.
