D is for Decision Tree
What is a decision tree and how do we implement it in Python?
What is a Decision Tree?
Decision Trees are an algorithm used for either classification or regression tasks.
How do they work?
While the math is fairly complicated, the concept is incredibly straightforward:
- start with all observations in one group
- identify a binary question, (i.e., yes/no, over/under) resulting in two groups which are as distinct from eachother as possible
- repeat step two until every subgroup is homogeneous or some other metric has been achieved
What are the advantages?
Transparency: They are considered a white/glass box algorithm because you can see what decisions the algorithm made which leads to,
Interpretability: Again, since you can see which decisions were made, it's easy to comprehend and explain the predictions.
Ease: Data trees DO NOT require feature scaling or normalization1.
What are the disadvantages?
Overfitting: Since they try to find the purest groups, they have a tendency to overfit.
Non-Linear Data: Relationships in the data between features are not considered.
How do we train a decision tree?
I'm a task-based person so let's set a problem.
Step 0: Frame the Problem
Can we determine who would have survived the Titanic?
Step 1: Collect/Load our Data
Now let's get the data.
import seaborn as sns
sns.set(palette="colorblind")
titanic = sns.load_dataset("titanic")
titanic.isnull().sum()
Yes we do.
What are the data types?
titanic.dtypes
We have a mix.
Let's get a feel for the values by looking at the first five rows.
titanic.head().T
Hmmm. Looks like seaborn
has already done some feature engineering (i.e, alone
is a combination of sibsp
and parch
).
Much appreciated
Now, since this is a toy dataset, I'm not going to do much EDA but, if you want to learn more, you can find some good examples of it here and here.
However, here is an obligatory bar chart of who survived.
sns.catplot(y="survived",
hue="sex",
kind="count",
data=titanic);
titanic = titanic.dropna(axis=1)
titanic.info();
- keeping only numerical features
titanic = titanic.select_dtypes(include=['float64', 'int64'])
titanic.info()
Cool.
Time to make the train test split.
from sklearn.model_selection import train_test_split
x = titanic.iloc[:, 1:]
y = titanic.survived
x_train, x_test, y_train, y_test = train_test_split(x, y, test_size=0.3, random_state=42)
and then
from sklearn.tree import DecisionTreeClassifier
tree_clf = DecisionTreeClassifier(random_state=0)
tree_clf = tree_clf.fit(x_train, y_train)
from sklearn.tree import export_graphviz
from IPython.display import Image
import pydotplus
dot_data = export_graphviz(tree_clf,
out_file=None,
feature_names=x_train.columns,
class_names=['perished', 'survived'],
rounded=True,
filled=True)
# Draw graph
graph = pydotplus.graph_from_dot_data(dot_data)
# Show graph
Image(graph.create_png())
That is a gigantic tree!!! In reality, we'd probably prune it by limiting the depth of the tree in an attempt to avoid overfitting.
But, let see how it performs.
#Predict the response for test dataset
y_pred = tree_clf.predict(x_test)
from sklearn.metrics import accuracy_score
# Model Accuracy, how often is the classifier correct?
print("The training accuracy of our model is %.2f." % (accuracy_score(y_train, tree_clf.predict(x_train))*100))
print("The test accuracy of our model is %.2f." % (accuracy_score(y_test, y_pred) *100))
Yep, definitely overfitting.
So, let's prune our tree by limiting the max_depth
to three:
tree_clf = DecisionTreeClassifier(max_depth=3,
random_state=0)
retrain our model,
tree_clf = tree_clf.fit(x_train, y_train)
and visualize it one more time.
dot_data = export_graphviz(tree_clf,
out_file=None,
feature_names=x_train.columns,
class_names=["perished", "survived"],
rounded=True,
filled=True)
# Draw graph
graph = pydotplus.graph_from_dot_data(dot_data)
# Show graph
Image(graph.create_png())
That is a much easier tree to comprehend but does it perform better?
#Predict the response for test dataset
y_pred = tree_clf.predict(x_test)
# Model Accuracy, how often is the classifier correct?
print("The training accuracy of our model is %.2f." % (accuracy_score(y_train, tree_clf.predict(x_train))*100))
print("The test accuracy of our model is %.2f." % (accuracy_score(y_test, y_pred) *100))
Indeed it does!
Now imagine how much better it would perform if we added gender
and age
to the model.
Summary
Decision trees are incredibly powerful because of their:
- ease of use
- flexibility
- simplicity to understand
However, overfitting can be a real issue so, as with all algorithms, check the training accuracy and tune the hyperparameters as needed.
Also, some of the splits defy common sense; for example, how can anyone have half a sibling or half a parent?2
Therefore, while the trees maybe easy to interpret, they might not always be logical so, as always, we have to be prepared to defend what we've made.
Happy coding!
Further Reading
Hands-On Machine Learning with Scikit-Learn and TensorFlow: 'Chapter 6 Decision Trees'
DataCamp Decision Tree Classifier Tutorial
Understanding the Mathematics Behind Decision Trees
Visualizing Decision Trees
Footnotes
1. Geron, 2019, Page 177↩
2. If anyone knows how to set the tree to only split on whole values, please write it in the comments below.↩