Explainable models: Heart Failure Analysis¶
This notebook performs a classification analysis on the Heart Failure Prediction dataset in such a way that doctors can use it in their clinics directly without using any blackbox models. The goal is to predict whether a patient has heart disease based on clinical and demographic features such as age, blood pressure, cholesterol level, and exercise-related indicators.
Dataset¶
The dataset is designed to predict the presence of heart disease using a mix of numeric and categorical attributes collected from patients. It combines data from five well-known heart disease studies into a single, clean dataset with 918 observations. Predicting heart disease risk is critical for early intervention, improving patient outcomes, and supporting healthcare decision-making.
Use Case Description¶
Our objective is to classify patients as having heart disease or not using logistic regression techniques. This involves:
1. Exploring relationships between patient attributes and heart disease occurrence
2. Identifying statistically significant predictors
3. Validating model assumptions and performance
4. Diagnosing issues such as multicollinearity and calibration
5. Evaluating model robustness using cross-validation, ROC curves, and threshold tuning
Explainable Models¶
For doctors in a hospital setting, simple and highly explainable models are essential because they can be translated into quick, actionable rules for patient screening. These models allow clinicians to make fast decisions without relying on complex computations or black-box algorithms. Some commonly used, interpretable models include:
1. Logistic Regression: Provides clear insights through coefficients and odds ratios, making it easy to understand how each factor influences heart disease risk.
2. Decision Trees: Offer intuitive, rule-based structures that can be visualized and converted into simple “if-then” guidelines for clinical use.
3. Rule-Based Models (e.g., CART or CHAID) – Extend decision trees into sets of human-readable rules for rapid triage.
4. Naïve Bayes: Based on conditional probabilities; Easy to interpret and explain as likelihoods.
5. Simple k-Nearest Neighbors (k-NN): While not inherently rule-based, its logic ("You look most like these N past patients who had/didn’t have heart disease.”) is easy to explain in clinical terms.
Dataset card: https://www.kaggle.com/datasets/fedesoriano/heart-failure-prediction
Fallback CSV mirror: https://raw.githubusercontent.com/benbobyabraham/heart_failure_prediction_dataset_kaggle/main/heart.csv
Importing packages
import os, sys, warnings, json
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.compose import ColumnTransformer
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import OneHotEncoder, StandardScaler
from sklearn.impute import SimpleImputer
from sklearn.model_selection import train_test_split, StratifiedKFold, GridSearchCV, cross_val_score, learning_curve, validation_curve
from sklearn.metrics import (accuracy_score, precision_score, recall_score, f1_score, roc_auc_score, roc_curve, auc,
confusion_matrix, classification_report, precision_recall_curve, average_precision_score,
brier_score_loss, RocCurveDisplay, PrecisionRecallDisplay)
from sklearn.linear_model import LogisticRegression
from sklearn.dummy import DummyClassifier
from sklearn.calibration import CalibratedClassifierCV, CalibrationDisplay
from sklearn.tree import export_text
import joblib
import statsmodels.api as sm
import statsmodels.formula.api as smf
warnings.filterwarnings('ignore')
np.random.seed(42)
sns.set(style='whitegrid', context='notebook')
We try to read heart.csv from the Data folder. If not present, we fetch a public read‑only mirror of the Kaggle CSV from GitHub.
from pathlib import Path
import io, urllib.request
DATA_LOCAL = Path('../Data/heart.csv')
DATA_URL = 'https://raw.githubusercontent.com/benbobyabraham/heart_failure_prediction_dataset_kaggle/main/heart.csv'
if DATA_LOCAL.exists():
df = pd.read_csv(DATA_LOCAL)
else:
try:
with urllib.request.urlopen(DATA_URL) as resp:
df = pd.read_csv(io.BytesIO(resp.read()))
except Exception as e:
raise FileNotFoundError('Could not find heart.csv locally or fetch from the fallback URL. Please download from Kaggle and place heart.csv next to this notebook.')
print(df.shape)
df.head()
(918, 12)
| Age | Sex | ChestPainType | RestingBP | Cholesterol | FastingBS | RestingECG | MaxHR | ExerciseAngina | Oldpeak | ST_Slope | HeartDisease | |
|---|---|---|---|---|---|---|---|---|---|---|---|---|
| 0 | 40 | M | ATA | 140 | 289 | 0 | Normal | 172 | N | 0.0 | Up | 0 |
| 1 | 49 | F | NAP | 160 | 180 | 0 | Normal | 156 | N | 1.0 | Flat | 1 |
| 2 | 37 | M | ATA | 130 | 283 | 0 | ST | 98 | N | 0.0 | Up | 0 |
| 3 | 48 | F | ASY | 138 | 214 | 0 | Normal | 108 | Y | 1.5 | Flat | 1 |
| 4 | 54 | M | NAP | 150 | 195 | 0 | Normal | 122 | N | 0.0 | Up | 0 |
EDA (Exploratory Data Analysis)¶
Data attributes and description
| Column Name | Description |
|---|---|
| Age | Age of the patient (years) |
| Sex | Gender of the patient (M/F) |
| ChestPainType | Type of chest pain (ATA, NAP, ASY, TA) |
| RestingBP | Resting blood pressure (mm Hg) |
| Cholesterol | Serum cholesterol (mg/dl) |
| FastingBS | Fasting blood sugar (>120 mg/dl: 1, else: 0) |
| RestingECG | Resting electrocardiogram results |
| MaxHR | Maximum heart rate achieved |
| ExerciseAngina | Exercise-induced angina (Y/N) |
| Oldpeak | ST depression induced by exercise relative to rest |
| ST_Slope | Slope of the peak exercise ST segment |
| HeartDisease | Target variable (1 = disease present, 0 = no disease) |
Feature types: Mix of numeric (Age, RestingBP, Cholesterol, MaxHR, Oldpeak) and categorical (Sex, ChestPainType, RestingECG, ExerciseAngina, ST_Slope), plus a binary variable FastingBS.
Unit notes: e.g., RestingBP in mmHg, Cholesterol in mg/dl; some sources note 0 values may indicate missing/not measured.
df['FastingBS'] = df.FastingBS.astype(str)
df['HeartDisease'] = df.HeartDisease.astype(str)
df.describe()
| Age | RestingBP | Cholesterol | MaxHR | Oldpeak | |
|---|---|---|---|---|---|
| count | 918.000000 | 918.000000 | 918.000000 | 918.000000 | 918.000000 |
| mean | 53.510893 | 132.396514 | 198.799564 | 136.809368 | 0.887364 |
| std | 9.432617 | 18.514154 | 109.384145 | 25.460334 | 1.066570 |
| min | 28.000000 | 0.000000 | 0.000000 | 60.000000 | -2.600000 |
| 25% | 47.000000 | 120.000000 | 173.250000 | 120.000000 | 0.000000 |
| 50% | 54.000000 | 130.000000 | 223.000000 | 138.000000 | 0.600000 |
| 75% | 60.000000 | 140.000000 | 267.000000 | 156.000000 | 1.500000 |
| max | 77.000000 | 200.000000 | 603.000000 | 202.000000 | 6.200000 |
from ml_plottings import univariate_analysis, classification_bivariate_analysis, plot_correlation_matrix
univariate_analysis(df)
Sex


------------------------------------------------------------------------------------------------
ChestPainType


------------------------------------------------------------------------------------------------
FastingBS


------------------------------------------------------------------------------------------------
RestingECG


------------------------------------------------------------------------------------------------
ExerciseAngina


------------------------------------------------------------------------------------------------
ST_Slope


------------------------------------------------------------------------------------------------
HeartDisease


------------------------------------------------------------------------------------------------
Age



------------------------------------------------------------------------------------------------
RestingBP



------------------------------------------------------------------------------------------------
Cholesterol



------------------------------------------------------------------------------------------------
MaxHR



------------------------------------------------------------------------------------------------
Oldpeak



------------------------------------------------------------------------------------------------
classification_bivariate_analysis(df, 'HeartDisease')











Defining the categorical and continuous variables
num_cols = ['Age', 'RestingBP', 'Cholesterol', 'MaxHR', 'Oldpeak']
cat_cols = ['Sex', 'ChestPainType', 'RestingECG', 'ExerciseAngina', 'ST_Slope']
binary_cols = ['FastingBS']
plot_correlation_matrix(df, num_cols+binary_cols)

Key Insights from Exploratory Data Analysis
1. Univariate Analysis - Target Distribution: About 55% of patients have heart disease, while 45% do not.
- Demographics:
- Approximately 80% of patients are male.
- Age ranges from 28 to 77 years, roughly normally distributed.
- Categorical Features:
- ChestPainType: Over 50% report ASY (asymptomatic) chest pain.
- RestingECG: More than 60% show normal ECG results.
- FastingBS: Majority have normal fasting blood sugar.
- ST_Slope: Most patients have a flat ST slope.
- Numeric Features:
- RestingBP: Mostly between 80–120 mmHg, right-skewed; a value of 0 is an outlier.
- Cholesterol: About 172 missing values (≈17%), likely unrecorded measurements.
- MaxHR: Roughly normal distribution, concentrated between 120–155 bpm.
- Oldpeak: Around 375 values are zero, few below zero, rest above 1.0.
2. Bivariate Analysis
- Age vs HeartDisease: Probability of heart disease increases with age.
- MaxHR vs HeartDisease: Patients with heart disease tend to have lower maximum heart rate.
- Sex vs HeartDisease: Males have a higher chance of heart disease.
- ChestPainType vs HeartDisease: Patients with ASY chest pain have a higher risk.
- FastingBS vs HeartDisease: Elevated fasting blood sugar is associated with higher risk.
- ExerciseAngina vs HeartDisease: Presence of exercise-induced angina significantly increases risk.
- ST_Slope vs HeartDisease: Patients with ST slope = Up have a lower chance of heart disease.
Feature Engineering and Cleaning¶
The null values in Cholesterol and RestingBP indicate missing values.
df.loc[df.Cholesterol==0, 'Cholesterol'] = None
df.loc[df.RestingBP==0, 'RestingBP'] = None
Preprocessing & Pipelines¶
Before training a machine learning model, it’s essential to address missing data and ensure features are on comparable scales. These steps improve model stability and performance.
1. Handling Missing Values
- Mean/Median Imputation: Replace missing numeric values with the mean or median of the column. Median is often preferred for skewed distributions.
- Mode Imputation: For categorical variables, fill missing values with the most frequent category.
- Advanced Techniques: Use algorithms like KNN imputation or model-based imputation for more accurate estimates when data is not missing completely at random.
2. Feature Scaling
- Standardization (Z-score): Transforms features to have zero mean and unit variance. Common for models like logistic regression and SVM.
- Min-Max Scaling: Rescales features to a fixed range (usually [0, 1]). Useful for algorithms sensitive to absolute values, such as neural networks.
- Robust Scaling: Uses median and interquartile range, making it less sensitive to outliers.
A Pipeline in scikit-learn allows us to combine these steps into a single, reproducible workflow. It also makes the workflow modular and reusable for training, tuning, and deployment.Steps in the Pipeline:
1. Imputation
- Numeric Features: Replace missing values with the median (robust to outliers).
- Categorical Features: Fill missing values with the most frequent category.
2. Scaling
- Apply Standardization (Z-score) to numeric features so they have zero mean and unit variance.
- This is essential for models like logistic regression that are sensitive to feature scales.
3. Encoding - Use One-Hot Encoding for categorical variables to convert them into numeric form.
- Set handle_unknown='ignore' to safely handle unseen categories during inference.
feature_cols = num_cols + cat_cols + binary_cols
X = df[feature_cols]
y = df['HeartDisease']
numeric_transformer = Pipeline(steps=[
('imputer', SimpleImputer(strategy='median')),
('scaler', StandardScaler())
])
categorical_transformer = Pipeline(steps=[
('imputer', SimpleImputer(strategy='most_frequent')),
('onehot', OneHotEncoder(handle_unknown='ignore'))
])
preprocess = ColumnTransformer(
transformers=[
('num', numeric_transformer, num_cols+binary_cols),
('cat', categorical_transformer, cat_cols)
]
)
Train/Validation/Test split¶
We divide our data into two parts: training data (80%) and testing data (20%). The training set is used to fit the model, while the test set evaluates how well the model generalizes to unseen data. This helps prevent overfitting and gives us a realistic estimate of model performance. We use Stratified sampling to preserve class ratio.
X_train, X_test, y_train, y_test = train_test_split(
X, y, test_size=0.2, stratify=y, random_state=29
)
X_train.shape, X_test.shape
((734, 11), (184, 11))
# Helper to get post‑preprocessing feature names
def get_feature_names(ct):
"""
Return feature names from a fitted ColumnTransformer `ct`.
Works when the inner transformers expose `get_feature_names_out`.
"""
output_features = []
for name, trans, cols in ct.transformers_:
if name == 'remainder' and trans == 'drop':
continue
# Pipeline with a final step having get_feature_names_out (e.g., OneHotEncoder)
if hasattr(trans, 'named_steps') and hasattr(trans.named_steps.get('onehot', None), 'get_feature_names_out'):
ohe = trans.named_steps['onehot']
feats = list(ohe.get_feature_names_out(cols))
output_features.extend(feats)
# Direct transformer exposing get_feature_names_out (rare for numeric here)
elif hasattr(trans, 'get_feature_names_out'):
feats = list(trans.get_feature_names_out(cols))
output_features.extend(feats)
else:
# Fall back to original column names
output_features.extend(list(cols))
return output_features
from sklearn.metrics import (accuracy_score, precision_score, recall_score, f1_score,
roc_auc_score, average_precision_score,
confusion_matrix, classification_report)
# Helper function to evaluate the model accuracy
def evaluate_classifier(name, model, X_te=X_test, y_te=y_test, prob_pos=None):
y_pred = model.predict(X_te)
y_pred = y_pred.astype(int)
y_te = y_test.astype(int)
if prob_pos is None:
if hasattr(model, 'predict_proba'):
prob_pos = model.predict_proba(X_te)[:,1]
elif hasattr(model, 'decision_function'):
scores = model.decision_function(X_te)
smin, smax = scores.min(), scores.max()
prob_pos = (scores - smin) / (smax - smin + 1e-12)
else:
prob_pos = np.zeros_like(y_te, dtype=float)
acc = accuracy_score(y_te, y_pred)
prec = precision_score(y_te, y_pred, zero_division=0)
rec = recall_score(y_te, y_pred, zero_division=0)
f1 = f1_score(y_te, y_pred, zero_division=0)
try:
roc = roc_auc_score(y_te, prob_pos)
except Exception:
roc = np.nan
pr_auc = average_precision_score(y_te, prob_pos)
print(f"=== {name} ===")
print(f"Accuracy: {acc:.3f}\nPrecision: {prec:.3f}\nRecall: {rec:.3f}\nF1: {f1:.3f}")
print(f"ROC-AUC: {roc:.3f}\nPR-AUC: {pr_auc:.3f}")
print("Confusion matrix:\n", confusion_matrix(y_te, y_pred))
print("Classification report:\n", classification_report(y_te, y_pred, zero_division=0))
return dict(model=name, accuracy=acc, precision=prec, recall=rec, f1=f1, roc_auc=roc, pr_auc=pr_auc)
Baseline model¶
A baseline model is a simple reference point used to evaluate whether a more complex model adds real predictive value. For classification tasks, this often means using a Dummy Classifier that predicts the majority class or random guesses based on class distribution. It doesn’t use any features, but sets a minimum performance benchmark, helping us understand if our logistic regression model truly improves upon naive predictions.
baseline = Pipeline(steps=[('prep', preprocess), ('mdl', DummyClassifier(strategy='most_frequent'))])
baseline.fit(X_train, y_train)
yp = baseline.predict(X_test)
yp_prob = baseline.predict_proba(X_test)[:,1]
Key Classification Metrics¶
- Accuracy
- Proportion of correct predictions:
- \[\text{Accuracy} = \frac{\text{TP + TN}}{\text{Total Samples}}\]
- Simple but can be misleading if classes are imbalanced.
- Precision
- Of all predicted positives, how many are actually positive:
- \[\text{Precision} = \frac{\text{TP}}{\text{TP + FP}}\]
- Recall (Sensitivity)
- Of all actual positives, how many did we correctly predict:
- \[\text{Recall} = \frac{\text{TP}}{\text{TP + FN}}\]
- F1 Score
- Harmonic mean of precision and recall:
- \[ F1 = 2 \times \frac{\text{Precision} \times \text{Recall}}{\text{Precision} + \text{Recall}}\]
- ROC-AUC
- Measures ranking quality across thresholds; higher is better.
# Get the accuracy metrics
lr_metrics = evaluate_classifier("baseline model", baseline, X_test, y_test)
preprocess.fit(X_train, y_train)
# Get the feature names
final_feature_names = get_feature_names(preprocess)
final_feature_names
=== baseline model ===
Accuracy: 0.554
Precision: 0.554
Recall: 1.000
F1: 0.713
ROC-AUC: 0.500
PR-AUC: 0.554
Confusion matrix:
[[ 0 82]
[ 0 102]]
Classification report:
precision recall f1-score support
0 0.00 0.00 0.00 82
1 0.55 1.00 0.71 102
accuracy 0.55 184
macro avg 0.28 0.50 0.36 184
weighted avg 0.31 0.55 0.40 184
['Age',
'RestingBP',
'Cholesterol',
'MaxHR',
'Oldpeak',
'FastingBS',
'Sex_F',
'Sex_M',
'ChestPainType_ASY',
'ChestPainType_ATA',
'ChestPainType_NAP',
'ChestPainType_TA',
'RestingECG_LVH',
'RestingECG_Normal',
'RestingECG_ST',
'ExerciseAngina_N',
'ExerciseAngina_Y',
'ST_Slope_Down',
'ST_Slope_Flat',
'ST_Slope_Up']
y_test_int = np.asarray(y_test).astype(int)
yp_int = np.asarray(yp).astype(int)
print('Baseline accuracy:', accuracy_score(y_test_int, yp_int))
print('Baseline F1:', f1_score(y_test_int, yp_int, zero_division=0))
print("Precision:", precision_score(y_test_int, yp_int, pos_label=1))
print("Recall:", recall_score(y_test_int, yp_int, pos_label=1))
print("F1 Score:", f1_score(y_test_int, yp_int, pos_label=1))
print("ROC-AUC:", roc_auc_score(y_test_int, yp_prob))
Baseline accuracy: 0.5543478260869565
Baseline F1: 0.7132867132867133
Precision: 0.5543478260869565
Recall: 1.0
F1 Score: 0.7132867132867133
ROC-AUC: 0.5
Logistic Regression¶
Logistic Regression is a generalized linear model used for binary classification. It models the probability of the positive class as:
$$ P(y=1|x) = \sigma(w^T x + b), \quad \text{where } \sigma(z) = \frac{1}{1 + e^{-z}}$$ Training is done via maximum likelihood estimation, minimizing log-loss. Coefficients represent log-odds, and exponentiating them gives odds ratios. Regularization options include L2 (ridge), L1 (lasso) for sparsity, and Elastic-Net for a mix.
Key Assumptions
1. Correct Model Specification: The log-odds of the outcome are a linear combination of predictors.
2. Independence of Observations: Each observation is independent of others.
3. Limited Multicollinearity: Predictors should not be highly correlated.
4. Sufficient Events per Variable: Enough positive cases relative to the number of predictors to avoid overfitting.
5. No Perfect Separation: Predictors should not perfectly predict the outcome (can cause convergence issues).
6. Linearity in the Logit for Numeric Predictors: Continuous variables should have a linear relationship with the log-odds (often checked via transformations or splines).
from sklearn.linear_model import LogisticRegression
logreg = Pipeline(steps=[('prep', preprocess),
('clf', LogisticRegression(max_iter=1000, solver='lbfgs', C=1.0, random_state=42))])
logreg.fit(X_train, y_train)
lr_metrics = evaluate_classifier("Logistic Regression", logreg)
=== Logistic Regression ===
Accuracy: 0.842
Precision: 0.861
Recall: 0.853
F1: 0.857
ROC-AUC: 0.909
PR-AUC: 0.923
Confusion matrix:
[[68 14]
[15 87]]
Classification report:
precision recall f1-score support
0 0.82 0.83 0.82 82
1 0.86 0.85 0.86 102
accuracy 0.84 184
macro avg 0.84 0.84 0.84 184
weighted avg 0.84 0.84 0.84 184
The accuracy is 84% suggesting a better fit to the model when compared to the baseline model. The precision and recall have also improved. To explain the logistic regression result, we can look at the slopes of the variables as the odds ratios.
param_grid_lr = [
{# L2 with lbfgs/newton-cg/sag/saga
'clf__solver': ['lbfgs','newton-cg','sag','saga'],
'clf__penalty': ['l2'],
'clf__C': [0.01, 0.1, 1.0, 10.0],
'clf__class_weight': [None, 'balanced']
},
{# L1 with liblinear or saga
'clf__solver': ['liblinear','saga'],
'clf__penalty': ['l1'],
'clf__C': [0.01, 0.1, 1.0, 10.0],
'clf__class_weight': [None, 'balanced']
},
{# Elastic-Net with saga
'clf__solver': ['saga'],
'clf__penalty': ['elasticnet'],
'clf__l1_ratio': [0.2, 0.5, 0.8],
'clf__C': [0.01, 0.1, 1.0, 10.0],
'clf__class_weight': [None, 'balanced']
}
]
base_lr = Pipeline(steps=[
('prep', preprocess),
('clf', LogisticRegression())
])
cv = StratifiedKFold(n_splits=5, shuffle=True, random_state=29)
logreg = GridSearchCV(base_lr, param_grid=param_grid_lr, cv=cv, scoring='roc_auc', n_jobs=-1, refit=True, verbose=1)
logreg.fit(X_train, y_train)
print('Best AUC (CV):', logreg.best_score_)
print('Best params:', logreg.best_params_)
lr_metrics = evaluate_classifier("Logistic Regression (with hyper parameter tuning)", logreg.best_estimator_)
Fitting 5 folds for each of 72 candidates, totalling 360 fits
Best AUC (CV): 0.9280405051136758
Best params: {'clf__C': 0.1, 'clf__class_weight': 'balanced', 'clf__l1_ratio': 0.5, 'clf__penalty': 'elasticnet', 'clf__solver': 'saga'}
=== Logistic Regression (with hyper parameter tuning) ===
Accuracy: 0.842
Precision: 0.869
Recall: 0.843
F1: 0.856
ROC-AUC: 0.909
PR-AUC: 0.924
Confusion matrix:
[[69 13]
[16 86]]
Classification report:
precision recall f1-score support
0 0.81 0.84 0.83 82
1 0.87 0.84 0.86 102
accuracy 0.84 184
macro avg 0.84 0.84 0.84 184
weighted avg 0.84 0.84 0.84 184
df
| Age | Sex | ChestPainType | RestingBP | Cholesterol | FastingBS | RestingECG | MaxHR | ExerciseAngina | Oldpeak | ST_Slope | HeartDisease | |
|---|---|---|---|---|---|---|---|---|---|---|---|---|
| 0 | 40 | M | ATA | 140.0 | 289.0 | 0 | Normal | 172 | N | 0.0 | Up | 0 |
| 1 | 49 | F | NAP | 160.0 | 180.0 | 0 | Normal | 156 | N | 1.0 | Flat | 1 |
| 2 | 37 | M | ATA | 130.0 | 283.0 | 0 | ST | 98 | N | 0.0 | Up | 0 |
| 3 | 48 | F | ASY | 138.0 | 214.0 | 0 | Normal | 108 | Y | 1.5 | Flat | 1 |
| 4 | 54 | M | NAP | 150.0 | 195.0 | 0 | Normal | 122 | N | 0.0 | Up | 0 |
| ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... |
| 913 | 45 | M | TA | 110.0 | 264.0 | 0 | Normal | 132 | N | 1.2 | Flat | 1 |
| 914 | 68 | M | ASY | 144.0 | 193.0 | 1 | Normal | 141 | N | 3.4 | Flat | 1 |
| 915 | 57 | M | ASY | 130.0 | 131.0 | 0 | Normal | 115 | Y | 1.2 | Flat | 1 |
| 916 | 57 | F | ATA | 130.0 | 236.0 | 0 | LVH | 174 | N | 0.0 | Flat | 1 |
| 917 | 38 | M | NAP | 138.0 | 175.0 | 0 | Normal | 173 | N | 0.0 | Up | 0 |
918 rows × 12 columns
# Build formula: categorical vars wrapped in C()
formula = (
'HeartDisease ~ Age + RestingBP + Cholesterol + MaxHR + Oldpeak + FastingBS '
'+ C(Sex) + C(ChestPainType) + C(RestingECG) + C(ExerciseAngina) + C(ST_Slope)'
)
df['HeartDisease'] = df.HeartDisease.astype(int)
sm_model = smf.logit(formula=formula, data=df.dropna()).fit(disp=False)
sm_model.summary()
| Dep. Variable: | HeartDisease | No. Observations: | 746 |
|---|---|---|---|
| Model: | Logit | Df Residuals: | 730 |
| Method: | MLE | Df Model: | 15 |
| Date: | Tue, 28 Oct 2025 | Pseudo R-squ.: | 0.5317 |
| Time: | 14:12:34 | Log-Likelihood: | -241.79 |
| converged: | True | LL-Null: | -516.31 |
| Covariance Type: | nonrobust | LLR p-value: | 2.319e-107 |
| coef | std err | z | P>|z| | [0.025 | 0.975] | |
|---|---|---|---|---|---|---|
| Intercept | -5.4373 | 1.763 | -3.085 | 0.002 | -8.892 | -1.983 |
| FastingBS[T.1] | 0.2924 | 0.331 | 0.883 | 0.377 | -0.357 | 0.941 |
| C(Sex)[T.M] | 1.8655 | 0.313 | 5.952 | 0.000 | 1.251 | 2.480 |
| C(ChestPainType)[T.ATA] | -1.6732 | 0.354 | -4.721 | 0.000 | -2.368 | -0.979 |
| C(ChestPainType)[T.NAP] | -1.5730 | 0.303 | -5.192 | 0.000 | -2.167 | -0.979 |
| C(ChestPainType)[T.TA] | -1.6333 | 0.484 | -3.376 | 0.001 | -2.582 | -0.685 |
| C(RestingECG)[T.Normal] | -0.2298 | 0.284 | -0.809 | 0.419 | -0.787 | 0.327 |
| C(RestingECG)[T.ST] | -0.1746 | 0.394 | -0.443 | 0.658 | -0.947 | 0.598 |
| C(ExerciseAngina)[T.Y] | 0.9074 | 0.267 | 3.397 | 0.001 | 0.384 | 1.431 |
| C(ST_Slope)[T.Flat] | 1.3038 | 0.520 | 2.509 | 0.012 | 0.285 | 2.323 |
| C(ST_Slope)[T.Up] | -1.2100 | 0.566 | -2.140 | 0.032 | -2.318 | -0.102 |
| Age | 0.0314 | 0.015 | 2.119 | 0.034 | 0.002 | 0.060 |
| RestingBP | 0.0118 | 0.007 | 1.614 | 0.107 | -0.003 | 0.026 |
| Cholesterol | 0.0025 | 0.002 | 1.262 | 0.207 | -0.001 | 0.006 |
| MaxHR | 0.0006 | 0.006 | 0.100 | 0.920 | -0.011 | 0.012 |
| Oldpeak | 0.4108 | 0.141 | 2.921 | 0.003 | 0.135 | 0.687 |
# Coefficients -> odds ratios
coefs = logreg.best_estimator_.named_steps['clf'].coef_.ravel()
odds = np.exp(coefs)
coef_df = pd.DataFrame({'feature': final_feature_names, 'coef': coefs, 'odds_ratio': odds})
coef_df = coef_df.sort_values('odds_ratio', ascending=False)
print("Top risk‑increasing features (odds ratio > 1):")
print(coef_df.head(10).round(3))
print("Protective features (lowest odds ratios):")
print(coef_df.tail(10).round(3))
# Plain‑English explanation
print("Plain‑English explanation (Logistic Regression):")
for _, r in coef_df.head(3).iterrows():
print(f"• Higher '{r['feature']}' increases the odds of heart disease by ~{r['odds_ratio']:.2f}× (holding others constant).")
for _, r in coef_df.tail(3).sort_values('odds_ratio').iterrows():
print(f"• Higher '{r['feature']}' is associated with ~{r['odds_ratio']:.2f}× odds (lower than baseline), suggesting protection.")
Top risk‑increasing features (odds ratio > 1):
feature coef odds_ratio
8 ChestPainType_ASY 1.120 3.064
18 ST_Slope_Flat 1.027 2.794
5 FastingBS 0.505 1.656
7 Sex_M 0.484 1.622
4 Oldpeak 0.422 1.524
16 ExerciseAngina_Y 0.310 1.364
0 Age 0.158 1.171
12 RestingECG_LVH 0.000 1.000
17 ST_Slope_Down 0.000 1.000
14 RestingECG_ST 0.000 1.000
Protective features (lowest odds ratios):
feature coef odds_ratio
13 RestingECG_Normal 0.000 1.000
10 ChestPainType_NAP 0.000 1.000
11 ChestPainType_TA 0.000 1.000
1 RestingBP 0.000 1.000
2 Cholesterol 0.000 1.000
9 ChestPainType_ATA -0.162 0.850
3 MaxHR -0.261 0.770
15 ExerciseAngina_N -0.314 0.731
6 Sex_F -0.487 0.615
19 ST_Slope_Up -0.746 0.474
Plain‑English explanation (Logistic Regression):
• Higher 'ChestPainType_ASY' increases the odds of heart disease by ~3.06× (holding others constant).
• Higher 'ST_Slope_Flat' increases the odds of heart disease by ~2.79× (holding others constant).
• Higher 'FastingBS' increases the odds of heart disease by ~1.66× (holding others constant).
• Higher 'ST_Slope_Up' is associated with ~0.47× odds (lower than baseline), suggesting protection.
• Higher 'Sex_F' is associated with ~0.61× odds (lower than baseline), suggesting protection.
• Higher 'ExerciseAngina_N' is associated with ~0.73× odds (lower than baseline), suggesting protection.
# Plot top positive and negative effects
topN = 15
sdf = coef_df.iloc[np.r_[coef_df['coef'].nlargest(topN).index, coef_df['coef'].nsmallest(topN).index]].\
sort_values('odds_ratio')
plt.figure(figsize=(8,10))
sns.barplot(y='feature', x='coef', data=sdf, palette=['green' if v<0 else 'red' for v in sdf['coef']])
plt.axvline(0, color='k', lw=1)
plt.title('Top +/- coefficients (log-odds)')
plt.tight_layout()
plt.show()

Decision Tree (Shallow, Rule-Based)¶
Decision Trees are intuitive models that split data based on feature thresholds, forming a set of "if-then" rules. In clinical settings, shallow trees (limited depth) are preferred for their simplicity and interpretability.
This section trains a decision tree with depth=3 and evaluates its performance. The tree structure is visualized, and human-readable rules are extracted to support clinical decision-making.
Assumptions & notes: This model is non‑parametric, meaning that the tree structure cannot be recreated solely from optimal metrics. Explainability improves when the tree is shallow (limited max_depth, larger min_samples_leaf). Post‑training, we can prune using cost‑complexity pruning (ccp_alpha) to remove weak branches and keep the ruleset compact.
from sklearn.tree import DecisionTreeClassifier, export_text, plot_tree
tree_shallow = Pipeline(steps=[('prep', preprocess),
('clf', DecisionTreeClassifier(max_depth=3, min_samples_leaf=30, random_state=42))])
tree_shallow.fit(X_train, y_train)
dt_metrics = evaluate_classifier("Decision Tree (depth=3)", tree_shallow)
=== Decision Tree (depth=3) ===
Accuracy: 0.821
Precision: 0.871
Recall: 0.794
F1: 0.831
ROC-AUC: 0.897
PR-AUC: 0.891
Confusion matrix:
[[70 12]
[21 81]]
Classification report:
precision recall f1-score support
0 0.77 0.85 0.81 82
1 0.87 0.79 0.83 102
accuracy 0.82 184
macro avg 0.82 0.82 0.82 184
weighted avg 0.83 0.82 0.82 184
Optimising the parameters using grid search and cross validation.
param_grid_dt = [{
'clf__criterion': ['gini', 'entropy', 'log_loss'], # splitting criteria
'clf__splitter': ['best', 'random'],
# keep trees shallow/compact for rules clinicians can use
'clf__max_depth': [2, 3, 4],
'clf__min_samples_split': [5, 10],
'clf__min_samples_leaf': [5, 10],
# feature sampling can help generalization; None is simplest
'clf__max_features': [None, 'sqrt', 'log2'],
# class balance
'clf__class_weight': [None, 'balanced'],
# CART post-pruning – small positive values prune weak branches
'clf__ccp_alpha': [0.0, 1e-4, 5e-4, 1e-3]
}]
base_dt = Pipeline(steps=[('prep', preprocess),
('clf', DecisionTreeClassifier(random_state=29))])
cv = StratifiedKFold(n_splits=5, shuffle=True, random_state=29)
dt_grid = GridSearchCV(base_dt, param_grid=param_grid_dt, cv=cv,
scoring='roc_auc', n_jobs=-1, refit=True, verbose=1)
dt_grid.fit(X_train, y_train)
print('Best AUC (CV):', dt_grid.best_score_)
print('Best params:', dt_grid.best_params_)
Fitting 5 folds for each of 1728 candidates, totalling 8640 fits
Best AUC (CV): 0.9081265647932314
Best params: {'clf__ccp_alpha': 0.0, 'clf__class_weight': None, 'clf__criterion': 'gini', 'clf__max_depth': 4, 'clf__max_features': None, 'clf__min_samples_leaf': 5, 'clf__min_samples_split': 5, 'clf__splitter': 'random'}
dt_metrics = evaluate_classifier("Decision Tree (tuned)", dt_grid.best_estimator_)
print("Human‑readable rules (Decision Tree, (hyperparameter tuned):")
print(export_text(dt_grid.best_estimator_.named_steps['clf'], feature_names=final_feature_names))
# Optional visualization (will display when run in Jupyter)
plt.figure(figsize=(14,8))
plot_tree(dt_grid.best_estimator_.named_steps['clf'], feature_names=final_feature_names,
class_names=['No HD','HD'], filled=True, rounded=True, max_depth=3)
plt.title('Decision Tree (depth=3)'); plt.show()
print("Plain‑English explanation (Decision Tree):")
print("• The top split isolates the strongest single risk condition; subsequent splits form compact if‑then rules.")
print("• Leaves summarize risk with class proportions—ideal for quick triage guidelines.")
=== Decision Tree (tuned) ===
Accuracy: 0.821
Precision: 0.822
Recall: 0.863
F1: 0.842
ROC-AUC: 0.880
PR-AUC: 0.859
Confusion matrix:
[[63 19]
[14 88]]
Classification report:
precision recall f1-score support
0 0.82 0.77 0.79 82
1 0.82 0.86 0.84 102
accuracy 0.82 184
macro avg 0.82 0.82 0.82 184
weighted avg 0.82 0.82 0.82 184
Human‑readable rules (Decision Tree, (hyperparameter tuned):
|--- ST_Slope_Up <= 0.91
| |--- ChestPainType_ASY <= 0.79
| | |--- Sex_F <= 0.91
| | | |--- ST_Slope_Flat <= 0.61
| | | | |--- class: 0
| | | |--- ST_Slope_Flat > 0.61
| | | | |--- class: 1
| | |--- Sex_F > 0.91
| | | |--- Oldpeak <= 0.61
| | | | |--- class: 0
| | | |--- Oldpeak > 0.61
| | | | |--- class: 0
| |--- ChestPainType_ASY > 0.79
| | |--- Sex_M <= 0.04
| | | |--- FastingBS <= 0.57
| | | | |--- class: 1
| | | |--- FastingBS > 0.57
| | | | |--- class: 1
| | |--- Sex_M > 0.04
| | | |--- Cholesterol <= 0.79
| | | | |--- class: 1
| | | |--- Cholesterol > 0.79
| | | | |--- class: 1
|--- ST_Slope_Up > 0.91
| |--- ChestPainType_ASY <= 0.93
| | |--- FastingBS <= 0.71
| | | |--- RestingECG_LVH <= 0.97
| | | | |--- class: 0
| | | |--- RestingECG_LVH > 0.97
| | | | |--- class: 0
| | |--- FastingBS > 0.71
| | | |--- Oldpeak <= -0.18
| | | | |--- class: 0
| | | |--- Oldpeak > -0.18
| | | | |--- class: 1
| |--- ChestPainType_ASY > 0.93
| | |--- ExerciseAngina_N <= 0.11
| | | |--- Oldpeak <= 0.17
| | | | |--- class: 1
| | | |--- Oldpeak > 0.17
| | | | |--- class: 1
| | |--- ExerciseAngina_N > 0.11
| | | |--- FastingBS <= 0.97
| | | | |--- class: 0
| | | |--- FastingBS > 0.97
| | | | |--- class: 1

Plain‑English explanation (Decision Tree):
• The top split isolates the strongest single risk condition; subsequent splits form compact if‑then rules.
• Leaves summarize risk with class proportions—ideal for quick triage guidelines.
CART (Decision Tree with Cost-Complexity Pruning)¶
CART applies pruning to reduce overfitting and simplify the tree. Cost-complexity pruning removes branches that contribute little to predictive power, resulting in a more generalizable and compact model.
This section fits a full tree, derives pruning paths, and selects an optimal ccp_alpha to build a pruned tree. The final rules are extracted and explained.
Assumptions & notes: Non‑parametric. Balancing bias–variance via pruning improves generalization.
# Fit full tree on engineered features to obtain pruning path
Xtr_eng = preprocess.fit_transform(X_train)
from sklearn.tree import DecisionTreeClassifier
full_tree = DecisionTreeClassifier(random_state=42)
full_tree.fit(Xtr_eng, y_train)
path = full_tree.cost_complexity_pruning_path(Xtr_eng, y_train)
ccp_alphas = path.ccp_alphas
# Choose a moderate alpha (e.g., median of unique alphas) for simplicity
alpha = float(np.median(np.unique(ccp_alphas)))
cart = Pipeline(steps=[('prep', preprocess),
('clf', DecisionTreeClassifier(random_state=42, ccp_alpha=alpha))])
cart.fit(X_train, y_train)
cart_metrics = evaluate_classifier(f"CART (pruned, ccp_alpha={alpha:.5f})", cart)
=== CART (pruned, ccp_alpha=0.00242) ===
Accuracy: 0.837
Precision: 0.853
Recall: 0.853
F1: 0.853
ROC-AUC: 0.820
PR-AUC: 0.817
Confusion matrix:
[[67 15]
[15 87]]
Classification report:
precision recall f1-score support
0 0.82 0.82 0.82 82
1 0.85 0.85 0.85 102
accuracy 0.84 184
macro avg 0.84 0.84 0.84 184
weighted avg 0.84 0.84 0.84 184
Optimising the parameters using grid search and cross validation.
# Derive candidate alphas from the pruning path (on engineered features)
preprocess.fit(X_train, y_train)
Xtr_eng = preprocess.transform(X_train)
clf_full = DecisionTreeClassifier(random_state=29)
clf_full.fit(Xtr_eng, y_train)
path = clf_full.cost_complexity_pruning_path(Xtr_eng, y_train)
ccp_alphas = np.unique(path.ccp_alphas)
# Build a simple grid over alphas, keeping depth small
param_grid_cart = [{
'clf__ccp_alpha': list(ccp_alphas[::max(1, len(ccp_alphas)//10)]), # sample ~10 alphas
'clf__max_depth': [2, 3, 4],
'clf__min_samples_leaf': [5, 10],
'clf__criterion': ['gini', 'entropy']
}]
base_cart = Pipeline(steps=[('prep', preprocess),
('clf', DecisionTreeClassifier(random_state=29))])
cv = StratifiedKFold(n_splits=5, shuffle=True, random_state=29)
cart_grid = GridSearchCV(base_cart, param_grid=param_grid_cart, cv=cv,
scoring='roc_auc', n_jobs=-1, refit=True, verbose=1)
cart_grid.fit(X_train, y_train)
print('Best AUC (CV):', cart_grid.best_score_)
print('Best params:', cart_grid.best_params_)
Fitting 5 folds for each of 144 candidates, totalling 720 fits
Best AUC (CV): 0.897052808678825
Best params: {'clf__ccp_alpha': 0.0, 'clf__criterion': 'entropy', 'clf__max_depth': 4, 'clf__min_samples_leaf': 5}
cart_metrics = evaluate_classifier("CART (pruned, tuned)", cart_grid.best_estimator_)
print("Rules (CART, pruned):")
print(export_text(cart_grid.best_estimator_.named_steps['clf'], feature_names=final_feature_names))
print("Plain‑English explanation (CART):")
print("• Pruning removes branches that add little generalizable signal, keeping rules concise.")
print("• The final tree provides a short list of if‑then criteria suitable for checklists.")
=== CART (pruned, tuned) ===
Accuracy: 0.804
Precision: 0.780
Recall: 0.902
F1: 0.836
ROC-AUC: 0.873
PR-AUC: 0.857
Confusion matrix:
[[56 26]
[10 92]]
Classification report:
precision recall f1-score support
0 0.85 0.68 0.76 82
1 0.78 0.90 0.84 102
accuracy 0.80 184
macro avg 0.81 0.79 0.80 184
weighted avg 0.81 0.80 0.80 184
Rules (CART, pruned):
|--- ST_Slope_Up <= 0.50
| |--- ChestPainType_ASY <= 0.50
| | |--- MaxHR <= -0.00
| | | |--- Sex_M <= 0.50
| | | | |--- class: 0
| | | |--- Sex_M > 0.50
| | | | |--- class: 1
| | |--- MaxHR > -0.00
| | | |--- ST_Slope_Flat <= 0.50
| | | | |--- class: 0
| | | |--- ST_Slope_Flat > 0.50
| | | | |--- class: 1
| |--- ChestPainType_ASY > 0.50
| | |--- Sex_M <= 0.50
| | | |--- FastingBS <= 0.63
| | | | |--- class: 1
| | | |--- FastingBS > 0.63
| | | | |--- class: 1
| | |--- Sex_M > 0.50
| | | |--- Oldpeak <= -0.70
| | | | |--- class: 1
| | | |--- Oldpeak > -0.70
| | | | |--- class: 1
|--- ST_Slope_Up > 0.50
| |--- ChestPainType_ASY <= 0.50
| | |--- Oldpeak <= 1.03
| | | |--- Cholesterol <= -0.10
| | | | |--- class: 0
| | | |--- Cholesterol > -0.10
| | | | |--- class: 0
| | |--- Oldpeak > 1.03
| | | |--- class: 1
| |--- ChestPainType_ASY > 0.50
| | |--- Oldpeak <= -0.41
| | | |--- FastingBS <= 0.63
| | | | |--- class: 0
| | | |--- FastingBS > 0.63
| | | | |--- class: 1
| | |--- Oldpeak > -0.41
| | | |--- ExerciseAngina_Y <= 0.50
| | | | |--- class: 1
| | | |--- ExerciseAngina_Y > 0.50
| | | | |--- class: 1
Plain‑English explanation (CART):
• Pruning removes branches that add little generalizable signal, keeping rules concise.
• The final tree provides a short list of if‑then criteria suitable for checklists.
Naive Bayes (Gaussian)¶
Naive Bayes is a probabilistic classifier based on Bayes' theorem, assuming feature independence. It models continuous features using Gaussian distributions.
This section trains and tunes a Gaussian Naive Bayes model, evaluates its performance, and analyzes feature-wise class means to understand which features contribute most to prediction.
Assumptions & notes: Independence across features given the class. Stability is often improved by tuning var_smoothing (which adds a small value to variances). While it has While it has fewer tunable parameters than most models, it is great for fast, interpretable baselines.
from sklearn.naive_bayes import GaussianNB
from sklearn.preprocessing import FunctionTransformer
# Make engineered features dense for GaussianNB
to_dense = FunctionTransformer(lambda X: X.toarray() if hasattr(X, 'toarray') else X)
nb = Pipeline(steps=[('prep', preprocess),
('dense', to_dense),
('clf', GaussianNB())])
nb.fit(X_train, y_train)
nb_metrics = evaluate_classifier("Naive Bayes (Gaussian)", nb)
=== Naive Bayes (Gaussian) ===
Accuracy: 0.837
Precision: 0.867
Recall: 0.833
F1: 0.850
ROC-AUC: 0.909
PR-AUC: 0.923
Confusion matrix:
[[69 13]
[17 85]]
Classification report:
precision recall f1-score support
0 0.80 0.84 0.82 82
1 0.87 0.83 0.85 102
accuracy 0.84 184
macro avg 0.83 0.84 0.84 184
weighted avg 0.84 0.84 0.84 184
Optimising the parameters using grid search and cross validation.
param_grid_nb = [{
# var_smoothing adds a fraction of the largest variance for numerical stability
'clf__var_smoothing': np.logspace(-12, -7, 10),
'clf__priors': [None] # keep data-driven class priors
}]
base_nb = Pipeline(steps=[('prep', preprocess),
('dense', to_dense),
('clf', GaussianNB())])
cv = StratifiedKFold(n_splits=5, shuffle=True, random_state=29)
nb_grid = GridSearchCV(base_nb, param_grid=param_grid_nb, cv=cv,
scoring='roc_auc', n_jobs=-1, refit=True, verbose=1)
nb_grid.fit(X_train, y_train)
print('Best AUC (CV):', nb_grid.best_score_)
print('Best params:', nb_grid.best_params_)
Fitting 5 folds for each of 10 candidates, totalling 50 fits
Best AUC (CV): 0.9183289499820665
Best params: {'clf__priors': None, 'clf__var_smoothing': 1e-12}
nb_metrics = evaluate_classifier("Naive Bayes (Gaussian, tuned)", nb_grid.best_estimator_)
# Class means per engineered feature
gnb = nb_grid.best_estimator_.named_steps['clf']
Xtr_engineered = nb_grid.best_estimator_.named_steps['dense'].\
transform(nb_grid.best_estimator_.named_steps['prep'].transform(X_train))
means_0 = gnb.theta_[0]; means_1 = gnb.theta_[1]
nb_df = pd.DataFrame({'feature': final_feature_names,
'mean_no_disease': means_0,
'mean_disease': means_1,
'diff(disease - no_disease)': means_1 - means_0})
nb_df = nb_df.sort_values('diff(disease - no_disease)', ascending=False)
print("Plain‑English explanation (Naive Bayes):")
print("• A feature value more typical under the disease class shifts the posterior toward heart disease; the opposite shifts toward no disease.")
print("Top features with largest class‑mean differences (NB view):")
nb_df.head(10).round(3)
=== Naive Bayes (Gaussian, tuned) ===
Accuracy: 0.837
Precision: 0.867
Recall: 0.833
F1: 0.850
ROC-AUC: 0.909
PR-AUC: 0.923
Confusion matrix:
[[69 13]
[17 85]]
Classification report:
precision recall f1-score support
0 0.80 0.84 0.82 82
1 0.87 0.83 0.85 102
accuracy 0.84 184
macro avg 0.83 0.84 0.84 184
weighted avg 0.84 0.84 0.84 184
Plain‑English explanation (Naive Bayes):
• A feature value more typical under the disease class shifts the posterior toward heart disease; the opposite shifts toward no disease.
Top features with largest class‑mean differences (NB view):
| feature | mean_no_disease | mean_disease | diff(disease - no_disease) | |
|---|---|---|---|---|
| 4 | Oldpeak | -0.465 | 0.376 | 0.841 |
| 0 | Age | -0.334 | 0.270 | 0.604 |
| 5 | FastingBS | -0.320 | 0.259 | 0.579 |
| 18 | ST_Slope_Flat | 0.186 | 0.759 | 0.573 |
| 8 | ChestPainType_ASY | 0.253 | 0.773 | 0.520 |
| 16 | ExerciseAngina_Y | 0.143 | 0.618 | 0.475 |
| 1 | RestingBP | -0.132 | 0.107 | 0.239 |
| 7 | Sex_M | 0.655 | 0.894 | 0.239 |
| 2 | Cholesterol | -0.079 | 0.064 | 0.144 |
| 14 | RestingECG_ST | 0.149 | 0.227 | 0.077 |
k-Nearest Neighbors (k-NN)¶
k-NN is a non-parametric, instance-based learning algorithm. It predicts the class of a sample based on the majority vote of its k nearest neighbors.
This section trains and tunes a k-NN model, evaluates its performance, and inspects the neighbors of a sample to explain the prediction in clinical terms.
Assumptions & notes: KNN is sensitive to scaling and the distance metric. Key hyperparameters include n_neighbors, weights (uniform vs distance), and metric (Minkowski with p=½).
from sklearn.neighbors import KNeighborsClassifier
knn = Pipeline(steps=[('prep', preprocess),
('clf', KNeighborsClassifier(n_neighbors=5, weights='distance'))])
knn.fit(X_train, y_train)
knn_metrics = evaluate_classifier("k-NN (k=5, distance weights)", knn)
=== k-NN (k=5, distance weights) ===
Accuracy: 0.859
Precision: 0.873
Recall: 0.873
F1: 0.873
ROC-AUC: 0.890
PR-AUC: 0.886
Confusion matrix:
[[69 13]
[13 89]]
Classification report:
precision recall f1-score support
0 0.84 0.84 0.84 82
1 0.87 0.87 0.87 102
accuracy 0.86 184
macro avg 0.86 0.86 0.86 184
weighted avg 0.86 0.86 0.86 184
param_grid_knn = [{
'clf__n_neighbors': [3, 5, 7, 9, 11],
'clf__weights': ['uniform', 'distance'],
# Minkowski metric with p=1 (Manhattan) or p=2 (Euclidean)
'clf__metric': ['minkowski'],
'clf__p': [1, 2]
}]
base_knn = Pipeline(steps=[('prep', preprocess),
('clf', KNeighborsClassifier())])
cv = StratifiedKFold(n_splits=5, shuffle=True, random_state=29)
knn_grid = GridSearchCV(base_knn, param_grid=param_grid_knn, cv=cv,
scoring='roc_auc', n_jobs=-1, refit=True, verbose=1)
knn_grid.fit(X_train, y_train)
print('Best AUC (CV):', knn_grid.best_score_)
print('Best params:', knn_grid.best_params_)
Fitting 5 folds for each of 20 candidates, totalling 100 fits
Best AUC (CV): 0.9204453821797994
Best params: {'clf__metric': 'minkowski', 'clf__n_neighbors': 11, 'clf__p': 1, 'clf__weights': 'distance'}
Optimising the parameters using grid search and cross validation.
df_knn = df.copy()
df_knn['class'] = knn_grid.predict(df)
# df_knn.plot.scatter('RestingBP', 'Cholesterol', 'class')
plt.figure(figsize=(8, 6))
sns.scatterplot(data=df_knn, x='Age', y='MaxHR', hue='class', s=50, alpha=0.75)
plt.title('Scatter plot showing the incidence of heart disease')
plt.grid(True)
plt.show()

knn_metrics = evaluate_classifier("k-NN (tuned)", knn_grid.best_estimator_)
# Inspect neighbors for one patient
idx = 0
x_query = X_test.iloc[[idx]]
y_true = y_test.iloc[idx]
proba = knn_grid.predict_proba(x_query)[:,1][0]
pred = knn_grid.predict(x_query)[0]
print("Plain‑English explanation (k‑NN):")
print("• The prediction is driven by the outcomes of the five most similar patients; closer neighbors carry higher weight.")
print(f"Query patient index #{x_query.index[0]} | True={y_true}, Pred={pred}, Prob(HD)={proba:.3f}")
print("Patient details:")
print(x_query)
# Show actual neighbors from engineered space
Xtr_eng2 = preprocess.transform(X_train)
Xte_eng2 = preprocess.transform(x_query)
base_knn = KNeighborsClassifier(n_neighbors=5, weights='distance').fit(Xtr_eng2, y_train)
dist, ind = base_knn.kneighbors(Xte_eng2, n_neighbors=5, return_distance=True)
neighbors = pd.DataFrame({'neighbor_train_idx': y_train.index.values[ind[0]],
'distance': dist[0],
'label': y_train.iloc[ind[0]].values}).sort_values('distance')
print("Nearest neighbors (closer = more influence):")
neighbors.join(df)
=== k-NN (tuned) ===
Accuracy: 0.875
Precision: 0.876
Recall: 0.902
F1: 0.889
ROC-AUC: 0.906
PR-AUC: 0.905
Confusion matrix:
[[69 13]
[10 92]]
Classification report:
precision recall f1-score support
0 0.87 0.84 0.86 82
1 0.88 0.90 0.89 102
accuracy 0.88 184
macro avg 0.87 0.87 0.87 184
weighted avg 0.87 0.88 0.87 184
Plain‑English explanation (k‑NN):
• The prediction is driven by the outcomes of the five most similar patients; closer neighbors carry higher weight.
Query patient index #95 | True=1, Pred=1, Prob(HD)=0.868
Patient details:
Age RestingBP Cholesterol MaxHR Oldpeak Sex ChestPainType RestingECG \
95 58 130.0 263.0 140 2.0 M ASY Normal
ExerciseAngina ST_Slope FastingBS
95 Y Flat 0
Nearest neighbors (closer = more influence):
| neighbor_train_idx | distance | label | Age | Sex | ChestPainType | RestingBP | Cholesterol | FastingBS | RestingECG | MaxHR | ExerciseAngina | Oldpeak | ST_Slope | HeartDisease | |
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| 0 | 562 | 0.806738 | 0 | 40 | M | ATA | 140.0 | 289.0 | 0 | Normal | 172 | N | 0.0 | Up | 0 |
| 1 | 74 | 0.926989 | 1 | 49 | F | NAP | 160.0 | 180.0 | 0 | Normal | 156 | N | 1.0 | Flat | 1 |
| 2 | 586 | 1.016484 | 1 | 37 | M | ATA | 130.0 | 283.0 | 0 | ST | 98 | N | 0.0 | Up | 0 |
| 3 | 501 | 1.032806 | 1 | 48 | F | ASY | 138.0 | 214.0 | 0 | Normal | 108 | Y | 1.5 | Flat | 1 |
| 4 | 368 | 1.053968 | 1 | 54 | M | NAP | 150.0 | 195.0 | 0 | Normal | 122 | N | 0.0 | Up | 0 |
summary = pd.DataFrame([lr_metrics, dt_metrics, cart_metrics, nb_metrics, knn_metrics]).set_index('model').sort_values('roc_auc', ascending=False)
summary.sort_values('accuracy', ascending=False).round(3)
| accuracy | precision | recall | f1 | roc_auc | pr_auc | |
|---|---|---|---|---|---|---|
| model | ||||||
| k-NN (tuned) | 0.875 | 0.876 | 0.902 | 0.889 | 0.906 | 0.905 |
| Logistic Regression (with hyper parameter tuning) | 0.842 | 0.869 | 0.843 | 0.856 | 0.909 | 0.924 |
| Naive Bayes (Gaussian, tuned) | 0.837 | 0.867 | 0.833 | 0.850 | 0.909 | 0.923 |
| CART (pruned, ccp_alpha=0.00242) | 0.837 | 0.853 | 0.853 | 0.853 | 0.820 | 0.817 |
| Decision Tree (tuned) | 0.821 | 0.822 | 0.863 | 0.842 | 0.880 | 0.859 |
Diagnostics of the Best Model¶
To ensure robustness and reliability, we perform diagnostics on the best-performing model (k-NN). This includes checking multicollinearity using VIF and plotting learning curves to assess bias-variance tradeoff.
These diagnostics help validate model assumptions and guide further improvements.
best_model = knn_grid.best_estimator_
from statsmodels.stats.outliers_influence import variance_inflation_factor
X_num = df[num_cols].dropna()
X_num_const = sm.add_constant(X_num)
vif = pd.DataFrame({
'feature': ['const'] + list(X_num.columns),
'VIF': [variance_inflation_factor(X_num_const.values, i) for i in range(X_num_const.shape[1])]
})
vif
| feature | VIF | |
|---|---|---|
| 0 | const | 160.072777 |
| 1 | Age | 1.277888 |
| 2 | RestingBP | 1.098974 |
| 3 | Cholesterol | 1.011748 |
| 4 | MaxHR | 1.205888 |
| 5 | Oldpeak | 1.142342 |
ROC-AUC Curve¶
The ROC (Receiver Operating Characteristic) curve is a graphical representation of a classifier's performance across different threshold values. It plots the True Positive Rate (Sensitivity) against the False Positive Rate (1 - Specificity).
AUC (Area Under the Curve) quantifies the overall ability of the model to discriminate between positive and negative classes. A higher AUC indicates better performance.
from sklearn.metrics import roc_curve, auc
import matplotlib.pyplot as plt
# Predict probabilities
y_probs = best_model.predict_proba(X_test)[:, 1]
# Compute ROC curve and AUC
fpr, tpr, thresholds = roc_curve(y_test.astype(int), y_probs)
roc_auc = auc(fpr, tpr)
# Plot ROC curve
plt.figure()
plt.plot(fpr, tpr, color='darkorange', lw=2, label=f'ROC curve (AUC = {roc_auc:.2f})')
plt.plot([0, 1], [0, 1], color='navy', lw=2, linestyle='--')
plt.xlabel('False Positive Rate')
plt.ylabel('True Positive Rate')
plt.title('Receiver Operating Characteristic')
plt.legend(loc='lower right')
plt.show();

AUC of 0.91 indicates good performance
Sensitivity-Specificity Curve¶
This curve shows how sensitivity and specificity vary with different classification thresholds. It helps in selecting a threshold that balances both metrics.
# Compute sensitivity and specificity
sensitivity = tpr
specificity = 1 - fpr
# Plot Sensitivity-Specificity curve
plt.figure()
plt.plot(thresholds, sensitivity, label='Sensitivity')
plt.plot(thresholds, specificity, label='Specificity')
plt.xlabel('Threshold')
plt.title('Sensitivity vs Specificity')
plt.legend()
plt.show();

Learning curves¶
Learning curves are a valuable diagnostic tool to understand how a model's performance evolves with increasing training data. They help assess whether a model suffers from high bias (underfitting) or high variance (overfitting), and whether adding more data could improve performance.
# Learning curve for best estimator
train_sizes, train_scores, val_scores = learning_curve(best_model, X_train, y_train, cv=5, scoring='roc_auc', n_jobs=-1,
train_sizes=np.linspace(0.1, 1.0, 5), random_state=29)
plt.figure(figsize=(7,5))
plt.plot(train_sizes, train_scores.mean(axis=1), 'o-', label='Train AUC')
plt.plot(train_sizes, val_scores.mean(axis=1), 'o-', label='CV AUC')
plt.xlabel('Training examples')
plt.ylim(0, 1.01)
plt.ylabel('ROC-AUC')
plt.title('Learning Curve')
plt.legend()
plt.show()

Ethics and robustness¶
- Medical use requires validation: consult clinicians; understand labeling & sampling biases.
- Test fairness across groups (e.g., Sex) and ensure no proxy discrimination.
- Perform temporal/center splits if applicable; assess drift.
- Consider decision support (calibrated probabilities + cost‑aware thresholds) rather than hard automation.