Decision Tree Regression - From Scratch - Python
In the realm of data science, decision trees are revered for their versatility. While often associated with classification tasks, decision trees spread their roots into the realm of regression, offering a robust approach to predicting continuous outcomes.
In this blog, we will embark on a journey through the captivating world of decision trees for regression, unraveling the intricacies, delving into the mathematics, and exploring the advantages and disadvantages of this powerful technique.
For Classification – Refer to this article
Flow of Article:
- What is Decision Tree Regressor?
- Mathematical Explanation
- Strength and Weakness
- Coding the Decision Tree from scratch
- Conclusion
You may also want to explore Logistic Regression, Best 10 Regression Model Coded, Linear Regression, Transfer Learning using Regression, or Automated EDA.
What is a Decision Tree?
Decision trees are a prevalent machine learning algorithm known for their simplicity, interpretability, and effectiveness in both classification and regression tasks. They resemble the flowcharts we encounter in everyday life, guiding us through a series of decisions to reach a conclusion. In the realm of data science, decision trees excel at transforming complex datasets into understandable rules that can be applied to predict outcomes or make informed decisions. Research level information can be seen here.
Decision tree Regressor
At its core, a decision tree for regression is a predictive model that maps input features to continuous output values. Unlike classification trees that predict categorical labels, regression trees aim to forecast numeric outcomes. Each leaf node of the tree represents a predicted value based on the combination of features encountered along the path from the root.
Anatomy of Decision Tree
A decision tree, like a living organism, has a well-defined structure:
Root Node: The decision tree’s genesis, the point where the journey begins.
Internal Nodes: Decision points that split the data based on specific criteria.
Leaf Nodes: The terminal points, where the final predictions or outcomes are revealed.
Branches: Connections between nodes, representing the different paths taken based on the decision rules.
Maths Behind It
The magic of decision trees lies in their ability to recursively split the data into subsets based on feature thresholds. This process is driven by mathematical criteria, typically involving an impurity function. For regression trees, the impurity is commonly measured using the Mean Squared Error (MSE). The tree seeks to minimize the MSE at each split, ensuring that the predicted values are as close as possible to the true values in each leaf.
The Impurity Function – Mean Squared Error (MSE)
The MSE for a node with N data points and predicted values y1,y2,…,yN is calculated as:
data:image/s3,"s3://crabby-images/27210/2721070e56d00d59cce464c334f978351585a21a" alt=""
where y_bar is the mean of the predicted values in the node. The goal is to minimize this measure of impurity, guiding the tree towards optimal splits.
Advantages of Decision Trees for Regression:
Non-Linearity Handling: Decision trees can model complex, non-linear relationships in the data. This flexibility is particularly advantageous when dealing with datasets that don’t adhere to linear assumptions.
Interpretability: The transparent nature of decision trees allows for easy interpretation. Users can trace the decision-making process and understand how input features influence the predicted outcomes.
Robustness to Outliers: Decision trees are inherently robust to outliers. The tree structure is less influenced by extreme values, making it suitable for datasets with irregularities.
Dissecting the Disadvantages:
Overfitting: Decision trees are prone to overfitting, capturing noise in the training data and producing overly complex models. Pruning techniques and careful parameter tuning are essential to mitigate this risk.
Sensitivity to Small Variations: Small changes in the input data can lead to different tree structures. This sensitivity can result in high variability between models trained on similar datasets.
Limited Extrapolation: Decision trees struggle with extrapolation, meaning they may not perform well on data points outside the range of the training data. This limitation should be considered when applying regression trees to real-world scenarios.
Python Code Implementation
Now, we’ll embark on a journey to demystify the process of building a decision tree for regression from scratch, without relying on machine learning libraries. We’ll dive into the code, step by step, and explore how each component contributes to the creation of this powerful predictive model.
Defining Dataset
Our first block of code involves loading the dataset and splitting it into features and target variables. In our example, we’ll assume a housing production dataset with numerical features and a numerical target variable.
import numpy as np
import pandas as pd
# Load the dataset (Replace ‘your_dataset.csv’ with the actual file)
# Assume the dataset has columns ‘feature1’, ‘feature2’, …, ‘target’
data = pd.read_csv(‘your_dataset.csv’)
# Split the data into features and target variable
X = data.drop(‘target’, axis=1)
y = data[‘target’]
Node Structure: The Heart of the Tree:
Now, let’s introduce the concept of a node. In our decision tree, a node represents a point where the tree makes a decision. We’ve defined a Node
class with attributes like depth
, feature_index
, threshold
, and value
. The left
and right
attributes represent the two branches that result from a decision.
class Node:
def __init__(self, depth, max_depth=None):
self.depth = depth
self.max_depth = max_depth
self.feature_index = None
self.threshold = None
self.value = None
self.left = None
self.right = None
Impurity Measure: Mean Squared Error (MSE):
To guide the tree’s decision-making process, we need a measure of impurity. For regression tasks, Mean Squared Error (MSE) is a suitable metric. The calculate_mse
function computes the MSE of a given set of target values.
# Mean Squared Error (MSE) as the impurity measure
def calculate_mse(y):
if len(y) == 0:
return 0
mean = np.mean(y)
return np.mean((y – mean) ** 2)
Data Splitting: The Foundation of Decision Making:
The split_data
function is crucial for determining how the data is divided at each node based on a chosen feature and threshold. This function is fundamental to the decision tree’s ability to make informed splits.
# Split the data based on a given feature and threshold
def split_data(X, y, feature_index, threshold):
left_mask = X[:, feature_index] <= threshold
right_mask = ~left_mask
return X[left_mask], X[right_mask], y[left_mask], y[right_mask]
Finding the Best Split:
The find_best_split
function is the engine of the decision tree. It iterates over features and thresholds, searching for the split that minimizes the overall MSE. This function plays a pivotal role in constructing an effective decision tree.
# Find the best split for a node
def find_best_split(X, y):
m, n = X.shape
if m <= 1:
return None, None
best_mse = float(‘inf’)
best_feature = None
best_threshold = None
for feature_index in range(n):
thresholds = np.unique(X[:, feature_index])
for threshold in thresholds:
X_left, X_right, y_left, y_right = split_data(X, y, feature_index, threshold)
mse = calculate_mse(y_left) + calculate_mse(y_right)
if mse < best_mse:
best_mse = mse
best_feature = feature_index
best_threshold = threshold
return best_feature, best_threshold
Building the Tree:
With the foundation laid, it’s time to construct the decision tree itself. The build_tree
function recursively builds the tree, considering depth and a maximum depth parameter to control the tree’s size.
# Build the decision tree recursively
def build_tree(X, y, depth, max_depth=None):
if depth == max_depth or len(np.unique(y)) == 1:
leaf = Node(depth, max_depth)
leaf.value = np.mean(y)
return leaf
feature_index, threshold = find_best_split(X, y)
if feature_index is None:
leaf = Node(depth, max_depth)
leaf.value = np.mean(y)
return leaf
node = Node(depth, max_depth)
node.feature_index = feature_index
node.threshold = threshold
X_left, X_right, y_left, y_right = split_data(X, y, feature_index, threshold)
node.left = build_tree(X_left, y_left, depth + 1, max_depth)
node.right = build_tree(X_right, y_right, depth + 1, max_depth)
return node
Making Predictions:
With the decision tree constructed, we need to make predictions. The predict_sample
and predict
functions traverse the tree to predict the target variable for individual data points or entire datasets.
# Make predictions for a single data point
def predict_sample(tree, x):
if tree.left is None and tree.right is None:
return tree.value
if x[tree.feature_index] <= tree.threshold:
return predict_sample(tree.left, x)
else:
return predict_sample(tree.right, x)
# Make predictions for a set of data points
def predict(tree, X):
return np.array([predict_sample(tree, x) for x in X])
Evaluation:
Finally, we evaluate the performance of our decision tree on the test set. The calculate_mse_predictions
function computes the Mean Squared Error between the true and predicted values.
# Mean Squared Error for predictions
def calculate_mse_predictions(y_true, y_pred):
return np.mean((y_true – y_pred) ** 2)
# Split the data into training and testing sets
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
# Build the decision tree
max_depth = 5 # Set the maximum depth of the tree
tree = build_tree(X_train.values, y_train.values, depth=0, max_depth=max_depth)
# Make predictions on the test set
y_pred = predict(tree, X_test.values)
# Evaluate the model
mse = calculate_mse_predictions(y_test.values, y_pred)
print(f’Mean Squared Error: {mse}’)
Conclusion:
In this journey through the code, we’ve uncovered the intricacies of building a decision tree for regression from scratch. Each function and component contributes to the tree’s decision-making prowess, from impurity measures to data splitting and recursive tree construction. Armed with this understanding, you’re equipped to
Applications of Decision Trees
Decision trees find applications in a wide range of domains:
Fraud detection: Identifying patterns in financial transactions indicative of fraudulent activity.
Medical diagnosis: Assisting in diagnostic decision-making based on patient data.
Customer segmentation: Grouping customers based on their characteristics for targeted marketing campaigns.
Risk assessment: Evaluating the likelihood of events like loan defaults or insurance claims.
Most important Interview Questions:
- What is a decision tree?
A decision tree is a machine learning algorithm that makes predictions by recursively splitting the data into smaller and smaller subsets, based on certain criteria, until a final prediction is made.
- What are the different types of decision trees?
There are two main types of decision trees: classification trees and regression trees. Classification trees are used to predict categorical outcomes, while regression trees are used to predict continuous values.
- What is information gain?
Information gain is a measure of how much splitting a set of data on a particular feature reduces the uncertainty about the target variable. The higher the information gain, the more relevant the feature is to the prediction task.
- What is entropy?
Entropy is a measure of the uncertainty in a set of data. The higher the entropy, the more uncertain the data is.
- What is Gini impurity?
Gini impurity is another measure of the uncertainty in a set of data. It is similar to entropy, but it is more sensitive to differences in class distribution.
- What is overfitting?
Overfitting occurs when a decision tree memorizes the training data and struggles to generalize well to unseen data. This can be prevented by pruning the tree, which involves removing nodes that do not improve the model’s performance.
- What is pruning?
Pruning is a technique for preventing overfitting in decision trees. It involves removing nodes that do not improve the model’s performance.
- How can you improve the interpretability of a decision tree?
There are a few things you can do to improve the interpretability of a decision tree:
- Limit the depth of the tree.
- Limit the number of features.
- Use descriptive feature names.
- What are some of the strengths of decision trees?
Decision trees are known for their simplicity, interpretability, and effectiveness. They are also non-parametric, which means that they do not make assumptions about the underlying data distribution.
- What are some of the weaknesses of decision trees?
Decision trees are prone to overfitting and high variance. They can also only handle a limited number of features effectively.
Decision Tree, Decision tree from scratch, python decision tree, Python Projects, ML Projects, Machine learning projects, Random forest, Regression, Regressor, Decision tree regressor