Sunday, February 3, 2008

Classification via Decision Tree

Decision Tree is another model-based learning approach where the output is a tree. Here, we are given a set of data with structure [x1, x2 …, y] is presented. (in this case y is the output). The learning algorithm will learn (from the training set) a decision tree and use that to predict the output y for future seen data.

x1, x2 ... can be either numeric or categorical. y1 is categorical

In the decision tree, the leaf node contains relatively "pure" data represented by a histogram {class_j => count}. The intermediate node is called a "decision node" containing a test against an input attribute (e.g. x2) to a constant value. Two branches are output according to whether the decision is evaluate to be true or false. If x2 is a categorical value, the test will be a equality test (e.g. if "weather" equals "sunny"). If x2 is numeric, the test will be a greater than / less than test (e.g. if "age" >= 40).


Building the Decision Tree

We start at the root node which contains all training samples. We need to figure out what should be our first test. Our strategy is to pick the test such that it divides the training samples into two groups which has the highest sum of "purity"

A set of data records is "pure" if all their outcome is gravitated towards a particular value, otherwise it is impure. Purity can be measurable by Entropy or Gini Impurity function.

Gini is measured by calculating the probability of picking two records from a set such that their outcome is different.

Entropy is measured by calculate the following ...
sum_over_j(P(class_j) * log (P(class_j)))

Note that the term P * logP is close to zero in either case when P is close to zero or when P is close to one. Entropy is large when P is about 0.5. The higher the entropy, the lower the purity.

Keep doing the following until the overall purity is not improved further
  1. Try all combination of x1, x2 ... / value1a, value1b, value2a, value2b ...
  2. Pick one combination such that the divided data set has a better combined purity (which is the weighted sum of purity based on the frequency)
  3. To avoid the decision tree overfits the training data, we divide the tree only when the purity after divide exceed a threshold value (called pre-pruning)


def build_tree(set, func, threshold)
orig_purity = calculate_purity(func, set)
purity_gain = 0
best_attribute = nil
best_test_value = nil
best_split_left_set = nil
best_split_right_set = nil

for x in each attribute
for x_value in each possible value of x
left_set, right_set = split(x, x_value, set)
left_freq = left_set.size / set.size
right_freq = right_set.size / set.size
left_purity = calculate_purity(func, left_set)
right_purity = calculate_purity(func, right_set)
split_purity = left_freq*left_purity + right_freq*right_purity
improvement = split_purity - orig_purity
if improvement > purity_gain
purity_gain = improvement
best_attribute = x
best_test_value = x_value
best_split_left_set = left_set
best_split_right_set = right_set

if purity_gain > threshold
node = DecisionNode.new(best_attr, best_value)
node.left = build_tree(best_fit_left_set, func, threshold)
node.right = build_tree(best_fit_right_set, func, threshold)
else
node = DecisionNode.new
node.result = calculate_outcome_freq(set)


Over Fitting and Tree Pruning

Since the training data set is not exactly same as the actual population, we may biased towards some characteristics of the training data set which doesn't exist in actual population. Such biases is called over fitting and we want to avoid it.

In above algorithm, we create an intermediate node when the purity_gain is significant. We can also do "post pruning" such that we build the tree first and then working backward trying to merge the leaf nodes if the decrease of purity is small.

We also can set aside 1/4 of training data as validation data. So we use the 3/4 of training data to build the tree first and then use the 1/4 validation data to prune the tree. This approach will reduce the training data size and so only practical when the training data set is large.


Classification and missing data handling

When a new query point arise, we start traversing the decision tree from its root by answering the question at each decision node until we reach the leaf node, from there we pick the class with the highest number of instance count.

However, what if the value of some attributes (say x3, x5) are missing. When we reach a decision node where x3 need to be tested, we cannot proceed.

One simple solution is to fill-in the most likely value of x3 based on the probability distribution of x3 within the training set. We can also throw a dice based on the probability distribution. Or we can use a tree-way branch by adding "missing data" as one of the condition.

Another more sophisticated solution is to walk down both path if the attributed under test is missed. We know the probability distribution at the junction point, which used to calculate the weight on each branch. The final result will be aggregated according to the weight.

No comments: