I N F O A R Y A N

Decision Tree - Explained with Python Sklearn

Decision trees are a prevalent machine learning algorithm known for their simplicity, interpretability, and effectiveness in both classification and regression tasks. They resemble the flowcharts we encounter in everyday life, guiding us through a series of decisions to reach a conclusion. In the realm of data science, decision trees excel at transforming complex datasets into understandable rules that can be applied to predict outcomes or make informed decisions.

In this blog, we will delve into the fundamentals, explore the mathematical foundations, discuss its strengths and weakness, evaluate performance metrics, and finally, showcase its implementation using Python with the scikit-learn library.

Flow of Article:

  1. What is Decision Tree? 
  2. Mathematical Explanation
  3. Classification and Regression
  4. Strength and Weakness
  5. Python project
  6. Real-life Uses
  7. Interview Questions

You may also want to explore Logistic Regression, Best 10 Regression Model Coded, Linear Regression, Transfer Learning using Regression, or Automated EDA.

 

What is a Decision Tree?

Decision trees are a prevalent machine learning algorithm known for their simplicity, interpretability, and effectiveness in both classification and regression tasks. They resemble the flowcharts we encounter in everyday life, guiding us through a series of decisions to reach a conclusion. In the realm of data science, decision trees excel at transforming complex datasets into understandable rules that can be applied to predict outcomes or make informed decisions. Research level information can be seen here.

 

Anatomy of Decision Tree

A decision tree, like a living organism, has a well-defined structure:

  • Root Node: The decision tree’s genesis, the point where the journey begins.

  • Internal Nodes: Decision points that split the data based on specific criteria.

  • Leaf Nodes: The terminal points, where the final predictions or outcomes are revealed.

  • Branches: Connections between nodes, representing the different paths taken based on the decision rules.

 

Maths Behind It

Decision trees are a type of machine learning algorithm that can be used for both classification and regression tasks. They work by repeatedly splitting the data into smaller and smaller subsets, based on certain criteria, until a final prediction is made.

The process of building a decision tree is typically implemented using an algorithm called ID3 (Iterative Dichotomiser 3). ID3 uses a measure called information gain to determine which feature to split on at each node of the tree.

Now let’s discuss how a split is decided? Which feature, what threshold?

Firstly we calculate the Entropy for each feature of the data according to the formula:

Where,

  • S – The dataset for which entropy is being calculated.
  • X – The set of classes in S
  • p(x) – The proportion of the number of elements in class x to the number of elements in set S

When H(S)=0, the set S is perfectly classified (i.e. all elements in S are of the same class).

Now, after entropy calculation we calculate the Information gain related to each feature. 

Information gain is a measure of how much splitting a set of data on a particular feature reduces the uncertainty about the target variable. Information gain IG(A) is the measure of the difference in entropy from before to after the set S is split on an attribute A. In other words, how much uncertainty in S was reduced after splitting set S on attribute A.The formula for information gain is:

 

 

Where,

  • H(S) – Entropy of set S
  • T – The subsets created from splitting set S by attribute A.
  • p(t) – The proportion of the number of elements in t to the number of elements in set S
  • H(t) – Entropy of subset t

In ID3, information gain can be calculated (instead of entropy) for each remaining attribute. The attribute with the largest information gain is used to split the set S on this iteration.

 

How to use in Classification and Regression

Classification Trees

Classification trees specialize in predicting categorical outcomes, like classifying emails as spam or not spam. To achieve this, the tree progressively partitions the data based on specific features, leading to a specific class label at each leaf node.

Regression Trees

Regression trees, on the other hand, are masters of predicting continuous outcomes, like forecasting house prices or estimating customer lifetime value. The tree structure remains similar, but the leaf nodes no longer contain class labels. Instead, they hold the average value of the target variable for the corresponding subset of data.

During training, the algorithm selects the feature that minimizes the mean squared error (MSE) at each internal node. MSE measures the average squared difference between the predicted values and the actual values.

 

