Summary
Master decision tree classification: understand Gini impurity and entropy, learn the CART algorithm step-by-step, and apply pruning to build interpretable yet powerful models.
Learning Objectives
- Understand how decision trees make splits
- Compute information gain and Gini impurity
- Implement tree building with CART
- Apply pruning to prevent overfitting
Theory
What is a Decision Tree?
A decision tree is a flowchart-like structure that makes predictions by asking a series of questions about feature values.
Why are decision trees popular?
| Advantage | Description |
|---|
| Interpretable | Every prediction can be explained (“Loan rejected because Age < 25 AND Income < 30k”) |
| No preprocessing | Handles numeric and categorical features, missing values |
| Non-parametric | No assumptions about data distribution |
| Visual | Can be drawn and understood by non-experts |
Tree Structure
graph TB
A["🔷 Root: Age > 30?"] -->|No| B["🔶 Internal: Income > 50k?"]
A -->|Yes| C["🟢 Leaf: Approve"]
B -->|No| D["🔴 Leaf: Reject"]
B -->|Yes| E["🟢 Leaf: Approve"]
style C fill:#c8e6c9
style E fill:#c8e6c9
style D fill:#ffcdd2
| Node Type | Role | Example |
|---|
| Root | First split, most informative feature | “Age > 30?” |
| Internal | Intermediate decision | “Income > 50k?” |
| Leaf | Final prediction | “Approve” or “Reject” |
Geometric View: Each split creates an axis-aligned boundary. The tree partitions the feature space into rectangles, each assigned to a class.
Splitting Criteria: How to Choose the Best Split?
The goal is to find splits that create purer child nodes (more samples of one class).
Gini Impurity
Measures probability of misclassifying a random sample:
$$\boxed{Gini(D) = 1 - \sum_{k=1}^{K} p_k^2}$$
where $p_k$ is the proportion of class $k$ in dataset $D$.
Intuition: If all samples belong to one class → $Gini = 0$ (pure). If classes are equally mixed → $Gini$ is maximized.
| Dataset Example | $p_1$ | $p_2$ | Gini |
|---|
| [🔵🔵🔵🔵🔵] | 1.0 | 0.0 | $1 - 1^2 = 0$ (pure) |
| [🔵🔵🔵🔴🔴] | 0.6 | 0.4 | $1 - (0.6^2 + 0.4^2) = 0.48$ |
| [🔵🔵🔴🔴🔴] | 0.4 | 0.6 | $1 - (0.4^2 + 0.6^2) = 0.48$ |
| [🔵🔴🔵🔴] | 0.5 | 0.5 | $1 - (0.5^2 + 0.5^2) = 0.5$ (max impurity) |
Entropy measures uncertainty (from information theory):
$$\boxed{Entropy(D) = -\sum_{k=1}^{K} p_k \log_2 p_k}$$
Information Gain = reduction in entropy after a split:
$$IG(D, A) = Entropy(D) - \sum_{v \in values(A)} \frac{|D_v|}{|D|} Entropy(D_v)$$
| Metric | Intuition | Range (binary) |
|---|
| Gini | Probability of misclassification | [0, 0.5] |
| Entropy | Bits needed to encode class | [0, 1] |
| Info Gain | Reduction in uncertainty | Higher = better |
Gini vs Entropy: In practice, they produce nearly identical trees. Gini is slightly faster (no logarithm). sklearn uses Gini by default.
The CART Algorithm
CART (Classification and Regression Trees) is the standard algorithm used by sklearn.
Algorithm Steps
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
| function BuildTree(D, depth):
if stopping_condition(D, depth):
return LeafNode(majority_class(D))
best_feature, best_threshold = None, None
best_impurity_reduction = 0
for each feature f:
for each threshold t in unique_values(f):
D_left = {x ∈ D : x[f] ≤ t}
D_right = {x ∈ D : x[f] > t}
reduction = Impurity(D) - weighted_avg(Impurity(D_left), Impurity(D_right))
if reduction > best_impurity_reduction:
best_feature, best_threshold = f, t
best_impurity_reduction = reduction
left_child = BuildTree(D_left, depth + 1)
right_child = BuildTree(D_right, depth + 1)
return DecisionNode(best_feature, best_threshold, left_child, right_child)
|
Key Characteristics
| Property | Description |
|---|
| Binary splits | Each node has exactly 2 children |
| Greedy | Chooses locally optimal split, not globally |
| Top-down | Starts at root, recursively splits |
Greedy ≠ Optimal: CART doesn’t guarantee the best possible tree. Finding the optimal tree is NP-complete!
Worked Example: Splitting on a Simple Dataset
| ID | Age | Income | Approved |
|---|
| 1 | 25 | 40k | ❌ |
| 2 | 35 | 60k | ✅ |
| 3 | 45 | 80k | ✅ |
| 4 | 20 | 30k | ❌ |
| 5 | 30 | 50k | ✅ |
Step 1: Calculate parent Gini
- 3 approved, 2 rejected → $Gini = 1 - (0.6^2 + 0.4^2) = 0.48$
Step 2: Try split on Age ≤ 25
- Left (Age ≤ 25): [❌, ❌] → $Gini = 0$
- Right (Age > 25): [✅, ✅, ✅] → $Gini = 0$
- Weighted: $\frac{2}{5} \times 0 + \frac{3}{5} \times 0 = 0$
- Reduction: 0.48 - 0 = 0.48 ✅ Perfect split!
Step 3: Try split on Age ≤ 30
- Left: [❌, ❌, ✅] → $Gini = 0.444$
- Right: [✅, ✅] → $Gini = 0$
- Weighted: $\frac{3}{5} \times 0.444 + \frac{2}{5} \times 0 = 0.267$
- Reduction: 0.48 - 0.267 = 0.213
Best split: Age ≤ 25 (highest impurity reduction)
Overfitting and Pruning
Decision trees are prone to overfitting because they can grow until each leaf is pure, memorizing the training data.
graph LR
A["Underfitting<br/>depth=1"] --> B["Good Fit<br/>depth=3-5"] --> C["Overfitting<br/>depth=∞"]
style A fill:#ffcdd2
style B fill:#c8e6c9
style C fill:#ffcdd2
Pre-pruning (Early Stopping)
Stop growing before overfitting:
| Parameter | sklearn | Effect |
|---|
max_depth | 3-10 typical | Limits tree height |
min_samples_split | 10-50 | Requires min samples to split |
min_samples_leaf | 5-20 | Leaves must have min samples |
max_features | ‘sqrt’, ’log2’ | Limits features considered |
Post-pruning (Cost-Complexity)
Grow full tree, then prune based on complexity penalty:
$$R_\alpha(T) = R(T) + \alpha \cdot |T|$$
where:
- $R(T)$ = misclassification rate
- $|T|$ = number of leaves
- $\alpha$ = complexity parameter (
ccp_alpha in sklearn)
Use clf.cost_complexity_pruning_path(X, y) to find optimal α via cross-validation.
Code Practice
The following examples demonstrate building decision trees and exploring their behavior.
Building and Visualizing a Decision Tree
🐍 Python
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
| from sklearn.tree import DecisionTreeClassifier, plot_tree, export_text
from sklearn.datasets import load_iris
import matplotlib.pyplot as plt
# Load Iris dataset
X, y = load_iris(return_X_y=True)
feature_names = load_iris().feature_names
class_names = load_iris().target_names
# Build tree with max_depth=3
clf = DecisionTreeClassifier(max_depth=3, random_state=42)
clf.fit(X, y)
# Print text representation
print("🌳 Decision Tree Structure:")
print(export_text(clf, feature_names=list(feature_names)))
# Tree statistics
print(f"\n📊 Tree Statistics:")
print(f" Depth: {clf.get_depth()}")
print(f" Leaves: {clf.get_n_leaves()}")
print(f" Features used: {sum(clf.feature_importances_ > 0)}/{len(feature_names)}")
|
Output:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
| 🌳 Decision Tree Structure:
|--- petal length (cm) <= 2.45
| |--- class: 0
|--- petal length (cm) > 2.45
| |--- petal width (cm) <= 1.75
| | |--- petal length (cm) <= 4.95
| | | |--- class: 1
| | |--- petal length (cm) > 4.95
| | | |--- class: 2
| |--- petal width (cm) > 1.75
| | |--- petal length (cm) <= 4.85
| | | |--- class: 2
| | |--- petal length (cm) > 4.85
| | | |--- class: 2
📊 Tree Statistics:
Depth: 3
Leaves: 5
Features used: 2/4
|
The tree only uses 2 features (petal width and petal length) out of 4! This is automatic feature selection.
Computing Gini Impurity from Scratch
🐍 Python
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
| import numpy as np
def gini(y):
"""Compute Gini impurity."""
classes, counts = np.unique(y, return_counts=True)
probs = counts / len(y)
return 1 - np.sum(probs ** 2)
def entropy(y):
"""Compute Entropy."""
classes, counts = np.unique(y, return_counts=True)
probs = counts / len(y)
return -np.sum(probs * np.log2(probs + 1e-10))
# Compare different distributions
examples = [
("Pure [A,A,A,A,A]", [0, 0, 0, 0, 0]),
("Mixed [A,A,A,B,B]", [0, 0, 0, 1, 1]),
("Even [A,A,B,B,B]", [0, 0, 1, 1, 1]),
("50/50 [A,B,A,B]", [0, 1, 0, 1]),
]
print("📊 Gini vs Entropy Comparison")
print("=" * 50)
print(f"{'Distribution':<22} {'Gini':>8} {'Entropy':>10}")
print("-" * 50)
for name, y in examples:
print(f"{name:<22} {gini(y):>8.3f} {entropy(y):>10.3f}")
|
Output:
1
2
3
4
5
6
7
8
| 📊 Gini vs Entropy Comparison
==================================================
Distribution Gini Entropy
--------------------------------------------------
Pure [A,A,A,A,A] 0.000 -0.000
Mixed [A,A,A,B,B] 0.480 0.971
Even [A,A,B,B,B] 0.480 0.971
50/50 [A,B,A,B] 0.500 1.000
|
Both Gini and Entropy are minimized at 0 for pure nodes and maximized for 50/50 splits. They typically produce identical trees!
Effect of Max Depth (Bias-Variance Tradeoff)
🐍 Python
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
| from sklearn.tree import DecisionTreeClassifier
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
# Load Iris dataset
X, y = load_iris(return_X_y=True)
X_train, X_test, y_train, y_test = train_test_split(
X, y, test_size=0.3, random_state=42
)
print("📊 Effect of Max Depth on Accuracy")
print("=" * 50)
print(f"{'Depth':<10} {'Train Acc':>12} {'Test Acc':>12} {'# Leaves':>10}")
print("-" * 50)
for depth in [1, 2, 3, 5, 10, None]:
clf = DecisionTreeClassifier(max_depth=depth, random_state=42)
clf.fit(X_train, y_train)
train_acc = clf.score(X_train, y_train)
test_acc = clf.score(X_test, y_test)
n_leaves = clf.get_n_leaves()
depth_str = "∞" if depth is None else str(depth)
print(f"{depth_str:<10} {train_acc:>12.2%} {test_acc:>12.2%} {n_leaves:>10}")
|
Output:
1
2
3
4
5
6
7
8
9
10
| 📊 Effect of Max Depth on Accuracy
==================================================
Depth Train Acc Test Acc # Leaves
--------------------------------------------------
1 64.76% 71.11% 2
2 94.29% 97.78% 3
3 95.24% 100.00% 5
5 99.05% 100.00% 9
10 100.00% 100.00% 10
∞ 100.00% 100.00% 10
|
Sweet spot at depth=3: 100% test accuracy with only 5 leaves. Deeper trees achieve the same accuracy but with more leaves — no benefit from added complexity!
Feature Importance
🐍 Python
1
2
3
4
5
6
7
8
9
| clf = DecisionTreeClassifier(max_depth=4, random_state=42)
clf.fit(X, y)
print("📊 Feature Importance")
print("=" * 40)
for name, importance in sorted(zip(feature_names, clf.feature_importances_),
key=lambda x: -x[1]):
bar = "█" * int(importance * 30)
print(f"{name:<18} {importance:.3f} {bar}")
|
Output:
1
2
3
4
5
6
| 📊 Feature Importance
========================================
petal length (cm) 0.558 ████████████████
petal width (cm) 0.428 ████████████
sepal width (cm) 0.014
sepal length (cm) 0.000
|
Feature importance is based on total impurity reduction. Petal length is the most important feature for Iris classification!
Regression Tree
🐍 Python
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
| from sklearn.tree import DecisionTreeRegressor
# Generate noisy sine wave
np.random.seed(42)
X_reg = np.sort(5 * np.random.rand(80, 1), axis=0)
y_reg = np.sin(X_reg).ravel() + np.random.randn(80) * 0.2
# Compare different depths
depths = [2, 4, None]
X_test_reg = np.linspace(0, 5, 100).reshape(-1, 1)
print("📊 Regression Tree - Mean Squared Error")
print("=" * 40)
for depth in depths:
reg = DecisionTreeRegressor(max_depth=depth, random_state=42)
reg.fit(X_reg, y_reg)
mse = np.mean((reg.predict(X_reg) - y_reg) ** 2)
depth_str = "∞" if depth is None else str(depth)
print(f"Depth={depth_str:<4} MSE={mse:.4f} Leaves={reg.get_n_leaves()}")
|
Output:
1
2
3
4
5
| 📊 Regression Tree - Mean Squared Error
========================================
Depth=2 MSE=0.0747 Leaves=4
Depth=4 MSE=0.0312 Leaves=16
Depth=∞ MSE=0.0000 Leaves=80
|
With depth=∞, MSE=0 means the tree memorized every training point. This is severe overfitting — use pruning!
Deep Dive
Frequently Asked Questions
Q1: Gini vs Entropy — which is better?
| Aspect | Gini | Entropy |
|---|
| Speed | Faster (no log) | Slower |
| Interpretation | Probability of misclassification | Information content |
| Result | Nearly identical trees | Nearly identical trees |
Verdict: Use Gini (sklearn default). Only use Entropy if information-theoretic justification is needed.
Q2: When should Decision Trees be used?
| Use Tree When | Avoid Tree When |
|---|
| Interpretability is key | High accuracy is critical |
| Data has simple patterns | Data has complex boundaries |
| Few training samples | Large dataset (use ensemble) |
| Quick baseline needed | Robustness required |
| Mixed feature types | Features are highly correlated |
Q3: Decision Tree vs Random Forest vs Gradient Boosting
| Aspect | Decision Tree | Random Forest | Gradient Boosting |
|---|
| Interpretability | ✅ High | ⚠️ Low | ❌ Very low |
| Training speed | ⚡ Fast | 🔶 Medium | 🐢 Slow |
| Overfitting risk | ⚠️ High | ✅ Low | ⚠️ Medium |
| Accuracy | 🔶 Medium | ✅ High | ✅✅ Very high |
| Hyperparameter tuning | Easy | Medium | Hard |
Use single trees for explainability, Random Forest for robust baseline, XGBoost/LightGBM for competitions.
Q4: How are categorical features handled?
sklearn’s DecisionTreeClassifier only handles numeric features. Options:
| Approach | Pros | Cons |
|---|
| One-Hot Encoding | Works with sklearn | High cardinality = many features |
| Label Encoding | Compact | Implies ordering (not ideal) |
| Category Encoders | Target encoding | Risk of leakage |
| Use histogram-based | LightGBM handles natively | Different library |
Practical Tips
Hyperparameter tuning checklist:
- Start with
max_depth=5 - Tune
min_samples_split (10-50 for large datasets) - Try
min_samples_leaf=5 to smooth predictions - Use
ccp_alpha for post-pruning
Preventing Overfitting
graph TD
A["Train Accuracy High<br/>Test Accuracy Low?"] -->|Yes| B{"Overfitting!"}
B --> C["Reduce max_depth"]
B --> D["Increase min_samples_split"]
B --> E["Increase min_samples_leaf"]
B --> F["Use ccp_alpha pruning"]
Common Pitfalls
| Pitfall | Symptom | Solution |
|---|
| Overfitting | 100% train, low test accuracy | Add pruning constraints |
| Feature leakage | Unrealistic high accuracy | Remove leaky features |
| Imbalanced data | Tree ignores minority class | Use class_weight='balanced' |
| High cardinality | Tree splits on ID-like features | Remove or encode properly |
| Missing values | sklearn doesn’t handle them | Impute first |
When to Choose Something Else
graph TD
A["Classification Problem"] --> B{"Need interpretability?"}
B -->|Yes| C["Decision Tree"]
B -->|No| D{"Dataset size?"}
D -->|Small| E["Random Forest"]
D -->|Large| F["Gradient Boosting<br/>(XGBoost/LightGBM)"]
Summary
| Concept | Key Points |
|---|
| Splitting | Choose feature/threshold minimizing impurity |
| Gini Impurity | $1 - \sum p_k^2$, prob of misclassification |
| Entropy | $-\sum p_k \log_2 p_k$, bits needed |
| CART | Greedy binary splits, top-down |
| Pruning | Pre (max_depth) or Post (ccp_alpha) |
| Interpretability | Can explain each prediction |
References
- Breiman, L. et al. (1984). “Classification and Regression Trees”
- sklearn Decision Trees
- Quinlan, J.R. (1986). “Induction of Decision Trees”