Random Forest
Random Forest
Definition
Core Statement
Random Forest is an ensemble learning method that operates by constructing a multitude of Decision Trees at training time. For classification, it outputs the mode (majority vote) of the classes; for regression, the mean prediction.
It combines Bagging (Bootstrap Aggregation) with Feature Randomness.
Purpose
- Reduce Variance: Individual trees are noisy and prone to overfitting. Averaging 100 trees cancels out the noise.
- High Accuracy: Often the "gold standard" for tabular data (competed only by Gradient Boosting).
- Feature Importance: Provides a robust estimate of which variables matter.
The "Random" Ingredients
- Bagging (Row Randomness): Each tree is trained on a random sample of the data (with replacement). This means each tree sees slightly different data.
- Feature Subsampling (Column Randomness): At each split in a tree, the algorithm considers only a random subset of features (e.g.,
features). This forces trees to be uncorrelated. (If one feature is super strong, normally ALL trees would use it first. Implementation of randomness prevents this).
Random Forest vs Single Tree
| Feature | Single Decision Tree | Random Forest |
|---|---|---|
| Bias | Low | Low |
| Variance | High (Overfits) | Low (Stable) |
| Interpretability | High (Visualizable) | Low (Black Box) |
| Speed | Fast | Slower (100x trees) |
Assumptions
Limitations
Pitfalls
- Extrapolation: RF predicts by averaging training labels. It cannot predict values outside the range seen in training (e.g., if max historical price was $100, it can never predict $110). Linear regression can.
- Slow Prediction: Real-time applications might find 500 trees too slow to evaluate.
- Black Box: Hard to explain why it rejected a loan, other than "500 trees voted No".
Python Implementation
from sklearn.ensemble import RandomForestClassifier
# 1. Fit Model (100 trees)
rf = RandomForestClassifier(n_estimators=100, max_depth=None, n_jobs=-1)
rf.fit(X_train, y_train)
# 2. Variable Importance
import pandas as pd
importance = pd.Series(rf.feature_importances_, index=X.columns)
importance.sort_values(ascending=False).plot(kind='barh')
Related Concepts
- Decision Tree - The building block.
- Gradient Boosting (XGBoost) - The sequential alternative (often slightly more accurate but harder to tune).
- Bootstrap Methods - The sampling technique used.
- Ensemble Methods