Skip to content

Explainable models: Heart Failure Analysis

This notebook performs a classification analysis on the Heart Failure Prediction dataset in such a way that doctors can use it in their clinics directly without using any blackbox models. The goal is to predict whether a patient has heart disease based on clinical and demographic features such as age, blood pressure, cholesterol level, and exercise-related indicators.

Dataset

The dataset is designed to predict the presence of heart disease using a mix of numeric and categorical attributes collected from patients. It combines data from five well-known heart disease studies into a single, clean dataset with 918 observations. Predicting heart disease risk is critical for early intervention, improving patient outcomes, and supporting healthcare decision-making.

Use Case Description

Our objective is to classify patients as having heart disease or not using logistic regression techniques. This involves:
1. Exploring relationships between patient attributes and heart disease occurrence
2. Identifying statistically significant predictors
3. Validating model assumptions and performance
4. Diagnosing issues such as multicollinearity and calibration
5. Evaluating model robustness using cross-validation, ROC curves, and threshold tuning

Explainable Models

For doctors in a hospital setting, simple and highly explainable models are essential because they can be translated into quick, actionable rules for patient screening. These models allow clinicians to make fast decisions without relying on complex computations or black-box algorithms. Some commonly used, interpretable models include:
1. Logistic Regression: Provides clear insights through coefficients and odds ratios, making it easy to understand how each factor influences heart disease risk.
2. Decision Trees: Offer intuitive, rule-based structures that can be visualized and converted into simple “if-then” guidelines for clinical use.
3. Rule-Based Models (e.g., CART or CHAID) – Extend decision trees into sets of human-readable rules for rapid triage.
4. Naïve Bayes: Based on conditional probabilities; Easy to interpret and explain as likelihoods.
5. Simple k-Nearest Neighbors (k-NN): While not inherently rule-based, its logic ("You look most like these N past patients who had/didn’t have heart disease.”) is easy to explain in clinical terms.

Dataset card: https://www.kaggle.com/datasets/fedesoriano/heart-failure-prediction
Fallback CSV mirror: https://raw.githubusercontent.com/benbobyabraham/heart_failure_prediction_dataset_kaggle/main/heart.csv

Importing packages

import os, sys, warnings, json
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

from sklearn.compose import ColumnTransformer
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import OneHotEncoder, StandardScaler
from sklearn.impute import SimpleImputer
from sklearn.model_selection import train_test_split, StratifiedKFold, GridSearchCV, cross_val_score, learning_curve, validation_curve
from sklearn.metrics import (accuracy_score, precision_score, recall_score, f1_score, roc_auc_score, roc_curve, auc,
                             confusion_matrix, classification_report, precision_recall_curve, average_precision_score,
                             brier_score_loss, RocCurveDisplay, PrecisionRecallDisplay)
from sklearn.linear_model import LogisticRegression
from sklearn.dummy import DummyClassifier
from sklearn.calibration import CalibratedClassifierCV, CalibrationDisplay
from sklearn.tree import export_text
import joblib

import statsmodels.api as sm
import statsmodels.formula.api as smf

warnings.filterwarnings('ignore')
np.random.seed(42)
sns.set(style='whitegrid', context='notebook')

We try to read heart.csv from the Data folder. If not present, we fetch a public read‑only mirror of the Kaggle CSV from GitHub.

from pathlib import Path
import io, urllib.request

DATA_LOCAL = Path('../Data/heart.csv')
DATA_URL = 'https://raw.githubusercontent.com/benbobyabraham/heart_failure_prediction_dataset_kaggle/main/heart.csv'

if DATA_LOCAL.exists():
    df = pd.read_csv(DATA_LOCAL)
else:
    try:
        with urllib.request.urlopen(DATA_URL) as resp:
            df = pd.read_csv(io.BytesIO(resp.read()))
    except Exception as e:
        raise FileNotFoundError('Could not find heart.csv locally or fetch from the fallback URL. Please download from Kaggle and place heart.csv next to this notebook.')

print(df.shape)
df.head()
(918, 12)
Age Sex ChestPainType RestingBP Cholesterol FastingBS RestingECG MaxHR ExerciseAngina Oldpeak ST_Slope HeartDisease
0 40 M ATA 140 289 0 Normal 172 N 0.0 Up 0
1 49 F NAP 160 180 0 Normal 156 N 1.0 Flat 1
2 37 M ATA 130 283 0 ST 98 N 0.0 Up 0
3 48 F ASY 138 214 0 Normal 108 Y 1.5 Flat 1
4 54 M NAP 150 195 0 Normal 122 N 0.0 Up 0

EDA (Exploratory Data Analysis)

Data attributes and description

Column Name Description
Age Age of the patient (years)
Sex Gender of the patient (M/F)
ChestPainType Type of chest pain (ATA, NAP, ASY, TA)
RestingBP Resting blood pressure (mm Hg)
Cholesterol Serum cholesterol (mg/dl)
FastingBS Fasting blood sugar (>120 mg/dl: 1, else: 0)
RestingECG Resting electrocardiogram results
MaxHR Maximum heart rate achieved
ExerciseAngina Exercise-induced angina (Y/N)
Oldpeak ST depression induced by exercise relative to rest
ST_Slope Slope of the peak exercise ST segment
HeartDisease Target variable (1 = disease present, 0 = no disease)

Feature types: Mix of numeric (Age, RestingBP, Cholesterol, MaxHR, Oldpeak) and categorical (Sex, ChestPainType, RestingECG, ExerciseAngina, ST_Slope), plus a binary variable FastingBS.
Unit notes: e.g., RestingBP in mmHg, Cholesterol in mg/dl; some sources note 0 values may indicate missing/not measured.

df['FastingBS'] = df.FastingBS.astype(str)
df['HeartDisease'] = df.HeartDisease.astype(str)
df.describe()
Age RestingBP Cholesterol MaxHR Oldpeak
count 918.000000 918.000000 918.000000 918.000000 918.000000
mean 53.510893 132.396514 198.799564 136.809368 0.887364
std 9.432617 18.514154 109.384145 25.460334 1.066570
min 28.000000 0.000000 0.000000 60.000000 -2.600000
25% 47.000000 120.000000 173.250000 120.000000 0.000000
50% 54.000000 130.000000 223.000000 138.000000 0.600000
75% 60.000000 140.000000 267.000000 156.000000 1.500000
max 77.000000 200.000000 603.000000 202.000000 6.200000
from ml_plottings import univariate_analysis, classification_bivariate_analysis, plot_correlation_matrix
univariate_analysis(df)
Sex

png

png

------------------------------------------------------------------------------------------------
ChestPainType

png

png

------------------------------------------------------------------------------------------------
FastingBS

png

png

------------------------------------------------------------------------------------------------
RestingECG

png

png

------------------------------------------------------------------------------------------------
ExerciseAngina

png

png

------------------------------------------------------------------------------------------------
ST_Slope

png

png

------------------------------------------------------------------------------------------------
HeartDisease

png

png

------------------------------------------------------------------------------------------------
Age

png

png

png

------------------------------------------------------------------------------------------------
RestingBP

png

png

png

------------------------------------------------------------------------------------------------
Cholesterol

png

png

png

------------------------------------------------------------------------------------------------
MaxHR

png

png

png

------------------------------------------------------------------------------------------------
Oldpeak

png

png

png

------------------------------------------------------------------------------------------------
classification_bivariate_analysis(df, 'HeartDisease')

png

png

png

png

png

png

png

png

png

png

png

Defining the categorical and continuous variables

num_cols = ['Age', 'RestingBP', 'Cholesterol', 'MaxHR', 'Oldpeak']
cat_cols = ['Sex', 'ChestPainType', 'RestingECG', 'ExerciseAngina', 'ST_Slope']
binary_cols = ['FastingBS']
plot_correlation_matrix(df, num_cols+binary_cols)

png

Key Insights from Exploratory Data Analysis
1. Univariate Analysis - Target Distribution: About 55% of patients have heart disease, while 45% do not.
- Demographics:
- Approximately 80% of patients are male.
- Age ranges from 28 to 77 years, roughly normally distributed.
- Categorical Features:
- ChestPainType: Over 50% report ASY (asymptomatic) chest pain.
- RestingECG: More than 60% show normal ECG results.
- FastingBS: Majority have normal fasting blood sugar.
- ST_Slope: Most patients have a flat ST slope.
- Numeric Features:
- RestingBP: Mostly between 80–120 mmHg, right-skewed; a value of 0 is an outlier.
- Cholesterol: About 172 missing values (≈17%), likely unrecorded measurements.
- MaxHR: Roughly normal distribution, concentrated between 120–155 bpm.
- Oldpeak: Around 375 values are zero, few below zero, rest above 1.0.
2. Bivariate Analysis
- Age vs HeartDisease: Probability of heart disease increases with age.
- MaxHR vs HeartDisease: Patients with heart disease tend to have lower maximum heart rate.
- Sex vs HeartDisease: Males have a higher chance of heart disease.
- ChestPainType vs HeartDisease: Patients with ASY chest pain have a higher risk.
- FastingBS vs HeartDisease: Elevated fasting blood sugar is associated with higher risk.
- ExerciseAngina vs HeartDisease: Presence of exercise-induced angina significantly increases risk.
- ST_Slope vs HeartDisease: Patients with ST slope = Up have a lower chance of heart disease.

Feature Engineering and Cleaning

The null values in Cholesterol and RestingBP indicate missing values.

df.loc[df.Cholesterol==0, 'Cholesterol'] = None
df.loc[df.RestingBP==0, 'RestingBP'] = None

Preprocessing & Pipelines

Before training a machine learning model, it’s essential to address missing data and ensure features are on comparable scales. These steps improve model stability and performance.
1. Handling Missing Values
- Mean/Median Imputation: Replace missing numeric values with the mean or median of the column. Median is often preferred for skewed distributions.
- Mode Imputation: For categorical variables, fill missing values with the most frequent category.
- Advanced Techniques: Use algorithms like KNN imputation or model-based imputation for more accurate estimates when data is not missing completely at random.
2. Feature Scaling
- Standardization (Z-score): Transforms features to have zero mean and unit variance. Common for models like logistic regression and SVM.
- Min-Max Scaling: Rescales features to a fixed range (usually [0, 1]). Useful for algorithms sensitive to absolute values, such as neural networks.
- Robust Scaling: Uses median and interquartile range, making it less sensitive to outliers.

A Pipeline in scikit-learn allows us to combine these steps into a single, reproducible workflow. It also makes the workflow modular and reusable for training, tuning, and deployment.Steps in the Pipeline:
1. Imputation
- Numeric Features: Replace missing values with the median (robust to outliers).
- Categorical Features: Fill missing values with the most frequent category.
2. Scaling
- Apply Standardization (Z-score) to numeric features so they have zero mean and unit variance.
- This is essential for models like logistic regression that are sensitive to feature scales.
3. Encoding - Use One-Hot Encoding for categorical variables to convert them into numeric form.
- Set handle_unknown='ignore' to safely handle unseen categories during inference.

feature_cols = num_cols + cat_cols + binary_cols
X = df[feature_cols]
y = df['HeartDisease']

numeric_transformer = Pipeline(steps=[
    ('imputer', SimpleImputer(strategy='median')),
    ('scaler', StandardScaler())
])

categorical_transformer = Pipeline(steps=[
    ('imputer', SimpleImputer(strategy='most_frequent')),
    ('onehot', OneHotEncoder(handle_unknown='ignore'))
])

preprocess = ColumnTransformer(
    transformers=[
        ('num', numeric_transformer, num_cols+binary_cols),
        ('cat', categorical_transformer, cat_cols)
    ]
)

Train/Validation/Test split

We divide our data into two parts: training data (80%) and testing data (20%). The training set is used to fit the model, while the test set evaluates how well the model generalizes to unseen data. This helps prevent overfitting and gives us a realistic estimate of model performance. We use Stratified sampling to preserve class ratio.

X_train, X_test, y_train, y_test = train_test_split(
    X, y, test_size=0.2, stratify=y, random_state=29
)
X_train.shape, X_test.shape
((734, 11), (184, 11))
# Helper to get post‑preprocessing feature names
def get_feature_names(ct):
    """
    Return feature names from a fitted ColumnTransformer `ct`.
    Works when the inner transformers expose `get_feature_names_out`.
    """
    output_features = []
    for name, trans, cols in ct.transformers_:
        if name == 'remainder' and trans == 'drop':
            continue

        # Pipeline with a final step having get_feature_names_out (e.g., OneHotEncoder)
        if hasattr(trans, 'named_steps') and hasattr(trans.named_steps.get('onehot', None), 'get_feature_names_out'):
            ohe = trans.named_steps['onehot']
            feats = list(ohe.get_feature_names_out(cols))
            output_features.extend(feats)
        # Direct transformer exposing get_feature_names_out (rare for numeric here)
        elif hasattr(trans, 'get_feature_names_out'):
            feats = list(trans.get_feature_names_out(cols))
            output_features.extend(feats)
        else:
            # Fall back to original column names
            output_features.extend(list(cols))

    return output_features


from sklearn.metrics import (accuracy_score, precision_score, recall_score, f1_score,
                             roc_auc_score, average_precision_score,
                             confusion_matrix, classification_report)

# Helper function to evaluate the model accuracy
def evaluate_classifier(name, model, X_te=X_test, y_te=y_test, prob_pos=None):
    y_pred = model.predict(X_te)
    y_pred = y_pred.astype(int)
    y_te = y_test.astype(int)
    if prob_pos is None:
        if hasattr(model, 'predict_proba'):
            prob_pos = model.predict_proba(X_te)[:,1]
        elif hasattr(model, 'decision_function'):
            scores = model.decision_function(X_te)
            smin, smax = scores.min(), scores.max()
            prob_pos = (scores - smin) / (smax - smin + 1e-12)
        else:
            prob_pos = np.zeros_like(y_te, dtype=float)
    acc  = accuracy_score(y_te, y_pred)
    prec = precision_score(y_te, y_pred, zero_division=0)
    rec  = recall_score(y_te, y_pred, zero_division=0)
    f1   = f1_score(y_te, y_pred, zero_division=0)
    try:
        roc  = roc_auc_score(y_te, prob_pos)
    except Exception:
        roc = np.nan
    pr_auc = average_precision_score(y_te, prob_pos)
    print(f"=== {name} ===")
    print(f"Accuracy: {acc:.3f}\nPrecision: {prec:.3f}\nRecall: {rec:.3f}\nF1: {f1:.3f}")
    print(f"ROC-AUC: {roc:.3f}\nPR-AUC: {pr_auc:.3f}")
    print("Confusion matrix:\n", confusion_matrix(y_te, y_pred))
    print("Classification report:\n", classification_report(y_te, y_pred, zero_division=0))
    return dict(model=name, accuracy=acc, precision=prec, recall=rec, f1=f1, roc_auc=roc, pr_auc=pr_auc)

Baseline model

A baseline model is a simple reference point used to evaluate whether a more complex model adds real predictive value. For classification tasks, this often means using a Dummy Classifier that predicts the majority class or random guesses based on class distribution. It doesn’t use any features, but sets a minimum performance benchmark, helping us understand if our logistic regression model truly improves upon naive predictions.

baseline = Pipeline(steps=[('prep', preprocess), ('mdl', DummyClassifier(strategy='most_frequent'))])
baseline.fit(X_train, y_train)
yp = baseline.predict(X_test)
yp_prob = baseline.predict_proba(X_test)[:,1]

Key Classification Metrics

  1. Accuracy
    • Proportion of correct predictions:
    • \[\text{Accuracy} = \frac{\text{TP + TN}}{\text{Total Samples}}\]
    • Simple but can be misleading if classes are imbalanced.
  2. Precision
    • Of all predicted positives, how many are actually positive:
    • \[\text{Precision} = \frac{\text{TP}}{\text{TP + FP}}\]
  3. Recall (Sensitivity)
    • Of all actual positives, how many did we correctly predict:
    • \[\text{Recall} = \frac{\text{TP}}{\text{TP + FN}}\]
  4. F1 Score
    • Harmonic mean of precision and recall:
    • \[ F1 = 2 \times \frac{\text{Precision} \times \text{Recall}}{\text{Precision} + \text{Recall}}\]
  5. ROC-AUC
    • Measures ranking quality across thresholds; higher is better.
# Get the accuracy metrics
lr_metrics = evaluate_classifier("baseline model", baseline, X_test, y_test)
preprocess.fit(X_train, y_train)

# Get the feature names
final_feature_names = get_feature_names(preprocess)
final_feature_names
=== baseline model ===
Accuracy: 0.554
Precision: 0.554
Recall: 1.000
F1: 0.713
ROC-AUC: 0.500
PR-AUC: 0.554
Confusion matrix:
 [[  0  82]
 [  0 102]]
Classification report:
               precision    recall  f1-score   support

           0       0.00      0.00      0.00        82
           1       0.55      1.00      0.71       102

    accuracy                           0.55       184
   macro avg       0.28      0.50      0.36       184
weighted avg       0.31      0.55      0.40       184






['Age',
 'RestingBP',
 'Cholesterol',
 'MaxHR',
 'Oldpeak',
 'FastingBS',
 'Sex_F',
 'Sex_M',
 'ChestPainType_ASY',
 'ChestPainType_ATA',
 'ChestPainType_NAP',
 'ChestPainType_TA',
 'RestingECG_LVH',
 'RestingECG_Normal',
 'RestingECG_ST',
 'ExerciseAngina_N',
 'ExerciseAngina_Y',
 'ST_Slope_Down',
 'ST_Slope_Flat',
 'ST_Slope_Up']
y_test_int = np.asarray(y_test).astype(int)
yp_int     = np.asarray(yp).astype(int)
print('Baseline accuracy:', accuracy_score(y_test_int, yp_int))
print('Baseline F1:', f1_score(y_test_int, yp_int, zero_division=0))
print("Precision:", precision_score(y_test_int, yp_int, pos_label=1))
print("Recall:", recall_score(y_test_int, yp_int, pos_label=1))
print("F1 Score:", f1_score(y_test_int, yp_int, pos_label=1))
print("ROC-AUC:", roc_auc_score(y_test_int, yp_prob))
Baseline accuracy: 0.5543478260869565
Baseline F1: 0.7132867132867133
Precision: 0.5543478260869565
Recall: 1.0
F1 Score: 0.7132867132867133
ROC-AUC: 0.5

Logistic Regression

Logistic Regression is a generalized linear model used for binary classification. It models the probability of the positive class as:
$$ P(y=1|x) = \sigma(w^T x + b), \quad \text{where } \sigma(z) = \frac{1}{1 + e^{-z}}$$ Training is done via maximum likelihood estimation, minimizing log-loss. Coefficients represent log-odds, and exponentiating them gives odds ratios. Regularization options include L2 (ridge), L1 (lasso) for sparsity, and Elastic-Net for a mix.

Key Assumptions
1. Correct Model Specification: The log-odds of the outcome are a linear combination of predictors.
2. Independence of Observations: Each observation is independent of others.
3. Limited Multicollinearity: Predictors should not be highly correlated.
4. Sufficient Events per Variable: Enough positive cases relative to the number of predictors to avoid overfitting.
5. No Perfect Separation: Predictors should not perfectly predict the outcome (can cause convergence issues).
6. Linearity in the Logit for Numeric Predictors: Continuous variables should have a linear relationship with the log-odds (often checked via transformations or splines).

from sklearn.linear_model import LogisticRegression

logreg = Pipeline(steps=[('prep', preprocess),
                        ('clf',  LogisticRegression(max_iter=1000, solver='lbfgs', C=1.0, random_state=42))])
logreg.fit(X_train, y_train)
lr_metrics = evaluate_classifier("Logistic Regression", logreg)
=== Logistic Regression ===
Accuracy: 0.842
Precision: 0.861
Recall: 0.853
F1: 0.857
ROC-AUC: 0.909
PR-AUC: 0.923
Confusion matrix:
 [[68 14]
 [15 87]]
Classification report:
               precision    recall  f1-score   support

           0       0.82      0.83      0.82        82
           1       0.86      0.85      0.86       102

    accuracy                           0.84       184
   macro avg       0.84      0.84      0.84       184
weighted avg       0.84      0.84      0.84       184

The accuracy is 84% suggesting a better fit to the model when compared to the baseline model. The precision and recall have also improved. To explain the logistic regression result, we can look at the slopes of the variables as the odds ratios.

param_grid_lr = [
    {# L2 with lbfgs/newton-cg/sag/saga
     'clf__solver': ['lbfgs','newton-cg','sag','saga'],
     'clf__penalty': ['l2'],
     'clf__C': [0.01, 0.1, 1.0, 10.0],
     'clf__class_weight': [None, 'balanced']
    },
    {# L1 with liblinear or saga
     'clf__solver': ['liblinear','saga'],
     'clf__penalty': ['l1'],
     'clf__C': [0.01, 0.1, 1.0, 10.0],
     'clf__class_weight': [None, 'balanced']
    },
    {# Elastic-Net with saga
     'clf__solver': ['saga'],
     'clf__penalty': ['elasticnet'],
     'clf__l1_ratio': [0.2, 0.5, 0.8],
     'clf__C': [0.01, 0.1, 1.0, 10.0],
     'clf__class_weight': [None, 'balanced']
    }
]
base_lr = Pipeline(steps=[
    ('prep', preprocess),
    ('clf', LogisticRegression())
])
cv = StratifiedKFold(n_splits=5, shuffle=True, random_state=29)
logreg = GridSearchCV(base_lr, param_grid=param_grid_lr, cv=cv, scoring='roc_auc', n_jobs=-1, refit=True, verbose=1)
logreg.fit(X_train, y_train)
print('Best AUC (CV):', logreg.best_score_)
print('Best params:', logreg.best_params_)

lr_metrics = evaluate_classifier("Logistic Regression (with hyper parameter tuning)", logreg.best_estimator_)
Fitting 5 folds for each of 72 candidates, totalling 360 fits
Best AUC (CV): 0.9280405051136758
Best params: {'clf__C': 0.1, 'clf__class_weight': 'balanced', 'clf__l1_ratio': 0.5, 'clf__penalty': 'elasticnet', 'clf__solver': 'saga'}
=== Logistic Regression (with hyper parameter tuning) ===
Accuracy: 0.842
Precision: 0.869
Recall: 0.843
F1: 0.856
ROC-AUC: 0.909
PR-AUC: 0.924
Confusion matrix:
 [[69 13]
 [16 86]]
Classification report:
               precision    recall  f1-score   support

           0       0.81      0.84      0.83        82
           1       0.87      0.84      0.86       102

    accuracy                           0.84       184
   macro avg       0.84      0.84      0.84       184
weighted avg       0.84      0.84      0.84       184
df
Age Sex ChestPainType RestingBP Cholesterol FastingBS RestingECG MaxHR ExerciseAngina Oldpeak ST_Slope HeartDisease
0 40 M ATA 140.0 289.0 0 Normal 172 N 0.0 Up 0
1 49 F NAP 160.0 180.0 0 Normal 156 N 1.0 Flat 1
2 37 M ATA 130.0 283.0 0 ST 98 N 0.0 Up 0
3 48 F ASY 138.0 214.0 0 Normal 108 Y 1.5 Flat 1
4 54 M NAP 150.0 195.0 0 Normal 122 N 0.0 Up 0
... ... ... ... ... ... ... ... ... ... ... ... ...
913 45 M TA 110.0 264.0 0 Normal 132 N 1.2 Flat 1
914 68 M ASY 144.0 193.0 1 Normal 141 N 3.4 Flat 1
915 57 M ASY 130.0 131.0 0 Normal 115 Y 1.2 Flat 1
916 57 F ATA 130.0 236.0 0 LVH 174 N 0.0 Flat 1
917 38 M NAP 138.0 175.0 0 Normal 173 N 0.0 Up 0

918 rows × 12 columns

# Build formula: categorical vars wrapped in C()

formula = (
    'HeartDisease ~ Age + RestingBP + Cholesterol + MaxHR + Oldpeak + FastingBS '
    '+ C(Sex) + C(ChestPainType) + C(RestingECG) + C(ExerciseAngina) + C(ST_Slope)'
)
df['HeartDisease'] = df.HeartDisease.astype(int)
sm_model = smf.logit(formula=formula, data=df.dropna()).fit(disp=False)
sm_model.summary()
Logit Regression Results
Dep. Variable: HeartDisease No. Observations: 746
Model: Logit Df Residuals: 730
Method: MLE Df Model: 15
Date: Tue, 28 Oct 2025 Pseudo R-squ.: 0.5317
Time: 14:12:34 Log-Likelihood: -241.79
converged: True LL-Null: -516.31
Covariance Type: nonrobust LLR p-value: 2.319e-107
coef std err z P>|z| [0.025 0.975]
Intercept -5.4373 1.763 -3.085 0.002 -8.892 -1.983
FastingBS[T.1] 0.2924 0.331 0.883 0.377 -0.357 0.941
C(Sex)[T.M] 1.8655 0.313 5.952 0.000 1.251 2.480
C(ChestPainType)[T.ATA] -1.6732 0.354 -4.721 0.000 -2.368 -0.979
C(ChestPainType)[T.NAP] -1.5730 0.303 -5.192 0.000 -2.167 -0.979
C(ChestPainType)[T.TA] -1.6333 0.484 -3.376 0.001 -2.582 -0.685
C(RestingECG)[T.Normal] -0.2298 0.284 -0.809 0.419 -0.787 0.327
C(RestingECG)[T.ST] -0.1746 0.394 -0.443 0.658 -0.947 0.598
C(ExerciseAngina)[T.Y] 0.9074 0.267 3.397 0.001 0.384 1.431
C(ST_Slope)[T.Flat] 1.3038 0.520 2.509 0.012 0.285 2.323
C(ST_Slope)[T.Up] -1.2100 0.566 -2.140 0.032 -2.318 -0.102
Age 0.0314 0.015 2.119 0.034 0.002 0.060
RestingBP 0.0118 0.007 1.614 0.107 -0.003 0.026
Cholesterol 0.0025 0.002 1.262 0.207 -0.001 0.006
MaxHR 0.0006 0.006 0.100 0.920 -0.011 0.012
Oldpeak 0.4108 0.141 2.921 0.003 0.135 0.687
# Coefficients -> odds ratios
coefs = logreg.best_estimator_.named_steps['clf'].coef_.ravel()
odds  = np.exp(coefs)
coef_df = pd.DataFrame({'feature': final_feature_names, 'coef': coefs, 'odds_ratio': odds})
coef_df = coef_df.sort_values('odds_ratio', ascending=False)

print("Top risk‑increasing features (odds ratio > 1):")
print(coef_df.head(10).round(3))
print("Protective features (lowest odds ratios):")
print(coef_df.tail(10).round(3))

# Plain‑English explanation
print("Plain‑English explanation (Logistic Regression):")
for _, r in coef_df.head(3).iterrows():
    print(f"• Higher '{r['feature']}' increases the odds of heart disease by ~{r['odds_ratio']:.2f}× (holding others constant).")
for _, r in coef_df.tail(3).sort_values('odds_ratio').iterrows():
    print(f"• Higher '{r['feature']}' is associated with ~{r['odds_ratio']:.2f}× odds (lower than baseline), suggesting protection.")
Top risk‑increasing features (odds ratio > 1):
              feature   coef  odds_ratio
8   ChestPainType_ASY  1.120       3.064
18      ST_Slope_Flat  1.027       2.794
5           FastingBS  0.505       1.656
7               Sex_M  0.484       1.622
4             Oldpeak  0.422       1.524
16   ExerciseAngina_Y  0.310       1.364
0                 Age  0.158       1.171
12     RestingECG_LVH  0.000       1.000
17      ST_Slope_Down  0.000       1.000
14      RestingECG_ST  0.000       1.000
Protective features (lowest odds ratios):
              feature   coef  odds_ratio
13  RestingECG_Normal  0.000       1.000
10  ChestPainType_NAP  0.000       1.000
11   ChestPainType_TA  0.000       1.000
1           RestingBP  0.000       1.000
2         Cholesterol  0.000       1.000
9   ChestPainType_ATA -0.162       0.850
3               MaxHR -0.261       0.770
15   ExerciseAngina_N -0.314       0.731
6               Sex_F -0.487       0.615
19        ST_Slope_Up -0.746       0.474
Plain‑English explanation (Logistic Regression):
• Higher 'ChestPainType_ASY' increases the odds of heart disease by ~3.06× (holding others constant).
• Higher 'ST_Slope_Flat' increases the odds of heart disease by ~2.79× (holding others constant).
• Higher 'FastingBS' increases the odds of heart disease by ~1.66× (holding others constant).
• Higher 'ST_Slope_Up' is associated with ~0.47× odds (lower than baseline), suggesting protection.
• Higher 'Sex_F' is associated with ~0.61× odds (lower than baseline), suggesting protection.
• Higher 'ExerciseAngina_N' is associated with ~0.73× odds (lower than baseline), suggesting protection.
# Plot top positive and negative effects
topN = 15
sdf = coef_df.iloc[np.r_[coef_df['coef'].nlargest(topN).index, coef_df['coef'].nsmallest(topN).index]].\
    sort_values('odds_ratio')

plt.figure(figsize=(8,10))
sns.barplot(y='feature', x='coef', data=sdf, palette=['green' if v<0 else 'red' for v in sdf['coef']])
plt.axvline(0, color='k', lw=1)
plt.title('Top +/- coefficients (log-odds)')
plt.tight_layout()
plt.show()

png

Decision Tree (Shallow, Rule-Based)

Decision Trees are intuitive models that split data based on feature thresholds, forming a set of "if-then" rules. In clinical settings, shallow trees (limited depth) are preferred for their simplicity and interpretability.

This section trains a decision tree with depth=3 and evaluates its performance. The tree structure is visualized, and human-readable rules are extracted to support clinical decision-making.

Assumptions & notes: This model is non‑parametric, meaning that the tree structure cannot be recreated solely from optimal metrics. Explainability improves when the tree is shallow (limited max_depth, larger min_samples_leaf). Post‑training, we can prune using cost‑complexity pruning (ccp_alpha) to remove weak branches and keep the ruleset compact.

from sklearn.tree import DecisionTreeClassifier, export_text, plot_tree

tree_shallow = Pipeline(steps=[('prep', preprocess),
                              ('clf',  DecisionTreeClassifier(max_depth=3, min_samples_leaf=30, random_state=42))])
tree_shallow.fit(X_train, y_train)
dt_metrics = evaluate_classifier("Decision Tree (depth=3)", tree_shallow)
=== Decision Tree (depth=3) ===
Accuracy: 0.821
Precision: 0.871
Recall: 0.794
F1: 0.831
ROC-AUC: 0.897
PR-AUC: 0.891
Confusion matrix:
 [[70 12]
 [21 81]]
Classification report:
               precision    recall  f1-score   support

           0       0.77      0.85      0.81        82
           1       0.87      0.79      0.83       102

    accuracy                           0.82       184
   macro avg       0.82      0.82      0.82       184
weighted avg       0.83      0.82      0.82       184

Optimising the parameters using grid search and cross validation.

param_grid_dt = [{
    'clf__criterion': ['gini', 'entropy', 'log_loss'],  # splitting criteria
    'clf__splitter': ['best', 'random'],

    # keep trees shallow/compact for rules clinicians can use
    'clf__max_depth': [2, 3, 4],
    'clf__min_samples_split': [5, 10],
    'clf__min_samples_leaf': [5, 10],

    # feature sampling can help generalization; None is simplest
    'clf__max_features': [None, 'sqrt', 'log2'],

    # class balance
    'clf__class_weight': [None, 'balanced'],

    # CART post-pruning – small positive values prune weak branches
    'clf__ccp_alpha': [0.0, 1e-4, 5e-4, 1e-3]
}]

base_dt = Pipeline(steps=[('prep', preprocess),
                          ('clf', DecisionTreeClassifier(random_state=29))])

cv = StratifiedKFold(n_splits=5, shuffle=True, random_state=29)
dt_grid = GridSearchCV(base_dt, param_grid=param_grid_dt, cv=cv,
                       scoring='roc_auc', n_jobs=-1, refit=True, verbose=1)

dt_grid.fit(X_train, y_train)
print('Best AUC (CV):', dt_grid.best_score_)
print('Best params:', dt_grid.best_params_)
Fitting 5 folds for each of 1728 candidates, totalling 8640 fits
Best AUC (CV): 0.9081265647932314
Best params: {'clf__ccp_alpha': 0.0, 'clf__class_weight': None, 'clf__criterion': 'gini', 'clf__max_depth': 4, 'clf__max_features': None, 'clf__min_samples_leaf': 5, 'clf__min_samples_split': 5, 'clf__splitter': 'random'}
dt_metrics = evaluate_classifier("Decision Tree (tuned)", dt_grid.best_estimator_)

print("Human‑readable rules (Decision Tree, (hyperparameter tuned):")
print(export_text(dt_grid.best_estimator_.named_steps['clf'], feature_names=final_feature_names))

# Optional visualization (will display when run in Jupyter)
plt.figure(figsize=(14,8))
plot_tree(dt_grid.best_estimator_.named_steps['clf'], feature_names=final_feature_names,
          class_names=['No HD','HD'], filled=True, rounded=True, max_depth=3)
plt.title('Decision Tree (depth=3)'); plt.show()

print("Plain‑English explanation (Decision Tree):")
print("• The top split isolates the strongest single risk condition; subsequent splits form compact if‑then rules.")
print("• Leaves summarize risk with class proportions—ideal for quick triage guidelines.")
=== Decision Tree (tuned) ===
Accuracy: 0.821
Precision: 0.822
Recall: 0.863
F1: 0.842
ROC-AUC: 0.880
PR-AUC: 0.859
Confusion matrix:
 [[63 19]
 [14 88]]
Classification report:
               precision    recall  f1-score   support

           0       0.82      0.77      0.79        82
           1       0.82      0.86      0.84       102

    accuracy                           0.82       184
   macro avg       0.82      0.82      0.82       184
weighted avg       0.82      0.82      0.82       184

Human‑readable rules (Decision Tree, (hyperparameter tuned):
|--- ST_Slope_Up <= 0.91
|   |--- ChestPainType_ASY <= 0.79
|   |   |--- Sex_F <= 0.91
|   |   |   |--- ST_Slope_Flat <= 0.61
|   |   |   |   |--- class: 0
|   |   |   |--- ST_Slope_Flat >  0.61
|   |   |   |   |--- class: 1
|   |   |--- Sex_F >  0.91
|   |   |   |--- Oldpeak <= 0.61
|   |   |   |   |--- class: 0
|   |   |   |--- Oldpeak >  0.61
|   |   |   |   |--- class: 0
|   |--- ChestPainType_ASY >  0.79
|   |   |--- Sex_M <= 0.04
|   |   |   |--- FastingBS <= 0.57
|   |   |   |   |--- class: 1
|   |   |   |--- FastingBS >  0.57
|   |   |   |   |--- class: 1
|   |   |--- Sex_M >  0.04
|   |   |   |--- Cholesterol <= 0.79
|   |   |   |   |--- class: 1
|   |   |   |--- Cholesterol >  0.79
|   |   |   |   |--- class: 1
|--- ST_Slope_Up >  0.91
|   |--- ChestPainType_ASY <= 0.93
|   |   |--- FastingBS <= 0.71
|   |   |   |--- RestingECG_LVH <= 0.97
|   |   |   |   |--- class: 0
|   |   |   |--- RestingECG_LVH >  0.97
|   |   |   |   |--- class: 0
|   |   |--- FastingBS >  0.71
|   |   |   |--- Oldpeak <= -0.18
|   |   |   |   |--- class: 0
|   |   |   |--- Oldpeak >  -0.18
|   |   |   |   |--- class: 1
|   |--- ChestPainType_ASY >  0.93
|   |   |--- ExerciseAngina_N <= 0.11
|   |   |   |--- Oldpeak <= 0.17
|   |   |   |   |--- class: 1
|   |   |   |--- Oldpeak >  0.17
|   |   |   |   |--- class: 1
|   |   |--- ExerciseAngina_N >  0.11
|   |   |   |--- FastingBS <= 0.97
|   |   |   |   |--- class: 0
|   |   |   |--- FastingBS >  0.97
|   |   |   |   |--- class: 1

png

Plain‑English explanation (Decision Tree):
• The top split isolates the strongest single risk condition; subsequent splits form compact if‑then rules.
• Leaves summarize risk with class proportions—ideal for quick triage guidelines.

CART (Decision Tree with Cost-Complexity Pruning)

CART applies pruning to reduce overfitting and simplify the tree. Cost-complexity pruning removes branches that contribute little to predictive power, resulting in a more generalizable and compact model.

This section fits a full tree, derives pruning paths, and selects an optimal ccp_alpha to build a pruned tree. The final rules are extracted and explained.
Assumptions & notes: Non‑parametric. Balancing bias–variance via pruning improves generalization.

# Fit full tree on engineered features to obtain pruning path
Xtr_eng = preprocess.fit_transform(X_train)
from sklearn.tree import DecisionTreeClassifier
full_tree = DecisionTreeClassifier(random_state=42)
full_tree.fit(Xtr_eng, y_train)

path = full_tree.cost_complexity_pruning_path(Xtr_eng, y_train)
ccp_alphas = path.ccp_alphas

# Choose a moderate alpha (e.g., median of unique alphas) for simplicity
alpha = float(np.median(np.unique(ccp_alphas)))
cart = Pipeline(steps=[('prep', preprocess),
                      ('clf',  DecisionTreeClassifier(random_state=42, ccp_alpha=alpha))])
cart.fit(X_train, y_train)
cart_metrics = evaluate_classifier(f"CART (pruned, ccp_alpha={alpha:.5f})", cart)
=== CART (pruned, ccp_alpha=0.00242) ===
Accuracy: 0.837
Precision: 0.853
Recall: 0.853
F1: 0.853
ROC-AUC: 0.820
PR-AUC: 0.817
Confusion matrix:
 [[67 15]
 [15 87]]
Classification report:
               precision    recall  f1-score   support

           0       0.82      0.82      0.82        82
           1       0.85      0.85      0.85       102

    accuracy                           0.84       184
   macro avg       0.84      0.84      0.84       184
weighted avg       0.84      0.84      0.84       184

Optimising the parameters using grid search and cross validation.

# Derive candidate alphas from the pruning path (on engineered features)
preprocess.fit(X_train, y_train)
Xtr_eng = preprocess.transform(X_train)

clf_full = DecisionTreeClassifier(random_state=29)
clf_full.fit(Xtr_eng, y_train)
path = clf_full.cost_complexity_pruning_path(Xtr_eng, y_train)
ccp_alphas = np.unique(path.ccp_alphas)

# Build a simple grid over alphas, keeping depth small
param_grid_cart = [{
    'clf__ccp_alpha': list(ccp_alphas[::max(1, len(ccp_alphas)//10)]),  # sample ~10 alphas
    'clf__max_depth': [2, 3, 4],
    'clf__min_samples_leaf': [5, 10],
    'clf__criterion': ['gini', 'entropy']
}]

base_cart = Pipeline(steps=[('prep', preprocess),
                            ('clf', DecisionTreeClassifier(random_state=29))])

cv = StratifiedKFold(n_splits=5, shuffle=True, random_state=29)
cart_grid = GridSearchCV(base_cart, param_grid=param_grid_cart, cv=cv,
                         scoring='roc_auc', n_jobs=-1, refit=True, verbose=1)

cart_grid.fit(X_train, y_train)
print('Best AUC (CV):', cart_grid.best_score_)
print('Best params:', cart_grid.best_params_)
Fitting 5 folds for each of 144 candidates, totalling 720 fits
Best AUC (CV): 0.897052808678825
Best params: {'clf__ccp_alpha': 0.0, 'clf__criterion': 'entropy', 'clf__max_depth': 4, 'clf__min_samples_leaf': 5}
cart_metrics = evaluate_classifier("CART (pruned, tuned)", cart_grid.best_estimator_)
print("Rules (CART, pruned):")
print(export_text(cart_grid.best_estimator_.named_steps['clf'], feature_names=final_feature_names))

print("Plain‑English explanation (CART):")
print("• Pruning removes branches that add little generalizable signal, keeping rules concise.")
print("• The final tree provides a short list of if‑then criteria suitable for checklists.")
=== CART (pruned, tuned) ===
Accuracy: 0.804
Precision: 0.780
Recall: 0.902
F1: 0.836
ROC-AUC: 0.873
PR-AUC: 0.857
Confusion matrix:
 [[56 26]
 [10 92]]
Classification report:
               precision    recall  f1-score   support

           0       0.85      0.68      0.76        82
           1       0.78      0.90      0.84       102

    accuracy                           0.80       184
   macro avg       0.81      0.79      0.80       184
weighted avg       0.81      0.80      0.80       184

Rules (CART, pruned):
|--- ST_Slope_Up <= 0.50
|   |--- ChestPainType_ASY <= 0.50
|   |   |--- MaxHR <= -0.00
|   |   |   |--- Sex_M <= 0.50
|   |   |   |   |--- class: 0
|   |   |   |--- Sex_M >  0.50
|   |   |   |   |--- class: 1
|   |   |--- MaxHR >  -0.00
|   |   |   |--- ST_Slope_Flat <= 0.50
|   |   |   |   |--- class: 0
|   |   |   |--- ST_Slope_Flat >  0.50
|   |   |   |   |--- class: 1
|   |--- ChestPainType_ASY >  0.50
|   |   |--- Sex_M <= 0.50
|   |   |   |--- FastingBS <= 0.63
|   |   |   |   |--- class: 1
|   |   |   |--- FastingBS >  0.63
|   |   |   |   |--- class: 1
|   |   |--- Sex_M >  0.50
|   |   |   |--- Oldpeak <= -0.70
|   |   |   |   |--- class: 1
|   |   |   |--- Oldpeak >  -0.70
|   |   |   |   |--- class: 1
|--- ST_Slope_Up >  0.50
|   |--- ChestPainType_ASY <= 0.50
|   |   |--- Oldpeak <= 1.03
|   |   |   |--- Cholesterol <= -0.10
|   |   |   |   |--- class: 0
|   |   |   |--- Cholesterol >  -0.10
|   |   |   |   |--- class: 0
|   |   |--- Oldpeak >  1.03
|   |   |   |--- class: 1
|   |--- ChestPainType_ASY >  0.50
|   |   |--- Oldpeak <= -0.41
|   |   |   |--- FastingBS <= 0.63
|   |   |   |   |--- class: 0
|   |   |   |--- FastingBS >  0.63
|   |   |   |   |--- class: 1
|   |   |--- Oldpeak >  -0.41
|   |   |   |--- ExerciseAngina_Y <= 0.50
|   |   |   |   |--- class: 1
|   |   |   |--- ExerciseAngina_Y >  0.50
|   |   |   |   |--- class: 1

Plain‑English explanation (CART):
• Pruning removes branches that add little generalizable signal, keeping rules concise.
• The final tree provides a short list of if‑then criteria suitable for checklists.

Naive Bayes (Gaussian)

Naive Bayes is a probabilistic classifier based on Bayes' theorem, assuming feature independence. It models continuous features using Gaussian distributions.

This section trains and tunes a Gaussian Naive Bayes model, evaluates its performance, and analyzes feature-wise class means to understand which features contribute most to prediction.

Assumptions & notes: Independence across features given the class. Stability is often improved by tuning var_smoothing (which adds a small value to variances). While it has While it has fewer tunable parameters than most models, it is great for fast, interpretable baselines.

from sklearn.naive_bayes import GaussianNB
from sklearn.preprocessing import FunctionTransformer

# Make engineered features dense for GaussianNB
to_dense = FunctionTransformer(lambda X: X.toarray() if hasattr(X, 'toarray') else X)
nb = Pipeline(steps=[('prep', preprocess),
                    ('dense', to_dense),
                    ('clf',  GaussianNB())])
nb.fit(X_train, y_train)
nb_metrics = evaluate_classifier("Naive Bayes (Gaussian)", nb)
=== Naive Bayes (Gaussian) ===
Accuracy: 0.837
Precision: 0.867
Recall: 0.833
F1: 0.850
ROC-AUC: 0.909
PR-AUC: 0.923
Confusion matrix:
 [[69 13]
 [17 85]]
Classification report:
               precision    recall  f1-score   support

           0       0.80      0.84      0.82        82
           1       0.87      0.83      0.85       102

    accuracy                           0.84       184
   macro avg       0.83      0.84      0.84       184
weighted avg       0.84      0.84      0.84       184

Optimising the parameters using grid search and cross validation.

param_grid_nb = [{
    # var_smoothing adds a fraction of the largest variance for numerical stability
    'clf__var_smoothing': np.logspace(-12, -7, 10),
    'clf__priors': [None]  # keep data-driven class priors
}]

base_nb = Pipeline(steps=[('prep', preprocess),
                          ('dense', to_dense),
                          ('clf', GaussianNB())])

cv = StratifiedKFold(n_splits=5, shuffle=True, random_state=29)
nb_grid = GridSearchCV(base_nb, param_grid=param_grid_nb, cv=cv,
                       scoring='roc_auc', n_jobs=-1, refit=True, verbose=1)

nb_grid.fit(X_train, y_train)
print('Best AUC (CV):', nb_grid.best_score_)
print('Best params:', nb_grid.best_params_)
Fitting 5 folds for each of 10 candidates, totalling 50 fits
Best AUC (CV): 0.9183289499820665
Best params: {'clf__priors': None, 'clf__var_smoothing': 1e-12}
nb_metrics = evaluate_classifier("Naive Bayes (Gaussian, tuned)", nb_grid.best_estimator_)

# Class means per engineered feature
gnb = nb_grid.best_estimator_.named_steps['clf']
Xtr_engineered = nb_grid.best_estimator_.named_steps['dense'].\
    transform(nb_grid.best_estimator_.named_steps['prep'].transform(X_train))
means_0 = gnb.theta_[0]; means_1 = gnb.theta_[1]
nb_df = pd.DataFrame({'feature': final_feature_names,
                      'mean_no_disease': means_0,
                      'mean_disease':   means_1,
                      'diff(disease - no_disease)': means_1 - means_0})
nb_df = nb_df.sort_values('diff(disease - no_disease)', ascending=False)
print("Plain‑English explanation (Naive Bayes):")
print("• A feature value more typical under the disease class shifts the posterior toward heart disease; the opposite shifts toward no disease.")
print("Top features with largest class‑mean differences (NB view):")
nb_df.head(10).round(3)
=== Naive Bayes (Gaussian, tuned) ===
Accuracy: 0.837
Precision: 0.867
Recall: 0.833
F1: 0.850
ROC-AUC: 0.909
PR-AUC: 0.923
Confusion matrix:
 [[69 13]
 [17 85]]
Classification report:
               precision    recall  f1-score   support

           0       0.80      0.84      0.82        82
           1       0.87      0.83      0.85       102

    accuracy                           0.84       184
   macro avg       0.83      0.84      0.84       184
weighted avg       0.84      0.84      0.84       184

Plain‑English explanation (Naive Bayes):
• A feature value more typical under the disease class shifts the posterior toward heart disease; the opposite shifts toward no disease.
Top features with largest class‑mean differences (NB view):
feature mean_no_disease mean_disease diff(disease - no_disease)
4 Oldpeak -0.465 0.376 0.841
0 Age -0.334 0.270 0.604
5 FastingBS -0.320 0.259 0.579
18 ST_Slope_Flat 0.186 0.759 0.573
8 ChestPainType_ASY 0.253 0.773 0.520
16 ExerciseAngina_Y 0.143 0.618 0.475
1 RestingBP -0.132 0.107 0.239
7 Sex_M 0.655 0.894 0.239
2 Cholesterol -0.079 0.064 0.144
14 RestingECG_ST 0.149 0.227 0.077

k-Nearest Neighbors (k-NN)

k-NN is a non-parametric, instance-based learning algorithm. It predicts the class of a sample based on the majority vote of its k nearest neighbors.

This section trains and tunes a k-NN model, evaluates its performance, and inspects the neighbors of a sample to explain the prediction in clinical terms.

Assumptions & notes: KNN is sensitive to scaling and the distance metric. Key hyperparameters include n_neighbors, weights (uniform vs distance), and metric (Minkowski with p=½).

from sklearn.neighbors import KNeighborsClassifier

knn = Pipeline(steps=[('prep', preprocess),
                     ('clf',  KNeighborsClassifier(n_neighbors=5, weights='distance'))])
knn.fit(X_train, y_train)
knn_metrics = evaluate_classifier("k-NN (k=5, distance weights)", knn)
=== k-NN (k=5, distance weights) ===
Accuracy: 0.859
Precision: 0.873
Recall: 0.873
F1: 0.873
ROC-AUC: 0.890
PR-AUC: 0.886
Confusion matrix:
 [[69 13]
 [13 89]]
Classification report:
               precision    recall  f1-score   support

           0       0.84      0.84      0.84        82
           1       0.87      0.87      0.87       102

    accuracy                           0.86       184
   macro avg       0.86      0.86      0.86       184
weighted avg       0.86      0.86      0.86       184
param_grid_knn = [{
    'clf__n_neighbors': [3, 5, 7, 9, 11],
    'clf__weights': ['uniform', 'distance'],
    # Minkowski metric with p=1 (Manhattan) or p=2 (Euclidean)
    'clf__metric': ['minkowski'],
    'clf__p': [1, 2]
}]

base_knn = Pipeline(steps=[('prep', preprocess),
                           ('clf', KNeighborsClassifier())])

cv = StratifiedKFold(n_splits=5, shuffle=True, random_state=29)
knn_grid = GridSearchCV(base_knn, param_grid=param_grid_knn, cv=cv,
                        scoring='roc_auc', n_jobs=-1, refit=True, verbose=1)

knn_grid.fit(X_train, y_train)
print('Best AUC (CV):', knn_grid.best_score_)
print('Best params:', knn_grid.best_params_)
Fitting 5 folds for each of 20 candidates, totalling 100 fits
Best AUC (CV): 0.9204453821797994
Best params: {'clf__metric': 'minkowski', 'clf__n_neighbors': 11, 'clf__p': 1, 'clf__weights': 'distance'}

Optimising the parameters using grid search and cross validation.

df_knn = df.copy()

df_knn['class'] = knn_grid.predict(df)
# df_knn.plot.scatter('RestingBP', 'Cholesterol', 'class')
plt.figure(figsize=(8, 6))
sns.scatterplot(data=df_knn, x='Age', y='MaxHR', hue='class', s=50, alpha=0.75)
plt.title('Scatter plot showing the incidence of heart disease')
plt.grid(True)
plt.show()

png

knn_metrics = evaluate_classifier("k-NN (tuned)", knn_grid.best_estimator_)

# Inspect neighbors for one patient
idx = 0
x_query = X_test.iloc[[idx]]
y_true = y_test.iloc[idx]
proba = knn_grid.predict_proba(x_query)[:,1][0]
pred  = knn_grid.predict(x_query)[0]

print("Plain‑English explanation (k‑NN):")
print("• The prediction is driven by the outcomes of the five most similar patients; closer neighbors carry higher weight.")

print(f"Query patient index #{x_query.index[0]} | True={y_true}, Pred={pred}, Prob(HD)={proba:.3f}")
print("Patient details:")
print(x_query)
# Show actual neighbors from engineered space
Xtr_eng2 = preprocess.transform(X_train)
Xte_eng2 = preprocess.transform(x_query)
base_knn = KNeighborsClassifier(n_neighbors=5, weights='distance').fit(Xtr_eng2, y_train)
dist, ind = base_knn.kneighbors(Xte_eng2, n_neighbors=5, return_distance=True)
neighbors = pd.DataFrame({'neighbor_train_idx': y_train.index.values[ind[0]],
                          'distance': dist[0],
                          'label':   y_train.iloc[ind[0]].values}).sort_values('distance')

print("Nearest neighbors (closer = more influence):")
neighbors.join(df)
=== k-NN (tuned) ===
Accuracy: 0.875
Precision: 0.876
Recall: 0.902
F1: 0.889
ROC-AUC: 0.906
PR-AUC: 0.905
Confusion matrix:
 [[69 13]
 [10 92]]
Classification report:
               precision    recall  f1-score   support

           0       0.87      0.84      0.86        82
           1       0.88      0.90      0.89       102

    accuracy                           0.88       184
   macro avg       0.87      0.87      0.87       184
weighted avg       0.87      0.88      0.87       184

Plain‑English explanation (k‑NN):
• The prediction is driven by the outcomes of the five most similar patients; closer neighbors carry higher weight.
Query patient index #95 | True=1, Pred=1, Prob(HD)=0.868
Patient details:
    Age  RestingBP  Cholesterol  MaxHR  Oldpeak Sex ChestPainType RestingECG  \
95   58      130.0        263.0    140      2.0   M           ASY     Normal

   ExerciseAngina ST_Slope FastingBS  
95              Y     Flat         0  
Nearest neighbors (closer = more influence):
neighbor_train_idx distance label Age Sex ChestPainType RestingBP Cholesterol FastingBS RestingECG MaxHR ExerciseAngina Oldpeak ST_Slope HeartDisease
0 562 0.806738 0 40 M ATA 140.0 289.0 0 Normal 172 N 0.0 Up 0
1 74 0.926989 1 49 F NAP 160.0 180.0 0 Normal 156 N 1.0 Flat 1
2 586 1.016484 1 37 M ATA 130.0 283.0 0 ST 98 N 0.0 Up 0
3 501 1.032806 1 48 F ASY 138.0 214.0 0 Normal 108 Y 1.5 Flat 1
4 368 1.053968 1 54 M NAP 150.0 195.0 0 Normal 122 N 0.0 Up 0
summary = pd.DataFrame([lr_metrics, dt_metrics, cart_metrics, nb_metrics, knn_metrics]).set_index('model').sort_values('roc_auc', ascending=False)
summary.sort_values('accuracy', ascending=False).round(3)
accuracy precision recall f1 roc_auc pr_auc
model
k-NN (tuned) 0.875 0.876 0.902 0.889 0.906 0.905
Logistic Regression (with hyper parameter tuning) 0.842 0.869 0.843 0.856 0.909 0.924
Naive Bayes (Gaussian, tuned) 0.837 0.867 0.833 0.850 0.909 0.923
CART (pruned, ccp_alpha=0.00242) 0.837 0.853 0.853 0.853 0.820 0.817
Decision Tree (tuned) 0.821 0.822 0.863 0.842 0.880 0.859

Diagnostics of the Best Model

To ensure robustness and reliability, we perform diagnostics on the best-performing model (k-NN). This includes checking multicollinearity using VIF and plotting learning curves to assess bias-variance tradeoff.

These diagnostics help validate model assumptions and guide further improvements.

best_model = knn_grid.best_estimator_
from statsmodels.stats.outliers_influence import variance_inflation_factor

X_num = df[num_cols].dropna()
X_num_const = sm.add_constant(X_num)
vif = pd.DataFrame({
    'feature': ['const'] + list(X_num.columns),
    'VIF': [variance_inflation_factor(X_num_const.values, i) for i in range(X_num_const.shape[1])]
})
vif
feature VIF
0 const 160.072777
1 Age 1.277888
2 RestingBP 1.098974
3 Cholesterol 1.011748
4 MaxHR 1.205888
5 Oldpeak 1.142342

ROC-AUC Curve

The ROC (Receiver Operating Characteristic) curve is a graphical representation of a classifier's performance across different threshold values. It plots the True Positive Rate (Sensitivity) against the False Positive Rate (1 - Specificity).

AUC (Area Under the Curve) quantifies the overall ability of the model to discriminate between positive and negative classes. A higher AUC indicates better performance.

from sklearn.metrics import roc_curve, auc
import matplotlib.pyplot as plt

# Predict probabilities
y_probs = best_model.predict_proba(X_test)[:, 1]

# Compute ROC curve and AUC
fpr, tpr, thresholds = roc_curve(y_test.astype(int), y_probs)
roc_auc = auc(fpr, tpr)

# Plot ROC curve
plt.figure()
plt.plot(fpr, tpr, color='darkorange', lw=2, label=f'ROC curve (AUC = {roc_auc:.2f})')
plt.plot([0, 1], [0, 1], color='navy', lw=2, linestyle='--')
plt.xlabel('False Positive Rate')
plt.ylabel('True Positive Rate')
plt.title('Receiver Operating Characteristic')
plt.legend(loc='lower right')
plt.show();

png

AUC of 0.91 indicates good performance

Sensitivity-Specificity Curve

This curve shows how sensitivity and specificity vary with different classification thresholds. It helps in selecting a threshold that balances both metrics.

# Compute sensitivity and specificity
sensitivity = tpr
specificity = 1 - fpr
# Plot Sensitivity-Specificity curve
plt.figure()
plt.plot(thresholds, sensitivity, label='Sensitivity')
plt.plot(thresholds, specificity, label='Specificity')
plt.xlabel('Threshold')
plt.title('Sensitivity vs Specificity')
plt.legend()
plt.show();

png

Learning curves

Learning curves are a valuable diagnostic tool to understand how a model's performance evolves with increasing training data. They help assess whether a model suffers from high bias (underfitting) or high variance (overfitting), and whether adding more data could improve performance.

# Learning curve for best estimator
train_sizes, train_scores, val_scores = learning_curve(best_model, X_train, y_train, cv=5, scoring='roc_auc', n_jobs=-1,
                                                      train_sizes=np.linspace(0.1, 1.0, 5), random_state=29)
plt.figure(figsize=(7,5))
plt.plot(train_sizes, train_scores.mean(axis=1), 'o-', label='Train AUC')
plt.plot(train_sizes, val_scores.mean(axis=1), 'o-', label='CV AUC')
plt.xlabel('Training examples')
plt.ylim(0, 1.01)
plt.ylabel('ROC-AUC')
plt.title('Learning Curve')
plt.legend()
plt.show()

png

Ethics and robustness

  • Medical use requires validation: consult clinicians; understand labeling & sampling biases.
  • Test fairness across groups (e.g., Sex) and ensure no proxy discrimination.
  • Perform temporal/center splits if applicable; assess drift.
  • Consider decision support (calibrated probabilities + cost‑aware thresholds) rather than hard automation.
Back to top