{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "Code to accompany Machine Learning Recipes #8. We'll write a Decision Tree Classifier, in pure Python. Below each of the methods, I've written a little demo to help explain what it does." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# For Python 2 / 3 compatability\n", "from __future__ import print_function" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Toy dataset.\n", "# Format: each row is an example.\n", "# The last column is the label.\n", "# The first two columns are features.\n", "# Feel free to play with it by adding more features & examples.\n", "# Interesting note: I've written this so the 2nd and 5th examples\n", "# have the same features, but different labels - so we can see how the\n", "# tree handles this case.\n", "training_data = [\n", " ['Green', 3, 'Apple'],\n", " ['Yellow', 3, 'Apple'],\n", " ['Red', 1, 'Grape'],\n", " ['Red', 1, 'Grape'],\n", " ['Yellow', 3, 'Lemon'],\n", "]" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Column labels.\n", "# These are used only to print the tree.\n", "header = [\"color\", \"diameter\", \"label\"]" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def unique_vals(rows, col):\n", " \"\"\"Find the unique values for a column in a dataset.\"\"\"\n", " return set([row[col] for row in rows])" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#######\n", "# Demo:\n", "unique_vals(training_data, 0)\n", "# unique_vals(training_data, 1)\n", "#######" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def class_counts(rows):\n", " \"\"\"Counts the number of each type of example in a dataset.\"\"\"\n", " counts = {} # a dictionary of label -> count.\n", " for row in rows:\n", " # in our dataset format, the label is always the last column\n", " label = row[-1]\n", " if label not in counts:\n", " counts[label] = 0\n", " counts[label] += 1\n", " return counts" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#######\n", "# Demo:\n", "class_counts(training_data)\n", "#######" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def is_numeric(value):\n", " \"\"\"Test if a value is numeric.\"\"\"\n", " return isinstance(value, int) or isinstance(value, float)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#######\n", "# Demo:\n", "is_numeric(7)\n", "# is_numeric(\"Red\")\n", "#######" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "class Question:\n", " \"\"\"A Question is used to partition a dataset.\n", "\n", " This class just records a 'column number' (e.g., 0 for Color) and a\n", " 'column value' (e.g., Green). The 'match' method is used to compare\n", " the feature value in an example to the feature value stored in the\n", " question. See the demo below.\n", " \"\"\"\n", " def __init__(self, column, value):\n", " self.column = column\n", " self.value = value\n", "\n", " def match(self, example):\n", " # Compare the feature value in an example to the\n", " # feature value in this question.\n", " val = example[self.column]\n", " if is_numeric(val):\n", " return val >= self.value\n", " else:\n", " return val == self.value\n", "\n", " def __repr__(self):\n", " # This is just a helper method to print\n", " # the question in a readable format.\n", " condition = \"==\"\n", " if is_numeric(self.value):\n", " condition = \">=\"\n", " return \"Is %s %s %s?\" % (header[self.column], condition, str(\n", " self.value))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#######\n", "# Demo:\n", "# Let's write a question for a numeric attribute\n", "Question(1, 3)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# How about one for a categorical attribute\n", "q = Question(0, 'Green')\n", "q" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Let's pick an example from the training set...\n", "example = training_data[0]\n", "# ... and see if it matches the question\n", "q.match(example) # this will be true, since the first example is Green.\n", "#######" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def partition(rows, question):\n", " \"\"\"Partitions a dataset.\n", "\n", " For each row in the dataset, check if it matches the question. If\n", " so, add it to 'true rows', otherwise, add it to 'false rows'.\n", " \"\"\"\n", " true_rows, false_rows = [], []\n", " for row in rows:\n", " if question.match(row):\n", " true_rows.append(row)\n", " else:\n", " false_rows.append(row)\n", " return true_rows, false_rows" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#######\n", "# Demo:\n", "# Let's partition the training data based on whether rows are Red.\n", "true_rows, false_rows = partition(training_data, Question(0, 'Red'))\n", "# This will contain all the 'Red' rows.\n", "true_rows" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# This will contain everything else.\n", "false_rows\n", "#######" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def gini(rows):\n", " \"\"\"Calculate the Gini Impurity for a list of rows.\n", "\n", " There are a few different ways to do this, I thought this one was\n", " the most concise. See:\n", " https://en.wikipedia.org/wiki/Decision_tree_learning#Gini_impurity\n", " \"\"\"\n", " counts = class_counts(rows)\n", " impurity = 1\n", " for lbl in counts:\n", " prob_of_lbl = counts[lbl] / float(len(rows))\n", " impurity -= prob_of_lbl**2\n", " return impurity" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#######\n", "# Demo:\n", "# Let's look at some example to understand how Gini Impurity works.\n", "#\n", "# First, we'll look at a dataset with no mixing.\n", "no_mixing = [['Apple'], ['Apple']]\n", "# this will return 0\n", "gini(no_mixing)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Now, we'll look at dataset with a 50:50 apples:oranges ratio\n", "some_mixing = [['Apple'], ['Orange']]\n", "# this will return 0.5 - meaning, there's a 50% chance of misclassifying\n", "# a random example we draw from the dataset.\n", "gini(some_mixing)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Now, we'll look at a dataset with many different labels\n", "lots_of_mixing = [['Apple'], ['Orange'], ['Grape'], ['Grapefruit'],\n", " ['Blueberry']]\n", "# This will return 0.8\n", "gini(lots_of_mixing)\n", "#######" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def info_gain(left, right, current_uncertainty):\n", " \"\"\"Information Gain.\n", "\n", " The uncertainty of the starting node, minus the weighted impurity of\n", " two child nodes.\n", " \"\"\"\n", " p = float(len(left)) / (len(left) + len(right))\n", " return current_uncertainty - p * gini(left) - (1 - p) * gini(right)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#######\n", "# Demo:\n", "# Calculate the uncertainy of our training data.\n", "current_uncertainty = gini(training_data)\n", "current_uncertainty" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# How much information do we gain by partioning on 'Green'?\n", "true_rows, false_rows = partition(training_data, Question(0, 'Green'))\n", "info_gain(true_rows, false_rows, current_uncertainty)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# What about if we partioned on 'Red' instead?\n", "true_rows, false_rows = partition(training_data, Question(0, 'Red'))\n", "info_gain(true_rows, false_rows, current_uncertainty)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# It looks like we learned more using 'Red' (0.37), than 'Green' (0.14).\n", "# Why? Look at the different splits that result, and see which one\n", "# looks more 'unmixed' to you.\n", "true_rows, false_rows = partition(training_data, Question(0, 'Red'))\n", "\n", "# Here, the true_rows contain only 'Grapes'.\n", "true_rows" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# And the false rows contain two types of fruit. Not too bad.\n", "false_rows" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# On the other hand, partitioning by Green doesn't help so much.\n", "true_rows, false_rows = partition(training_data, Question(0, 'Green'))\n", "\n", "# We've isolated one apple in the true rows.\n", "true_rows" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# But, the false-rows are badly mixed up.\n", "false_rows\n", "#######" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def find_best_split(rows):\n", " \"\"\"Find the best question to ask by iterating over every feature / value\n", " and calculating the information gain.\"\"\"\n", " best_gain = 0 # keep track of the best information gain\n", " best_question = None # keep train of the feature / value that produced it\n", " current_uncertainty = gini(rows)\n", " n_features = len(rows[0]) - 1 # number of columns\n", "\n", " for col in range(n_features): # for each feature\n", "\n", " values = set([row[col] for row in rows]) # unique values in the column\n", "\n", " for val in values: # for each value\n", "\n", " question = Question(col, val)\n", "\n", " # try splitting the dataset\n", " true_rows, false_rows = partition(rows, question)\n", "\n", " # Skip this split if it doesn't divide the\n", " # dataset.\n", " if len(true_rows) == 0 or len(false_rows) == 0:\n", " continue\n", "\n", " # Calculate the information gain from this split\n", " gain = info_gain(true_rows, false_rows, current_uncertainty)\n", "\n", " # You actually can use '>' instead of '>=' here\n", " # but I wanted the tree to look a certain way for our\n", " # toy dataset.\n", " if gain >= best_gain:\n", " best_gain, best_question = gain, question\n", "\n", " return best_gain, best_question" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#######\n", "# Demo:\n", "# Find the best question to ask first for our toy dataset.\n", "best_gain, best_question = find_best_split(training_data)\n", "best_question\n", "# FYI: is color == Red is just as good. See the note in the code above\n", "# where I used '>='.\n", "#######" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "class Leaf:\n", " \"\"\"A Leaf node classifies data.\n", "\n", " This holds a dictionary of class (e.g., \"Apple\") -> number of times\n", " it appears in the rows from the training data that reach this leaf.\n", " \"\"\"\n", " def __init__(self, rows):\n", " self.predictions = class_counts(rows)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "class Decision_Node:\n", " \"\"\"A Decision Node asks a question.\n", "\n", " This holds a reference to the question, and to the two child nodes.\n", " \"\"\"\n", " def __init__(self, question, true_branch, false_branch):\n", " self.question = question\n", " self.true_branch = true_branch\n", " self.false_branch = false_branch" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def build_tree(rows):\n", " \"\"\"Builds the tree.\n", "\n", " Rules of recursion: 1) Believe that it works. 2) Start by checking\n", " for the base case (no further information gain). 3) Prepare for\n", " giant stack traces.\n", " \"\"\"\n", "\n", " # Try partitioing the dataset on each of the unique attribute,\n", " # calculate the information gain,\n", " # and return the question that produces the highest gain.\n", " gain, question = find_best_split(rows)\n", "\n", " # Base case: no further info gain\n", " # Since we can ask no further questions,\n", " # we'll return a leaf.\n", " if gain == 0:\n", " return Leaf(rows)\n", "\n", " # If we reach here, we have found a useful feature / value\n", " # to partition on.\n", " true_rows, false_rows = partition(rows, question)\n", "\n", " # Recursively build the true branch.\n", " true_branch = build_tree(true_rows)\n", "\n", " # Recursively build the false branch.\n", " false_branch = build_tree(false_rows)\n", "\n", " # Return a Question node.\n", " # This records the best feature / value to ask at this point,\n", " # as well as the branches to follow\n", " # dependingo on the answer.\n", " return Decision_Node(question, true_branch, false_branch)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def print_tree(node, spacing=\"\"):\n", " \"\"\"World's most elegant tree printing function.\"\"\"\n", "\n", " # Base case: we've reached a leaf\n", " if isinstance(node, Leaf):\n", " print(spacing + \"Predict\", node.predictions)\n", " return\n", "\n", " # Print the question at this node\n", " print(spacing + str(node.question))\n", "\n", " # Call this function recursively on the true branch\n", " print(spacing + '--> True:')\n", " print_tree(node.true_branch, spacing + \" \")\n", "\n", " # Call this function recursively on the false branch\n", " print(spacing + '--> False:')\n", " print_tree(node.false_branch, spacing + \" \")" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "my_tree = build_tree(training_data)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "print_tree(my_tree)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def classify(row, node):\n", " \"\"\"See the 'rules of recursion' above.\"\"\"\n", "\n", " # Base case: we've reached a leaf\n", " if isinstance(node, Leaf):\n", " return node.predictions\n", "\n", " # Decide whether to follow the true-branch or the false-branch.\n", " # Compare the feature / value stored in the node,\n", " # to the example we're considering.\n", " if node.question.match(row):\n", " return classify(row, node.true_branch)\n", " else:\n", " return classify(row, node.false_branch)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#######\n", "# Demo:\n", "# The tree predicts the 1st row of our\n", "# training data is an apple with confidence 1.\n", "classify(training_data[0], my_tree)\n", "#######" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def print_leaf(counts):\n", " \"\"\"A nicer way to print the predictions at a leaf.\"\"\"\n", " total = sum(counts.values()) * 1.0\n", " probs = {}\n", " for lbl in counts.keys():\n", " probs[lbl] = str(int(counts[lbl] / total * 100)) + \"%\"\n", " return probs" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#######\n", "# Demo:\n", "# Printing that a bit nicer\n", "print_leaf(classify(training_data[0], my_tree))\n", "#######" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#######\n", "# Demo:\n", "# On the second example, the confidence is lower\n", "print_leaf(classify(training_data[1], my_tree))\n", "#######" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Evaluate\n", "testing_data = [\n", " ['Green', 3, 'Apple'],\n", " ['Yellow', 4, 'Apple'],\n", " ['Red', 2, 'Grape'],\n", " ['Red', 1, 'Grape'],\n", " ['Yellow', 3, 'Lemon'],\n", "]" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "for row in testing_data:\n", " print(\"Actual: %s. Predicted: %s\" %\n", " (row[-1], print_leaf(classify(row, my_tree))))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.2" }, "toc": { "base_numbering": 1, "nav_menu": {}, "number_sections": true, "sideBar": true, "skip_h1_title": false, "title_cell": "Table of Contents", "title_sidebar": "Contents", "toc_cell": false, "toc_position": {}, "toc_section_display": true, "toc_window_display": false } }, "nbformat": 4, "nbformat_minor": 4 }