From 185f1170e5d6b4819473c6c8065e30167e421838 Mon Sep 17 00:00:00 2001 From: Ilya Grigorik Date: Wed, 16 Sep 2009 22:46:59 -0400 Subject: [PATCH] fix return values; add missing files (doh!) --- CHANGELOG.txt | 17 ++ decisiontree.gemspec | 25 +-- lib/decisiontree.rb | 2 +- lib/decisiontree/id3_tree.rb | 322 +++++++++++++++++++++++++++++++++++ test/helper.rb | 3 + test/test_decisiontree.rb | 4 +- 6 files changed, 359 insertions(+), 14 deletions(-) create mode 100755 CHANGELOG.txt create mode 100755 lib/decisiontree/id3_tree.rb create mode 100644 test/helper.rb diff --git a/CHANGELOG.txt b/CHANGELOG.txt new file mode 100755 index 0000000..9c1d3b5 --- /dev/null +++ b/CHANGELOG.txt @@ -0,0 +1,17 @@ +0.1.0 - Apr. 04/07 + * ID3 algorithms for continuous and discrete cases + * Graphviz component to visualize the learned tree + +0.2.0 - Jul. 07/07 + * Modified and improved by Jose Ignacio (joseignacio.fernandez@gmail.com) + * Added support for multiple, and symbolic outputs and graphing of continuos trees. + * Modified to return the default value when no branches are suitable for the input. + * Refactored entropy code. + +0.3.0 - Sept. 15/07 + * ID3Tree can now handle inconsistent datasets. + * Ruleset is a new class that trains an ID3Tree with 2/3 of the training data, + converts it into a set of rules and prunes the rules with the remaining 1/3 + of the training data (in a C4.5 way). + * Bagging is a bagging-based trainer (quite obvious), which trains 10 Ruleset + trainers and when predicting chooses the best output based on voting. diff --git a/decisiontree.gemspec b/decisiontree.gemspec index 2204b49..63d5c70 100644 --- a/decisiontree.gemspec +++ b/decisiontree.gemspec @@ -11,16 +11,19 @@ spec = Gem::Specification.new do |s| s.rubyforge_project = "decisiontree" # ruby -rpp -e' pp `git ls-files`.split("\n") ' - s.files = ["README.rdoc", - "examples/continuous-id3.rb", - "examples/data/continuous-test.txt", - "examples/data/continuous-training.txt", - "examples/data/discrete-test.txt", - "examples/data/discrete-training.txt", - "examples/discrete-id3.rb", - "examples/simple.rb", - "lib/decisiontree.rb", - "lib/id3_tree.rb", - "test/test_decisiontree.rb"] + s.files = ["CHANGELOG.txt", + "README.rdoc", + "decisiontree.gemspec", + "examples/continuous-id3.rb", + "examples/data/continuous-test.txt", + "examples/data/continuous-training.txt", + "examples/data/discrete-test.txt", + "examples/data/discrete-training.txt", + "examples/discrete-id3.rb", + "examples/simple.rb", + "lib/decisiontree.rb", + "lib/decisiontree/id3_tree.rb", + "test/helper.rb", + "test/test_decisiontree.rb"] end diff --git a/lib/decisiontree.rb b/lib/decisiontree.rb index c7ece9b..c6c3f28 100644 --- a/lib/decisiontree.rb +++ b/lib/decisiontree.rb @@ -1 +1 @@ -require File.dirname(__FILE__) + "/id3_tree.rb" \ No newline at end of file +require File.dirname(__FILE__) + '/decisiontree/id3_tree.rb' \ No newline at end of file diff --git a/lib/decisiontree/id3_tree.rb b/lib/decisiontree/id3_tree.rb new file mode 100755 index 0000000..c8cce8b --- /dev/null +++ b/lib/decisiontree/id3_tree.rb @@ -0,0 +1,322 @@ +# The MIT License +# +### Copyright (c) 2007 Ilya Grigorik +### Modifed at 2007 by José Ignacio Fernández + +begin + require 'graph/graphviz_dot' +rescue LoadError + STDERR.puts "graph/graphviz_dot not installed, graphing functionality not included." +end + +class Object + def save_to_file(filename) + File.open(filename, 'w+' ) { |f| f << Marshal.dump(self) } + end + + def self.load_from_file(filename) + Marshal.load( File.read( filename ) ) + end +end + +class Array + def classification; collect { |v| v.last }; end + + # calculate information entropy + def entropy + return 0 if empty? + + info = {} + total = 0 + each {|i| info[i] = !info[i] ? 1 : (info[i] + 1); total += 1} + + result = 0 + info.each do |symbol, count| + result += -count.to_f/total*Math.log(count.to_f/total)/Math.log(2.0) if (count > 0) + end + result + end +end + +module DecisionTree + Node = Struct.new(:attribute, :threshold, :gain) + + class ID3Tree + def initialize(attributes, data, default, type) + @used, @tree, @type = {}, {}, type + @data, @attributes, @default = data, attributes, default + end + + def train(data=@data, attributes=@attributes, default=@default) + initialize(attributes, data, default, @type) + + # Remove samples with same attributes leaving most common classification + data2 = data.inject({}) {|hash, d| hash[d.slice(0..-2)] ||= Hash.new(0); hash[d.slice(0..-2)][d.last] += 1; hash }.map{|key,val| key + [val.sort_by{ |k, v| v }.last.first]} + + @tree = id3_train(data2, attributes, default) + end + + def id3_train(data, attributes, default, used={}) + # Choose a fitness algorithm + case @type + when :discrete; fitness = proc{|a,b,c| id3_discrete(a,b,c)} + when :continuous; fitness = proc{|a,b,c| id3_continuous(a,b,c)} + end + + return default if data.empty? + + # return classification if all examples have the same classification + return data.first.last if data.classification.uniq.size == 1 + + # Choose best attribute (1. enumerate all attributes / 2. Pick best attribute) + performance = attributes.collect { |attribute| fitness.call(data, attributes, attribute) } + max = performance.max { |a,b| a[0] <=> b[0] } + best = Node.new(attributes[performance.index(max)], max[1], max[0]) + best.threshold = nil if @type == :discrete + @used.has_key?(best.attribute) ? @used[best.attribute] += [best.threshold] : @used[best.attribute] = [best.threshold] + tree, l = {best => {}}, ['>=', '<'] + + case @type + when :continuous + data.partition { |d| d[attributes.index(best.attribute)] >= best.threshold }.each_with_index { |examples, i| + tree[best][String.new(l[i])] = id3_train(examples, attributes, (data.classification.mode rescue 0), &fitness) + } + when :discrete + values = data.collect { |d| d[attributes.index(best.attribute)] }.uniq.sort + partitions = values.collect { |val| data.select { |d| d[attributes.index(best.attribute)] == val } } + partitions.each_with_index { |examples, i| + tree[best][values[i]] = id3_train(examples, attributes-[values[i]], (data.classification.mode rescue 0), &fitness) + } + end + + tree + end + + # ID3 for binary classification of continuous variables (e.g. healthy / sick based on temperature thresholds) + def id3_continuous(data, attributes, attribute) + values, thresholds = data.collect { |d| d[attributes.index(attribute)] }.uniq.sort, [] + return [-1, -1] if values.size == 1 + values.each_index { |i| thresholds.push((values[i]+(values[i+1].nil? ? values[i] : values[i+1])).to_f / 2) } + thresholds.pop + #thresholds -= used[attribute] if used.has_key? attribute + + gain = thresholds.collect { |threshold| + sp = data.partition { |d| d[attributes.index(attribute)] >= threshold } + pos = (sp[0].size).to_f / data.size + neg = (sp[1].size).to_f / data.size + + [data.classification.entropy - pos*sp[0].classification.entropy - neg*sp[1].classification.entropy, threshold] + }.max { |a,b| a[0] <=> b[0] } + + return [-1, -1] if gain.size == 0 + gain + end + + # ID3 for discrete label cases + def id3_discrete(data, attributes, attribute) + values = data.collect { |d| d[attributes.index(attribute)] }.uniq.sort + partitions = values.collect { |val| data.select { |d| d[attributes.index(attribute)] == val } } + remainder = partitions.collect {|p| (p.size.to_f / data.size) * p.classification.entropy}.inject(0) {|i,s| s+=i } + + [data.classification.entropy - remainder, attributes.index(attribute)] + end + + def predict(test) + return (@type == :discrete ? descend_discrete(@tree, test) : descend_continuous(@tree, test)) + end + + def graph(filename) + dgp = DotGraphPrinter.new(build_tree) + dgp.write_to_file("#{filename}.png", "png") + end + + def ruleset + rs = Ruleset.new(@attributes, @data, @default, @type) + rs.rules = build_rules + rs + end + + def build_rules(tree=@tree) + attr = tree.to_a.first + cases = attr[1].to_a + rules = [] + cases.each do |c,child| + if child.is_a?(Hash) then + build_rules(child).each do |r| + r2 = r.clone + r2.premises.unshift([attr.first, c]) + rules << r2 + end + else + rules << Rule.new(@attributes, [[attr.first, c]], child) + end + end + rules + end + + private + def descend_continuous(tree, test) + attr = tree.to_a.first + return @default if !attr + return attr[1]['>='] if !attr[1]['>='].is_a?(Hash) and test[@attributes.index(attr.first.attribute)] >= attr.first.threshold + return attr[1]['<'] if !attr[1]['<'].is_a?(Hash) and test[@attributes.index(attr.first.attribute)] < attr.first.threshold + return descend_continuous(attr[1]['>='],test) if test[@attributes.index(attr.first.attribute)] >= attr.first.threshold + return descend_continuous(attr[1]['<'],test) if test[@attributes.index(attr.first.attribute)] < attr.first.threshold + end + + def descend_discrete(tree, test) + attr = tree.to_a.first + return @default if !attr + return attr[1][test[@attributes.index(attr[0].attribute)]] if !attr[1][test[@attributes.index(attr[0].attribute)]].is_a?(Hash) + return descend_discrete(attr[1][test[@attributes.index(attr[0].attribute)]],test) + end + + def build_tree(tree = @tree) + return [] unless tree.is_a?(Hash) + return [["Always", @default]] if tree.empty? + + attr = tree.to_a.first + + links = attr[1].keys.collect do |key| + parent_text = "#{attr[0].attribute}\n(#{attr[0].object_id})" + if attr[1][key].is_a?(Hash) then + child = attr[1][key].to_a.first[0] + child_text = "#{child.attribute}\n(#{child.object_id})" + else + child = attr[1][key] + child_text = "#{child}\n(#{child.to_s.clone.object_id})" + end + label_text = "#{key} #{@type == :continuous ? attr[0].threshold : ""}" + + [parent_text, child_text, label_text] + end + attr[1].keys.each { |key| links += build_tree(attr[1][key]) } + + return links + end + end + + class Rule + attr_accessor :premises + attr_accessor :conclusion + attr_accessor :attributes + + def initialize(attributes,premises=[],conclusion=nil) + @attributes, @premises, @conclusion = attributes, premises, conclusion + end + + def to_s + str = '' + @premises.each do |p| + str += "#{p.first.attribute} #{p.last} #{p.first.threshold}" if p.first.threshold + str += "#{p.first.attribute} = #{p.last}" if !p.first.threshold + str += "\n" + end + str += "=> #{@conclusion} (#{accuracy})" + end + + def predict(test) + verifies = true; + @premises.each do |p| + if p.first.threshold then # Continuous + if !(p.last == '>=' && test[@attributes.index(p.first.attribute)] >= p.first.threshold) && !(p.last == '<' && test[@attributes.index(p.first.attribute)] < p.first.threshold) then + verifies = false; break + end + else # Discrete + if test[@attributes.index(p.first.attribute)] != p.last then + verifies = false; break + end + end + end + return @conclusion if verifies + return nil + end + + def get_accuracy(data) + correct = 0; total = 0 + data.each do |d| + prediction = predict(d) + correct += 1 if d.last == prediction + total += 1 if !prediction.nil? + end + (correct.to_f + 1) / (total.to_f + 2) + end + + def accuracy(data=nil) + data.nil? ? @accuracy : @accuracy = get_accuracy(data) + end + end + + class Ruleset + attr_accessor :rules + + def initialize(attributes, data, default, type) + @attributes, @default, @type = attributes, default, type + mixed_data = data.sort_by {rand} + cut = (mixed_data.size.to_f * 0.67).to_i + @train_data = mixed_data.slice(0..cut-1) + @prune_data = mixed_data.slice(cut..-1) + end + + def train(train_data=@train_data, attributes=@attributes, default=@default) + dec_tree = DecisionTree::ID3Tree.new(attributes, train_data, default, @type) + dec_tree.train + @rules = dec_tree.build_rules + @rules.each { |r| r.accuracy(train_data) } # Calculate accuracy + prune + end + + def prune(data=@prune_data) + @rules.each do |r| + (1..r.premises.size).each do + acc1 = r.accuracy(data) + p = r.premises.pop + if acc1 > r.get_accuracy(data) then + r.premises.push(p); break + end + end + end + @rules = @rules.sort_by{|r| -r.accuracy(data)} + end + + def to_s + str = ''; @rules.each { |rule| str += "#{rule}\n\n" } + str + end + + def predict(test) + @rules.each do |r| + prediction = r.predict(test) + return prediction, r.accuracy unless prediction.nil? + end + return @default, 0.0 + end + end + + class Bagging + attr_accessor :classifiers + def initialize(attributes, data, default, type) + @classifiers, @type = [], type + @data, @attributes, @default = data, attributes, default + end + + def train(data=@data, attributes=@attributes, default=@default) + @classifiers = [] + 10.times { @classifiers << Ruleset.new(attributes, data, default, @type) } + @classifiers.each do |c| + c.train(data, attributes, default) + end + end + + def predict(test) + predictions = Hash.new(0) + @classifiers.each do |c| + p, accuracy = c.predict(test) + predictions[p] += accuracy unless p.nil? + end + return @default, 0.0 if predictions.empty? + winner = predictions.sort_by {|k,v| -v}.first + return winner[0], winner[1].to_f / @classifiers.size.to_f + end + end +end diff --git a/test/helper.rb b/test/helper.rb new file mode 100644 index 0000000..77d51fa --- /dev/null +++ b/test/helper.rb @@ -0,0 +1,3 @@ +require 'rubygems' +require 'spec' +require 'lib/decisiontree' diff --git a/test/test_decisiontree.rb b/test/test_decisiontree.rb index 7383deb..93bde6b 100644 --- a/test/test_decisiontree.rb +++ b/test/test_decisiontree.rb @@ -8,10 +8,10 @@ describe DecisionTree::ID3Tree do [1,0,1], [0,1,0] ] - + dec_tree = DecisionTree::ID3Tree.new(labels, data, 1, :discrete) dec_tree.train - + dec_tree.predict([1,0]).should == 1 dec_tree.predict([0,1]).should == 0 end