Strengths of Decision Trees

  • Interpretability: Decision trees are one of the most interpretable machine learning algorithms. This is because their structure is easy to understand, even for non-experts. This makes them a good choice for tasks where it is important to understand how the model is making its predictions.

  • Non-parametric: Decision trees do not make any assumptions about the underlying data distribution. This makes them versatile for various data types, including categorical, numerical, and ordinal data.

  • Robust to outliers: Decision trees are less sensitive to outliers compared to some other algorithms. This is because they do not rely on outliers to make their predictions.

  • Efficient training: Decision trees are relatively fast to train, even with large datasets. This is because they do not require complex optimization algorithms.

 

Weaknesses of Decision Trees

  • Overfitting: Decision trees are prone to overfitting, where they memorize the training data and struggle to generalize well to unseen data. This is because they tend to grow very large if not pruned.

  • High variance: Decision trees are high variance algorithms, which means that they can produce different predictions depending on the training data. This makes them less reliable than some other algorithms, such as logistic regression.

  • Feature importance: Decision trees can only handle a limited number of features effectively. If there are too many features, the tree will become too complex and difficult to interpret.

  • Continuous values: Decision trees are not well-suited for predicting continuous values. This is because they are more effective at classifying data into discrete categories.

Python Code Implementation

Let’s implement it using the scikit-learn library with a simple example:

# Import necessary libraries
import pandas as pd
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from sklearn.tree import DecisionTreeClassifier
from matplotlib.pyplot import plot_tree

# Load the Iris dataset
iris = load_iris()

# Convert the dataset to pandas DataFrame
iris_df = pd.DataFrame(iris.data, columns=iris.feature_names)
iris_df[‘target’] = iris.target

# Split the data into training and testing sets
X_train, X_test, y_train, y_test = train_test_split(iris_df.drop(‘target’, axis=1), iris_df[‘target’], test_size=0.2)

# Create a decision tree classifier
clf = DecisionTreeClassifier()

# Train the decision tree classifier
clf.fit(X_train, y_train)

# Make predictions on the test set
y_pred = clf.predict(X_test)

# Evaluate the accuracy of the decision tree classifier
accuracy = clf.score(X_test, y_test)
print(“Accuracy:”, accuracy)

# Visualize the decision tree
plot_tree(clf, feature_names=iris.feature_names, class_names=iris.target_names)

Above is the example from another dataset that makes you visualise how the plot_tree function will make you visualise your tree. 

 

Applications of Decision Trees

Decision trees find applications in a wide range of domains:

  • Fraud detection: Identifying patterns in financial transactions indicative of fraudulent activity.

  • Medical diagnosis: Assisting in diagnostic decision-making based on patient data.

  • Customer segmentation: Grouping customers based on their characteristics for targeted marketing campaigns.

  • Risk assessment: Evaluating the likelihood of events like loan defaults or insurance claims.

 

Most important Interview Questions:

  1. What is a decision tree?

A decision tree is a machine learning algorithm that makes predictions by recursively splitting the data into smaller and smaller subsets, based on certain criteria, until a final prediction is made.

  1. What are the different types of decision trees?

There are two main types of decision trees: classification trees and regression trees. Classification trees are used to predict categorical outcomes, while regression trees are used to predict continuous values.

  1. What is information gain?

Information gain is a measure of how much splitting a set of data on a particular feature reduces the uncertainty about the target variable. The higher the information gain, the more relevant the feature is to the prediction task.

  1. What is entropy?

Entropy is a measure of the uncertainty in a set of data. The higher the entropy, the more uncertain the data is.

  1. What is Gini impurity?

Gini impurity is another measure of the uncertainty in a set of data. It is similar to entropy, but it is more sensitive to differences in class distribution.

  1. What is overfitting?

Overfitting occurs when a decision tree memorizes the training data and struggles to generalize well to unseen data. This can be prevented by pruning the tree, which involves removing nodes that do not improve the model’s performance.

  1. What is pruning?

Pruning is a technique for preventing overfitting in decision trees. It involves removing nodes that do not improve the model’s performance.

  1. How can you improve the interpretability of a decision tree?

There are a few things you can do to improve the interpretability of a decision tree:

  • Limit the depth of the tree.
  • Limit the number of features.
  • Use descriptive feature names.
  1. What are some of the strengths of decision trees?

Decision trees are known for their simplicity, interpretability, and effectiveness. They are also non-parametric, which means that they do not make assumptions about the underlying data distribution.

  1. What are some of the weaknesses of decision trees?

Decision trees are prone to overfitting and high variance. They can also only handle a limited number of features effectively.