Decision Trees — Understanding Explainable AI

Grant Holtes
Towards Data Science
5 min readMar 1, 2018

--

Explainable AI or XAI is a sub-category of AI where the decisions made by the model can be interpreted by humans, as opposed to “black box” models. As AI moves from correcting our spelling and targeting ads to driving our cars and diagnosing patients, the need to verify and justify the conclusions being reached is beginning to be prioritised.

To begin to delve into the field, lets look at one simple XAI model: the decision tree. Decision trees can be easily read and even mimic a human approach to decision making by breaking the choice into many small sub-choices. A simple example is how one may evaluate local universities when the leave high school. Given the student has a course in mind a simple decision making process could be:

How the student reached their conclusion can be easily justified if a third party has access to the “model” and the required variables.

This same structure can be applied with supervised learning, with the goal of creating a decision tree which best describes the training data. This model can then be used to understand the relationship between the variables or in a predictive application.

The Algorithm

The construction of the decision tree is done as a recusive process.

  1. Estimate which variable gives the largest information gain. Information gain is the reduction in entropy of the dependant variable when the state of the independent variable is known.
    Lot of big words there.
    Essentially this measures how much more organised the independant variable is when we split it into groups according to the dependant variable’s value.
  2. The dependant variable which provides the greatest increase in organisation is selected and the dataset is split according to this variable.
  3. At this point one of three things must be true:
    - The dependant variable now takes only one value. In this case this branch of the tree is complete, and we have reached our “decision”.
    - The dependant variable takes >1 values. Here we simply go back to step one and try to narrow it down further.
    -The dependant variable takes >1 values but we have no more independent variables to split the data by. Here we simply say what values the decision could take and estimate a probability for each according to the relative proportions of each option.

Calculating Information Gain

First, we need a formula for organisation or entropy. To calculate the entropy of our dependant variable we use:

The following chart shows how the entropy of Y (where Y has two states) changes with the probability of each state. As the probability of one state is 0 the entropy is also 0, as this is when Y is most organised, while when Y is evenly split between the two states the entropy is at its maximum.

Expanding this to add the effect on entropy of knowing an independent variable, X:

Information gain is now given as the difference between the entropy when we know X and when we don’t.

Give me the code!

Some functions to calculate entropy and create the graphics have not been included here.

def decide(Y, X_dict, previous_node):
#Calc info gain for each X
max_IG = 0
var_to_split = None
#Calculate information gain to find out which variable to split on
for x in X_dict.keys():
IG = InfoGain(Y, X_dict[x])
if IG > max_IG:
max_IG = IG
var_to_split = x
#See if all variables have been used and none are left.
if var_to_split == None:
Y_options = list(set(Y))
tot = float(len(Y))
count = [0 for _ in range(len(Y_options))]
for op in range(len(Y_options)):
for i in range(len(Y)):
if Y[i] == op:
count[op] += 1
#Format Node label
Prob = ""
for op in range(len(Y_options) - 1):
Prob += "P("
Prob += str(Y_options[op]) + ")-> "
P = float(count[op]) / tot
Prob += "{0:.2f}".format(P)
#Make a new node
nodename = node(Prob, color = "orange")
edge(previous_node, nodename)
else:
print("Splitting on {0}".format(var_to_split))
X_options = list(set(X_dict[var_to_split]))
#Make decision variable node
Var_nodename = node(var_to_split, color = "red")
edge(previous_node, Var_nodename)
#Init new data for each new branch of the tree
for X_option in X_options:
X_nodename = node(str(X_option))
edge(Var_nodename, X_nodename)
New_X_dict = {}
#get remaining variables
for key in X_dict.keys():
if key != var_to_split:
New_X_dict[key] = []
New_Y = []
#Populate
for i in range(len(Y)):
if X_dict[var_to_split][i] == X_option:
New_Y.append(Y[i])
for key in New_X_dict.keys():
New_X_dict[key].append(X_dict[key][i])
#Check if this is a terminal node:
if len(set(New_Y)) == 1:
nodename = node(str(New_Y[0]), color = "green")
edge(X_nodename, nodename)
else:
#No terminal node, so try again
decide(New_Y, New_X_dict, X_nodename)
Y, X_dict = import_golf('golf.csv') #import data
root_node = node("root", color = "blue") #Create the first node
decide(Y, X_dict, root_node) #start the tree

For the golf dataset the following tree is output, which is an easy way to interpret the decision making process.

--

--