diff --git a/Gemfile b/Gemfile index cc56bff..ea32a13 100644 --- a/Gemfile +++ b/Gemfile @@ -1,4 +1,4 @@ -source 'https://rubygems.org' +source "https://rubygems.org" # Specify your gem's dependencies in ..gemspec gemspec diff --git a/Rakefile b/Rakefile index ef12b2a..878cde2 100644 --- a/Rakefile +++ b/Rakefile @@ -1,7 +1,7 @@ -require 'bundler' +require "bundler" Bundler::GemHelper.install_tasks -require 'rspec/core/rake_task' +require "rspec/core/rake_task" RSpec::Core::RakeTask.new -task :default => :spec +task default: :spec diff --git a/decisiontree.gemspec b/decisiontree.gemspec index 9c21da7..eac8761 100644 --- a/decisiontree.gemspec +++ b/decisiontree.gemspec @@ -1,16 +1,15 @@ -# -*- encoding: utf-8 -*- $:.push File.expand_path("../lib", __FILE__) Gem::Specification.new do |s| - s.name = "decisiontree" - s.version = "0.5.0" - s.platform = Gem::Platform::RUBY - s.authors = ["Ilya Grigorik"] - s.email = ["ilya@igvita.com"] - s.homepage = "https://github.com/igrigorik/decisiontree" - s.summary = %q{ID3-based implementation of the M.L. Decision Tree algorithm} + s.name = "decisiontree" + s.version = "0.5.0" + s.platform = Gem::Platform::RUBY + s.authors = ["Ilya Grigorik"] + s.email = ["ilya@igvita.com"] + s.homepage = "https://github.com/igrigorik/decisiontree" + s.summary = "ID3-based implementation of the M.L. Decision Tree algorithm" s.description = s.summary - s.license = "MIT" + s.license = "MIT" s.rubyforge_project = "decisiontree" @@ -19,8 +18,8 @@ Gem::Specification.new do |s| s.add_development_dependency "rspec-given" s.add_development_dependency "pry" - s.files = `git ls-files`.split("\n") - s.test_files = `git ls-files -- {test,spec,features}/*`.split("\n") - s.executables = `git ls-files -- bin/*`.split("\n").map{ |f| File.basename(f) } + s.files = `git ls-files`.split("\n") + s.test_files = `git ls-files -- {test,spec,features}/*`.split("\n") + s.executables = `git ls-files -- bin/*`.split("\n").map { |f| File.basename(f) } s.require_paths = ["lib"] end diff --git a/examples/continuous-id3.rb b/examples/continuous-id3.rb index 42f4432..17cdb90 100644 --- a/examples/continuous-id3.rb +++ b/examples/continuous-id3.rb @@ -1,55 +1,55 @@ -require 'rubygems' -require 'decisiontree' -include DecisionTree - -# ---Continuous--- - -# Read in the training data -training = [] -attributes = nil - -File.open('data/continuous-training.txt', 'r').each_line do |line| - data = line.strip.chomp('.').split(',') - attributes ||= data - training_data = data.collect do |v| - case v - when 'healthy' - 1 - when 'colic' - 0 - else - v.to_f - end - end - training.push(training_data) -end - -# Remove the attribute row from the training data -training.shift - -# Instantiate the tree, and train it based on the data (set default to '1') -dec_tree = ID3Tree.new(attributes, training, 1, :continuous) -dec_tree.train - -# ---Test the tree--- - -# Read in the test cases -# Note: omit the attribute line (first line), we know the labels from the training data -test = [] -File.open('data/continuous-test.txt', 'r').each_line do |line| - data = line.strip.chomp('.').split(',') - test_data = data.collect do |v| - if v == 'healthy' || v == 'colic' - v == 'healthy' ? 1 : 0 - else - v.to_f - end - end - test.push(test_data) -end - -# Let the tree predict the output and compare it to the true specified value -test.each do |t| - predict = dec_tree.predict(t) - puts "Predict: #{predict} ... True: #{t.last}" -end +require "rubygems" +require "decisiontree" +include DecisionTree + +# ---Continuous--- + +# Read in the training data +training = [] +attributes = nil + +File.open("data/continuous-training.txt", "r").each_line do |line| + data = line.strip.chomp(".").split(",") + attributes ||= data + training_data = data.collect do |v| + case v + when "healthy" + 1 + when "colic" + 0 + else + v.to_f + end + end + training.push(training_data) +end + +# Remove the attribute row from the training data +training.shift + +# Instantiate the tree, and train it based on the data (set default to '1') +dec_tree = ID3Tree.new(attributes, training, 1, :continuous) +dec_tree.train + +# ---Test the tree--- + +# Read in the test cases +# Note: omit the attribute line (first line), we know the labels from the training data +test = [] +File.open("data/continuous-test.txt", "r").each_line do |line| + data = line.strip.chomp(".").split(",") + test_data = data.collect do |v| + if v == "healthy" || v == "colic" + v == "healthy" ? 1 : 0 + else + v.to_f + end + end + test.push(test_data) +end + +# Let the tree predict the output and compare it to the true specified value +test.each do |t| + predict = dec_tree.predict(t) + puts "Predict: #{predict} ... True: #{t.last}" +end diff --git a/examples/discrete-id3.rb b/examples/discrete-id3.rb index 6f06c79..ef3b0df 100644 --- a/examples/discrete-id3.rb +++ b/examples/discrete-id3.rb @@ -1,60 +1,60 @@ -require 'rubygems' -require 'decisiontree' - -# ---Discrete--- - -# Read in the training data -training = [] -attributes = nil - -File.open('data/discrete-training.txt', 'r').each_line do |line| - data = line.strip.split(',') - attributes ||= data - training_data = data.collect do |v| - case v - when 'will buy' - 1 - when "won't buy" - 0 - else - v - end - end - training.push(training_data) -end - -# Remove the attribute row from the training data -training.shift - -# Instantiate the tree, and train it based on the data (set default to '1') -dec_tree = DecisionTree::ID3Tree.new(attributes, training, 1, :discrete) -dec_tree.train - -# ---Test the tree--- - -# Read in the test cases -# Note: omit the attribute line (first line), we know the labels from the training data -test = [] -File.open('data/discrete-test.txt', 'r').each_line do |line| - data = line.strip.split(',') - test_data = data.collect do |v| - case v - when 'will buy' - 1 - when "won't buy" - 0 - else - v - end - end - test.push(test_data) -end - -# Let the tree predict the output and compare it to the true specified value -test.each do |t| - predict = dec_tree.predict(t) - puts "Predict: #{predict} ... True: #{t.last}" -end - -# Graph the tree, save to 'discrete.png' -dec_tree.graph('discrete') +require "rubygems" +require "decisiontree" + +# ---Discrete--- + +# Read in the training data +training = [] +attributes = nil + +File.open("data/discrete-training.txt", "r").each_line do |line| + data = line.strip.split(",") + attributes ||= data + training_data = data.collect do |v| + case v + when "will buy" + 1 + when "won't buy" + 0 + else + v + end + end + training.push(training_data) +end + +# Remove the attribute row from the training data +training.shift + +# Instantiate the tree, and train it based on the data (set default to '1') +dec_tree = DecisionTree::ID3Tree.new(attributes, training, 1, :discrete) +dec_tree.train + +# ---Test the tree--- + +# Read in the test cases +# Note: omit the attribute line (first line), we know the labels from the training data +test = [] +File.open("data/discrete-test.txt", "r").each_line do |line| + data = line.strip.split(",") + test_data = data.collect do |v| + case v + when "will buy" + 1 + when "won't buy" + 0 + else + v + end + end + test.push(test_data) +end + +# Let the tree predict the output and compare it to the true specified value +test.each do |t| + predict = dec_tree.predict(t) + puts "Predict: #{predict} ... True: #{t.last}" +end + +# Graph the tree, save to 'discrete.png' +dec_tree.graph("discrete") diff --git a/examples/simple.rb b/examples/simple.rb index e023675..d714227 100755 --- a/examples/simple.rb +++ b/examples/simple.rb @@ -1,26 +1,26 @@ #!/usr/bin/ruby -require 'rubygems' -require 'decisiontree' +require "rubygems" +require "decisiontree" -attributes = ['Temperature'] +attributes = ["Temperature"] training = [ - [36.6, 'healthy'], - [37, 'sick'], - [38, 'sick'], - [36.7, 'healthy'], - [40, 'sick'], - [50, 'really sick'] + [36.6, "healthy"], + [37, "sick"], + [38, "sick"], + [36.7, "healthy"], + [40, "sick"], + [50, "really sick"] ] # Instantiate the tree, and train it based on the data (set default to '1') -dec_tree = DecisionTree::ID3Tree.new(attributes, training, 'sick', :continuous) +dec_tree = DecisionTree::ID3Tree.new(attributes, training, "sick", :continuous) dec_tree.train -test = [37, 'sick'] +test = [37, "sick"] decision = dec_tree.predict(test) puts "Predicted: #{decision} ... True decision: #{test.last}" # Graph the tree, save to 'tree.png' -dec_tree.graph('tree') +dec_tree.graph("tree") diff --git a/lib/core_extensions/array.rb b/lib/core_extensions/array.rb index 34e762a..81f8aed 100644 --- a/lib/core_extensions/array.rb +++ b/lib/core_extensions/array.rb @@ -17,4 +17,3 @@ module ArrayClassification end end end - diff --git a/lib/core_extensions/object.rb b/lib/core_extensions/object.rb index 0b79fd9..75afe7c 100644 --- a/lib/core_extensions/object.rb +++ b/lib/core_extensions/object.rb @@ -1,6 +1,6 @@ class Object def save_to_file(filename) - File.open(filename, 'w+') { |f| f << Marshal.dump(self) } + File.open(filename, "w+") { |f| f << Marshal.dump(self) } end def self.load_from_file(filename) diff --git a/lib/decisiontree.rb b/lib/decisiontree.rb index 190fc0b..efd7a08 100644 --- a/lib/decisiontree.rb +++ b/lib/decisiontree.rb @@ -1,3 +1,3 @@ -require 'core_extensions/object' -require 'core_extensions/array' -require File.dirname(__FILE__) + '/decisiontree/id3_tree.rb' +require "core_extensions/object" +require "core_extensions/array" +require File.dirname(__FILE__) + "/decisiontree/id3_tree.rb" diff --git a/lib/decisiontree/id3_tree.rb b/lib/decisiontree/id3_tree.rb index 8df1483..0e8aaa7 100755 --- a/lib/decisiontree/id3_tree.rb +++ b/lib/decisiontree/id3_tree.rb @@ -23,14 +23,13 @@ module DecisionTree initialize(attributes, data, default, @type) # Remove samples with same attributes leaving most common classification - data2 = data.inject({}) do |hash, d| + data2 = data.each_with_object({}) do |d, hash| hash[d.slice(0..-2)] ||= Hash.new(0) hash[d.slice(0..-2)][d.last] += 1 - hash end data2 = data2.map do |key, val| - key + [val.sort_by { |_, v| v }.last.first] + key + [val.max_by { |_, v| v }.first] end @tree = id3_train(data2, attributes, default) @@ -49,7 +48,7 @@ module DecisionTree end end - def id3_train(data, attributes, default, _used={}) + def id3_train(data, attributes, default, _used = {}) return default if data.empty? # return classification if all examples have the same classification @@ -60,13 +59,13 @@ module DecisionTree # 2. Pick best attribute # 3. If attributes all score the same, then pick a random one to avoid infinite recursion. performance = attributes.collect { |attribute| fitness_for(attribute).call(data, attributes, attribute) } - max = performance.max { |a,b| a[0] <=> b[0] } - min = performance.min { |a,b| a[0] <=> b[0] } + max = performance.max_by { |a| a[0] } + min = performance.min_by { |a| a[0] } max = performance.sample if max[0] == min[0] best = Node.new(attributes[performance.index(max)], max[1], max[0]) best.threshold = nil if @type == :discrete @used.has_key?(best.attribute) ? @used[best.attribute] += [best.threshold] : @used[best.attribute] = [best.threshold] - tree, l = {best => {}}, ['>=', '<'] + tree, l = {best => {}}, [">=", "<"] case type(best.attribute) when :continuous @@ -74,7 +73,11 @@ module DecisionTree d[attributes.index(best.attribute)] >= best.threshold end partitioned_data.each_with_index do |examples, i| - tree[best][String.new(l[i])] = id3_train(examples, attributes, (data.classification.mode rescue 0)) + tree[best][String.new(l[i])] = id3_train(examples, attributes, begin + data.classification.mode + rescue + 0 + end) end when :discrete values = data.collect { |d| d[attributes.index(best.attribute)] }.uniq.sort @@ -84,7 +87,11 @@ module DecisionTree end end partitions.each_with_index do |examples, i| - tree[best][values[i]] = id3_train(examples, attributes - [values[i]], (data.classification.mode rescue 0)) + tree[best][values[i]] = id3_train(examples, attributes - [values[i]], begin + data.classification.mode + rescue + 0 + end) end end @@ -100,16 +107,16 @@ module DecisionTree thresholds.push((values[i] + (values[i + 1].nil? ? values[i] : values[i + 1])).to_f / 2) end thresholds.pop - #thresholds -= used[attribute] if used.has_key? attribute + # thresholds -= used[attribute] if used.has_key? attribute gain = thresholds.collect do |threshold| sp = data.partition { |d| d[attributes.index(attribute)] >= threshold } - pos = (sp[0].size).to_f / data.size - neg = (sp[1].size).to_f / data.size + pos = sp[0].size.to_f / data.size + neg = sp[1].size.to_f / data.size [data.classification.entropy - pos * sp[0].classification.entropy - neg * sp[1].classification.entropy, threshold] end - gain = gain.max { |a, b| a[0] <=> b[0] } + gain = gain.max_by { |a| a[0] } return [-1, -1] if gain.size == 0 gain @@ -135,16 +142,16 @@ module DecisionTree descend(@tree, test) end - def graph(filename, file_type = 'png') - require 'graphr' + def graph(filename, file_type = "png") + require "graphr" dgp = DotGraphPrinter.new(build_tree) - dgp.size = '' + dgp.size = "" dgp.node_labeler = proc { |n| n.split("\n").first } dgp.write_to_file("#{filename}.#{file_type}", file_type) rescue LoadError - STDERR.puts "Error: Cannot generate graph." - STDERR.puts " The 'graphr' gem doesn't seem to be installed." - STDERR.puts " Run 'gem install graphr' or add it to your Gemfile." + warn "Error: Cannot generate graph." + warn " The 'graphr' gem doesn't seem to be installed." + warn " Run 'gem install graphr' or add it to your Gemfile." end def ruleset @@ -177,19 +184,19 @@ module DecisionTree attr = tree.to_a.first return @default unless attr if type(attr.first.attribute) == :continuous - return attr[1]['>='] if !attr[1]['>='].is_a?(Hash) && test[@attributes.index(attr.first.attribute)] >= attr.first.threshold - return attr[1]['<'] if !attr[1]['<'].is_a?(Hash) && test[@attributes.index(attr.first.attribute)] < attr.first.threshold - return descend(attr[1]['>='], test) if test[@attributes.index(attr.first.attribute)] >= attr.first.threshold - return descend(attr[1]['<'], test) if test[@attributes.index(attr.first.attribute)] < attr.first.threshold + return attr[1][">="] if !attr[1][">="].is_a?(Hash) && test[@attributes.index(attr.first.attribute)] >= attr.first.threshold + return attr[1]["<"] if !attr[1]["<"].is_a?(Hash) && test[@attributes.index(attr.first.attribute)] < attr.first.threshold + return descend(attr[1][">="], test) if test[@attributes.index(attr.first.attribute)] >= attr.first.threshold + return descend(attr[1]["<"], test) if test[@attributes.index(attr.first.attribute)] < attr.first.threshold else - return attr[1][test[@attributes.index(attr[0].attribute)]] if !attr[1][test[@attributes.index(attr[0].attribute)]].is_a?(Hash) - return descend(attr[1][test[@attributes.index(attr[0].attribute)]], test) + return attr[1][test[@attributes.index(attr[0].attribute)]] unless attr[1][test[@attributes.index(attr[0].attribute)]].is_a?(Hash) + descend(attr[1][test[@attributes.index(attr[0].attribute)]], test) end end def build_tree(tree = @tree) return [] unless tree.is_a?(Hash) - return [['Always', @default]] if tree.empty? + return [["Always", @default]] if tree.empty? attr = tree.to_a.first @@ -203,10 +210,10 @@ module DecisionTree child_text = "#{child}\n(#{child.to_s.clone.object_id})" end - if type(attr[0].attribute) == :continuous - label_text = "#{key} #{attr[0].threshold}" + label_text = if type(attr[0].attribute) == :continuous + "#{key} #{attr[0].threshold}" else - label_text = key + key end [parent_text, child_text, label_text] @@ -229,12 +236,12 @@ module DecisionTree end def to_s - str = '' + str = "" @premises.each do |p| - if p.first.threshold - str += "#{p.first.attribute} #{p.last} #{p.first.threshold}" + str += if p.first.threshold + "#{p.first.attribute} #{p.last} #{p.first.threshold}" else - str += "#{p.first.attribute} = #{p.last}" + "#{p.first.attribute} = #{p.last}" end str += "\n" end @@ -245,15 +252,13 @@ module DecisionTree verifies = true @premises.each do |p| if p.first.threshold # Continuous - if !(p.last == '>=' && test[@attributes.index(p.first.attribute)] >= p.first.threshold) && !(p.last == '<' && test[@attributes.index(p.first.attribute)] < p.first.threshold) - verifies = false - break - end - else # Discrete - if test[@attributes.index(p.first.attribute)] != p.last + if !(p.last == ">=" && test[@attributes.index(p.first.attribute)] >= p.first.threshold) && !(p.last == "<" && test[@attributes.index(p.first.attribute)] < p.first.threshold) verifies = false break end + elsif test[@attributes.index(p.first.attribute)] != p.last # Discrete + verifies = false + break end end return @conclusion if verifies @@ -312,7 +317,7 @@ module DecisionTree end def to_s - str = '' + str = "" @rules.each { |rule| str += "#{rule}\n\n" } str end @@ -355,9 +360,8 @@ module DecisionTree predictions[p] += accuracy unless p.nil? end return @default, 0.0 if predictions.empty? - winner = predictions.sort_by { |_k, v| -v }.first + winner = predictions.min_by { |_k, v| -v } [winner[0], winner[1].to_f / @classifiers.size.to_f] end end end - diff --git a/spec/id3_spec.rb b/spec/id3_spec.rb index 618573f..b388223 100644 --- a/spec/id3_spec.rb +++ b/spec/id3_spec.rb @@ -1,19 +1,18 @@ -require 'spec_helper' +require "spec_helper" describe describe DecisionTree::ID3Tree do - describe "simple discrete case" do - Given(:labels) { ["sun", "rain"]} + Given(:labels) { ["sun", "rain"] } Given(:data) do [ - [1,0,1], - [0,1,0] + [1, 0, 1], + [0, 1, 0] ] end Given(:tree) { DecisionTree::ID3Tree.new(labels, data, 1, :discrete) } When { tree.train } - Then { expect(tree.predict([1,0])).to eq 1 } - Then { expect(tree.predict([0,1])).to eq 0 } + Then { expect(tree.predict([1, 0])).to eq 1 } + Then { expect(tree.predict([0, 1])).to eq 0 } end describe "discrete attributes" do @@ -84,7 +83,7 @@ describe describe DecisionTree::ID3Tree do end Given(:tree) { DecisionTree::ID3Tree.new(labels, data, "RED", :discrete) } When { tree.train } - Then { expect(tree.predict(["a1","b0","c0"])).to eq "RED" } + Then { expect(tree.predict(["a1", "b0", "c0"])).to eq "RED" } end describe "numerical labels case" do @@ -109,11 +108,11 @@ describe describe DecisionTree::ID3Tree do File.delete("#{FIGURE_FILENAME}.png") if File.file?("#{FIGURE_FILENAME}.png") end - Given(:labels) { ["sun", "rain"]} + Given(:labels) { ["sun", "rain"] } Given(:data) do [ - [1,0,1], - [0,1,0] + [1, 0, 1], + [0, 1, 0] ] end Given(:tree) { DecisionTree::ID3Tree.new(labels, data, 1, :discrete) } diff --git a/spec/spec_helper.rb b/spec/spec_helper.rb index 04de356..c399ec5 100644 --- a/spec/spec_helper.rb +++ b/spec/spec_helper.rb @@ -1,5 +1,5 @@ -require 'rspec/given' -require 'decisiontree' -require 'pry' +require "rspec/given" +require "decisiontree" +require "pry" FIGURE_FILENAME = "just_a_spec"