mirror of
https://github.com/dkam/decisiontree.git
synced 2025-12-28 07:04:53 +00:00
Tidy code style
This commit is contained in:
2
Gemfile
2
Gemfile
@@ -1,4 +1,4 @@
|
|||||||
source 'https://rubygems.org'
|
source "https://rubygems.org"
|
||||||
|
|
||||||
# Specify your gem's dependencies in ..gemspec
|
# Specify your gem's dependencies in ..gemspec
|
||||||
gemspec
|
gemspec
|
||||||
|
|||||||
6
Rakefile
6
Rakefile
@@ -1,7 +1,7 @@
|
|||||||
require 'bundler'
|
require "bundler"
|
||||||
Bundler::GemHelper.install_tasks
|
Bundler::GemHelper.install_tasks
|
||||||
|
|
||||||
require 'rspec/core/rake_task'
|
require "rspec/core/rake_task"
|
||||||
RSpec::Core::RakeTask.new
|
RSpec::Core::RakeTask.new
|
||||||
|
|
||||||
task :default => :spec
|
task default: :spec
|
||||||
|
|||||||
@@ -1,4 +1,3 @@
|
|||||||
# -*- encoding: utf-8 -*-
|
|
||||||
$:.push File.expand_path("../lib", __FILE__)
|
$:.push File.expand_path("../lib", __FILE__)
|
||||||
|
|
||||||
Gem::Specification.new do |s|
|
Gem::Specification.new do |s|
|
||||||
@@ -8,7 +7,7 @@ Gem::Specification.new do |s|
|
|||||||
s.authors = ["Ilya Grigorik"]
|
s.authors = ["Ilya Grigorik"]
|
||||||
s.email = ["ilya@igvita.com"]
|
s.email = ["ilya@igvita.com"]
|
||||||
s.homepage = "https://github.com/igrigorik/decisiontree"
|
s.homepage = "https://github.com/igrigorik/decisiontree"
|
||||||
s.summary = %q{ID3-based implementation of the M.L. Decision Tree algorithm}
|
s.summary = "ID3-based implementation of the M.L. Decision Tree algorithm"
|
||||||
s.description = s.summary
|
s.description = s.summary
|
||||||
s.license = "MIT"
|
s.license = "MIT"
|
||||||
|
|
||||||
@@ -21,6 +20,6 @@ Gem::Specification.new do |s|
|
|||||||
|
|
||||||
s.files = `git ls-files`.split("\n")
|
s.files = `git ls-files`.split("\n")
|
||||||
s.test_files = `git ls-files -- {test,spec,features}/*`.split("\n")
|
s.test_files = `git ls-files -- {test,spec,features}/*`.split("\n")
|
||||||
s.executables = `git ls-files -- bin/*`.split("\n").map{ |f| File.basename(f) }
|
s.executables = `git ls-files -- bin/*`.split("\n").map { |f| File.basename(f) }
|
||||||
s.require_paths = ["lib"]
|
s.require_paths = ["lib"]
|
||||||
end
|
end
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
require 'rubygems'
|
require "rubygems"
|
||||||
require 'decisiontree'
|
require "decisiontree"
|
||||||
include DecisionTree
|
include DecisionTree
|
||||||
|
|
||||||
# ---Continuous---
|
# ---Continuous---
|
||||||
@@ -8,14 +8,14 @@ include DecisionTree
|
|||||||
training = []
|
training = []
|
||||||
attributes = nil
|
attributes = nil
|
||||||
|
|
||||||
File.open('data/continuous-training.txt', 'r').each_line do |line|
|
File.open("data/continuous-training.txt", "r").each_line do |line|
|
||||||
data = line.strip.chomp('.').split(',')
|
data = line.strip.chomp(".").split(",")
|
||||||
attributes ||= data
|
attributes ||= data
|
||||||
training_data = data.collect do |v|
|
training_data = data.collect do |v|
|
||||||
case v
|
case v
|
||||||
when 'healthy'
|
when "healthy"
|
||||||
1
|
1
|
||||||
when 'colic'
|
when "colic"
|
||||||
0
|
0
|
||||||
else
|
else
|
||||||
v.to_f
|
v.to_f
|
||||||
@@ -36,11 +36,11 @@ dec_tree.train
|
|||||||
# Read in the test cases
|
# Read in the test cases
|
||||||
# Note: omit the attribute line (first line), we know the labels from the training data
|
# Note: omit the attribute line (first line), we know the labels from the training data
|
||||||
test = []
|
test = []
|
||||||
File.open('data/continuous-test.txt', 'r').each_line do |line|
|
File.open("data/continuous-test.txt", "r").each_line do |line|
|
||||||
data = line.strip.chomp('.').split(',')
|
data = line.strip.chomp(".").split(",")
|
||||||
test_data = data.collect do |v|
|
test_data = data.collect do |v|
|
||||||
if v == 'healthy' || v == 'colic'
|
if v == "healthy" || v == "colic"
|
||||||
v == 'healthy' ? 1 : 0
|
v == "healthy" ? 1 : 0
|
||||||
else
|
else
|
||||||
v.to_f
|
v.to_f
|
||||||
end
|
end
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
require 'rubygems'
|
require "rubygems"
|
||||||
require 'decisiontree'
|
require "decisiontree"
|
||||||
|
|
||||||
# ---Discrete---
|
# ---Discrete---
|
||||||
|
|
||||||
@@ -7,12 +7,12 @@ require 'decisiontree'
|
|||||||
training = []
|
training = []
|
||||||
attributes = nil
|
attributes = nil
|
||||||
|
|
||||||
File.open('data/discrete-training.txt', 'r').each_line do |line|
|
File.open("data/discrete-training.txt", "r").each_line do |line|
|
||||||
data = line.strip.split(',')
|
data = line.strip.split(",")
|
||||||
attributes ||= data
|
attributes ||= data
|
||||||
training_data = data.collect do |v|
|
training_data = data.collect do |v|
|
||||||
case v
|
case v
|
||||||
when 'will buy'
|
when "will buy"
|
||||||
1
|
1
|
||||||
when "won't buy"
|
when "won't buy"
|
||||||
0
|
0
|
||||||
@@ -35,11 +35,11 @@ dec_tree.train
|
|||||||
# Read in the test cases
|
# Read in the test cases
|
||||||
# Note: omit the attribute line (first line), we know the labels from the training data
|
# Note: omit the attribute line (first line), we know the labels from the training data
|
||||||
test = []
|
test = []
|
||||||
File.open('data/discrete-test.txt', 'r').each_line do |line|
|
File.open("data/discrete-test.txt", "r").each_line do |line|
|
||||||
data = line.strip.split(',')
|
data = line.strip.split(",")
|
||||||
test_data = data.collect do |v|
|
test_data = data.collect do |v|
|
||||||
case v
|
case v
|
||||||
when 'will buy'
|
when "will buy"
|
||||||
1
|
1
|
||||||
when "won't buy"
|
when "won't buy"
|
||||||
0
|
0
|
||||||
@@ -57,4 +57,4 @@ test.each do |t|
|
|||||||
end
|
end
|
||||||
|
|
||||||
# Graph the tree, save to 'discrete.png'
|
# Graph the tree, save to 'discrete.png'
|
||||||
dec_tree.graph('discrete')
|
dec_tree.graph("discrete")
|
||||||
|
|||||||
@@ -1,26 +1,26 @@
|
|||||||
#!/usr/bin/ruby
|
#!/usr/bin/ruby
|
||||||
|
|
||||||
require 'rubygems'
|
require "rubygems"
|
||||||
require 'decisiontree'
|
require "decisiontree"
|
||||||
|
|
||||||
attributes = ['Temperature']
|
attributes = ["Temperature"]
|
||||||
training = [
|
training = [
|
||||||
[36.6, 'healthy'],
|
[36.6, "healthy"],
|
||||||
[37, 'sick'],
|
[37, "sick"],
|
||||||
[38, 'sick'],
|
[38, "sick"],
|
||||||
[36.7, 'healthy'],
|
[36.7, "healthy"],
|
||||||
[40, 'sick'],
|
[40, "sick"],
|
||||||
[50, 'really sick']
|
[50, "really sick"]
|
||||||
]
|
]
|
||||||
|
|
||||||
# Instantiate the tree, and train it based on the data (set default to '1')
|
# Instantiate the tree, and train it based on the data (set default to '1')
|
||||||
dec_tree = DecisionTree::ID3Tree.new(attributes, training, 'sick', :continuous)
|
dec_tree = DecisionTree::ID3Tree.new(attributes, training, "sick", :continuous)
|
||||||
dec_tree.train
|
dec_tree.train
|
||||||
|
|
||||||
test = [37, 'sick']
|
test = [37, "sick"]
|
||||||
|
|
||||||
decision = dec_tree.predict(test)
|
decision = dec_tree.predict(test)
|
||||||
puts "Predicted: #{decision} ... True decision: #{test.last}"
|
puts "Predicted: #{decision} ... True decision: #{test.last}"
|
||||||
|
|
||||||
# Graph the tree, save to 'tree.png'
|
# Graph the tree, save to 'tree.png'
|
||||||
dec_tree.graph('tree')
|
dec_tree.graph("tree")
|
||||||
|
|||||||
@@ -17,4 +17,3 @@ module ArrayClassification
|
|||||||
end
|
end
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
class Object
|
class Object
|
||||||
def save_to_file(filename)
|
def save_to_file(filename)
|
||||||
File.open(filename, 'w+') { |f| f << Marshal.dump(self) }
|
File.open(filename, "w+") { |f| f << Marshal.dump(self) }
|
||||||
end
|
end
|
||||||
|
|
||||||
def self.load_from_file(filename)
|
def self.load_from_file(filename)
|
||||||
|
|||||||
@@ -1,3 +1,3 @@
|
|||||||
require 'core_extensions/object'
|
require "core_extensions/object"
|
||||||
require 'core_extensions/array'
|
require "core_extensions/array"
|
||||||
require File.dirname(__FILE__) + '/decisiontree/id3_tree.rb'
|
require File.dirname(__FILE__) + "/decisiontree/id3_tree.rb"
|
||||||
|
|||||||
@@ -23,14 +23,13 @@ module DecisionTree
|
|||||||
initialize(attributes, data, default, @type)
|
initialize(attributes, data, default, @type)
|
||||||
|
|
||||||
# Remove samples with same attributes leaving most common classification
|
# Remove samples with same attributes leaving most common classification
|
||||||
data2 = data.inject({}) do |hash, d|
|
data2 = data.each_with_object({}) do |d, hash|
|
||||||
hash[d.slice(0..-2)] ||= Hash.new(0)
|
hash[d.slice(0..-2)] ||= Hash.new(0)
|
||||||
hash[d.slice(0..-2)][d.last] += 1
|
hash[d.slice(0..-2)][d.last] += 1
|
||||||
hash
|
|
||||||
end
|
end
|
||||||
|
|
||||||
data2 = data2.map do |key, val|
|
data2 = data2.map do |key, val|
|
||||||
key + [val.sort_by { |_, v| v }.last.first]
|
key + [val.max_by { |_, v| v }.first]
|
||||||
end
|
end
|
||||||
|
|
||||||
@tree = id3_train(data2, attributes, default)
|
@tree = id3_train(data2, attributes, default)
|
||||||
@@ -49,7 +48,7 @@ module DecisionTree
|
|||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
def id3_train(data, attributes, default, _used={})
|
def id3_train(data, attributes, default, _used = {})
|
||||||
return default if data.empty?
|
return default if data.empty?
|
||||||
|
|
||||||
# return classification if all examples have the same classification
|
# return classification if all examples have the same classification
|
||||||
@@ -60,13 +59,13 @@ module DecisionTree
|
|||||||
# 2. Pick best attribute
|
# 2. Pick best attribute
|
||||||
# 3. If attributes all score the same, then pick a random one to avoid infinite recursion.
|
# 3. If attributes all score the same, then pick a random one to avoid infinite recursion.
|
||||||
performance = attributes.collect { |attribute| fitness_for(attribute).call(data, attributes, attribute) }
|
performance = attributes.collect { |attribute| fitness_for(attribute).call(data, attributes, attribute) }
|
||||||
max = performance.max { |a,b| a[0] <=> b[0] }
|
max = performance.max_by { |a| a[0] }
|
||||||
min = performance.min { |a,b| a[0] <=> b[0] }
|
min = performance.min_by { |a| a[0] }
|
||||||
max = performance.sample if max[0] == min[0]
|
max = performance.sample if max[0] == min[0]
|
||||||
best = Node.new(attributes[performance.index(max)], max[1], max[0])
|
best = Node.new(attributes[performance.index(max)], max[1], max[0])
|
||||||
best.threshold = nil if @type == :discrete
|
best.threshold = nil if @type == :discrete
|
||||||
@used.has_key?(best.attribute) ? @used[best.attribute] += [best.threshold] : @used[best.attribute] = [best.threshold]
|
@used.has_key?(best.attribute) ? @used[best.attribute] += [best.threshold] : @used[best.attribute] = [best.threshold]
|
||||||
tree, l = {best => {}}, ['>=', '<']
|
tree, l = {best => {}}, [">=", "<"]
|
||||||
|
|
||||||
case type(best.attribute)
|
case type(best.attribute)
|
||||||
when :continuous
|
when :continuous
|
||||||
@@ -74,7 +73,11 @@ module DecisionTree
|
|||||||
d[attributes.index(best.attribute)] >= best.threshold
|
d[attributes.index(best.attribute)] >= best.threshold
|
||||||
end
|
end
|
||||||
partitioned_data.each_with_index do |examples, i|
|
partitioned_data.each_with_index do |examples, i|
|
||||||
tree[best][String.new(l[i])] = id3_train(examples, attributes, (data.classification.mode rescue 0))
|
tree[best][String.new(l[i])] = id3_train(examples, attributes, begin
|
||||||
|
data.classification.mode
|
||||||
|
rescue
|
||||||
|
0
|
||||||
|
end)
|
||||||
end
|
end
|
||||||
when :discrete
|
when :discrete
|
||||||
values = data.collect { |d| d[attributes.index(best.attribute)] }.uniq.sort
|
values = data.collect { |d| d[attributes.index(best.attribute)] }.uniq.sort
|
||||||
@@ -84,7 +87,11 @@ module DecisionTree
|
|||||||
end
|
end
|
||||||
end
|
end
|
||||||
partitions.each_with_index do |examples, i|
|
partitions.each_with_index do |examples, i|
|
||||||
tree[best][values[i]] = id3_train(examples, attributes - [values[i]], (data.classification.mode rescue 0))
|
tree[best][values[i]] = id3_train(examples, attributes - [values[i]], begin
|
||||||
|
data.classification.mode
|
||||||
|
rescue
|
||||||
|
0
|
||||||
|
end)
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
@@ -100,16 +107,16 @@ module DecisionTree
|
|||||||
thresholds.push((values[i] + (values[i + 1].nil? ? values[i] : values[i + 1])).to_f / 2)
|
thresholds.push((values[i] + (values[i + 1].nil? ? values[i] : values[i + 1])).to_f / 2)
|
||||||
end
|
end
|
||||||
thresholds.pop
|
thresholds.pop
|
||||||
#thresholds -= used[attribute] if used.has_key? attribute
|
# thresholds -= used[attribute] if used.has_key? attribute
|
||||||
|
|
||||||
gain = thresholds.collect do |threshold|
|
gain = thresholds.collect do |threshold|
|
||||||
sp = data.partition { |d| d[attributes.index(attribute)] >= threshold }
|
sp = data.partition { |d| d[attributes.index(attribute)] >= threshold }
|
||||||
pos = (sp[0].size).to_f / data.size
|
pos = sp[0].size.to_f / data.size
|
||||||
neg = (sp[1].size).to_f / data.size
|
neg = sp[1].size.to_f / data.size
|
||||||
|
|
||||||
[data.classification.entropy - pos * sp[0].classification.entropy - neg * sp[1].classification.entropy, threshold]
|
[data.classification.entropy - pos * sp[0].classification.entropy - neg * sp[1].classification.entropy, threshold]
|
||||||
end
|
end
|
||||||
gain = gain.max { |a, b| a[0] <=> b[0] }
|
gain = gain.max_by { |a| a[0] }
|
||||||
|
|
||||||
return [-1, -1] if gain.size == 0
|
return [-1, -1] if gain.size == 0
|
||||||
gain
|
gain
|
||||||
@@ -135,16 +142,16 @@ module DecisionTree
|
|||||||
descend(@tree, test)
|
descend(@tree, test)
|
||||||
end
|
end
|
||||||
|
|
||||||
def graph(filename, file_type = 'png')
|
def graph(filename, file_type = "png")
|
||||||
require 'graphr'
|
require "graphr"
|
||||||
dgp = DotGraphPrinter.new(build_tree)
|
dgp = DotGraphPrinter.new(build_tree)
|
||||||
dgp.size = ''
|
dgp.size = ""
|
||||||
dgp.node_labeler = proc { |n| n.split("\n").first }
|
dgp.node_labeler = proc { |n| n.split("\n").first }
|
||||||
dgp.write_to_file("#{filename}.#{file_type}", file_type)
|
dgp.write_to_file("#{filename}.#{file_type}", file_type)
|
||||||
rescue LoadError
|
rescue LoadError
|
||||||
STDERR.puts "Error: Cannot generate graph."
|
warn "Error: Cannot generate graph."
|
||||||
STDERR.puts " The 'graphr' gem doesn't seem to be installed."
|
warn " The 'graphr' gem doesn't seem to be installed."
|
||||||
STDERR.puts " Run 'gem install graphr' or add it to your Gemfile."
|
warn " Run 'gem install graphr' or add it to your Gemfile."
|
||||||
end
|
end
|
||||||
|
|
||||||
def ruleset
|
def ruleset
|
||||||
@@ -177,19 +184,19 @@ module DecisionTree
|
|||||||
attr = tree.to_a.first
|
attr = tree.to_a.first
|
||||||
return @default unless attr
|
return @default unless attr
|
||||||
if type(attr.first.attribute) == :continuous
|
if type(attr.first.attribute) == :continuous
|
||||||
return attr[1]['>='] if !attr[1]['>='].is_a?(Hash) && test[@attributes.index(attr.first.attribute)] >= attr.first.threshold
|
return attr[1][">="] if !attr[1][">="].is_a?(Hash) && test[@attributes.index(attr.first.attribute)] >= attr.first.threshold
|
||||||
return attr[1]['<'] if !attr[1]['<'].is_a?(Hash) && test[@attributes.index(attr.first.attribute)] < attr.first.threshold
|
return attr[1]["<"] if !attr[1]["<"].is_a?(Hash) && test[@attributes.index(attr.first.attribute)] < attr.first.threshold
|
||||||
return descend(attr[1]['>='], test) if test[@attributes.index(attr.first.attribute)] >= attr.first.threshold
|
return descend(attr[1][">="], test) if test[@attributes.index(attr.first.attribute)] >= attr.first.threshold
|
||||||
return descend(attr[1]['<'], test) if test[@attributes.index(attr.first.attribute)] < attr.first.threshold
|
return descend(attr[1]["<"], test) if test[@attributes.index(attr.first.attribute)] < attr.first.threshold
|
||||||
else
|
else
|
||||||
return attr[1][test[@attributes.index(attr[0].attribute)]] if !attr[1][test[@attributes.index(attr[0].attribute)]].is_a?(Hash)
|
return attr[1][test[@attributes.index(attr[0].attribute)]] unless attr[1][test[@attributes.index(attr[0].attribute)]].is_a?(Hash)
|
||||||
return descend(attr[1][test[@attributes.index(attr[0].attribute)]], test)
|
descend(attr[1][test[@attributes.index(attr[0].attribute)]], test)
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
def build_tree(tree = @tree)
|
def build_tree(tree = @tree)
|
||||||
return [] unless tree.is_a?(Hash)
|
return [] unless tree.is_a?(Hash)
|
||||||
return [['Always', @default]] if tree.empty?
|
return [["Always", @default]] if tree.empty?
|
||||||
|
|
||||||
attr = tree.to_a.first
|
attr = tree.to_a.first
|
||||||
|
|
||||||
@@ -203,10 +210,10 @@ module DecisionTree
|
|||||||
child_text = "#{child}\n(#{child.to_s.clone.object_id})"
|
child_text = "#{child}\n(#{child.to_s.clone.object_id})"
|
||||||
end
|
end
|
||||||
|
|
||||||
if type(attr[0].attribute) == :continuous
|
label_text = if type(attr[0].attribute) == :continuous
|
||||||
label_text = "#{key} #{attr[0].threshold}"
|
"#{key} #{attr[0].threshold}"
|
||||||
else
|
else
|
||||||
label_text = key
|
key
|
||||||
end
|
end
|
||||||
|
|
||||||
[parent_text, child_text, label_text]
|
[parent_text, child_text, label_text]
|
||||||
@@ -229,12 +236,12 @@ module DecisionTree
|
|||||||
end
|
end
|
||||||
|
|
||||||
def to_s
|
def to_s
|
||||||
str = ''
|
str = ""
|
||||||
@premises.each do |p|
|
@premises.each do |p|
|
||||||
if p.first.threshold
|
str += if p.first.threshold
|
||||||
str += "#{p.first.attribute} #{p.last} #{p.first.threshold}"
|
"#{p.first.attribute} #{p.last} #{p.first.threshold}"
|
||||||
else
|
else
|
||||||
str += "#{p.first.attribute} = #{p.last}"
|
"#{p.first.attribute} = #{p.last}"
|
||||||
end
|
end
|
||||||
str += "\n"
|
str += "\n"
|
||||||
end
|
end
|
||||||
@@ -245,17 +252,15 @@ module DecisionTree
|
|||||||
verifies = true
|
verifies = true
|
||||||
@premises.each do |p|
|
@premises.each do |p|
|
||||||
if p.first.threshold # Continuous
|
if p.first.threshold # Continuous
|
||||||
if !(p.last == '>=' && test[@attributes.index(p.first.attribute)] >= p.first.threshold) && !(p.last == '<' && test[@attributes.index(p.first.attribute)] < p.first.threshold)
|
if !(p.last == ">=" && test[@attributes.index(p.first.attribute)] >= p.first.threshold) && !(p.last == "<" && test[@attributes.index(p.first.attribute)] < p.first.threshold)
|
||||||
verifies = false
|
verifies = false
|
||||||
break
|
break
|
||||||
end
|
end
|
||||||
else # Discrete
|
elsif test[@attributes.index(p.first.attribute)] != p.last # Discrete
|
||||||
if test[@attributes.index(p.first.attribute)] != p.last
|
|
||||||
verifies = false
|
verifies = false
|
||||||
break
|
break
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
end
|
|
||||||
return @conclusion if verifies
|
return @conclusion if verifies
|
||||||
nil
|
nil
|
||||||
end
|
end
|
||||||
@@ -312,7 +317,7 @@ module DecisionTree
|
|||||||
end
|
end
|
||||||
|
|
||||||
def to_s
|
def to_s
|
||||||
str = ''
|
str = ""
|
||||||
@rules.each { |rule| str += "#{rule}\n\n" }
|
@rules.each { |rule| str += "#{rule}\n\n" }
|
||||||
str
|
str
|
||||||
end
|
end
|
||||||
@@ -355,9 +360,8 @@ module DecisionTree
|
|||||||
predictions[p] += accuracy unless p.nil?
|
predictions[p] += accuracy unless p.nil?
|
||||||
end
|
end
|
||||||
return @default, 0.0 if predictions.empty?
|
return @default, 0.0 if predictions.empty?
|
||||||
winner = predictions.sort_by { |_k, v| -v }.first
|
winner = predictions.min_by { |_k, v| -v }
|
||||||
[winner[0], winner[1].to_f / @classifiers.size.to_f]
|
[winner[0], winner[1].to_f / @classifiers.size.to_f]
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
|
|||||||
@@ -1,19 +1,18 @@
|
|||||||
require 'spec_helper'
|
require "spec_helper"
|
||||||
|
|
||||||
describe describe DecisionTree::ID3Tree do
|
describe describe DecisionTree::ID3Tree do
|
||||||
|
|
||||||
describe "simple discrete case" do
|
describe "simple discrete case" do
|
||||||
Given(:labels) { ["sun", "rain"]}
|
Given(:labels) { ["sun", "rain"] }
|
||||||
Given(:data) do
|
Given(:data) do
|
||||||
[
|
[
|
||||||
[1,0,1],
|
[1, 0, 1],
|
||||||
[0,1,0]
|
[0, 1, 0]
|
||||||
]
|
]
|
||||||
end
|
end
|
||||||
Given(:tree) { DecisionTree::ID3Tree.new(labels, data, 1, :discrete) }
|
Given(:tree) { DecisionTree::ID3Tree.new(labels, data, 1, :discrete) }
|
||||||
When { tree.train }
|
When { tree.train }
|
||||||
Then { expect(tree.predict([1,0])).to eq 1 }
|
Then { expect(tree.predict([1, 0])).to eq 1 }
|
||||||
Then { expect(tree.predict([0,1])).to eq 0 }
|
Then { expect(tree.predict([0, 1])).to eq 0 }
|
||||||
end
|
end
|
||||||
|
|
||||||
describe "discrete attributes" do
|
describe "discrete attributes" do
|
||||||
@@ -84,7 +83,7 @@ describe describe DecisionTree::ID3Tree do
|
|||||||
end
|
end
|
||||||
Given(:tree) { DecisionTree::ID3Tree.new(labels, data, "RED", :discrete) }
|
Given(:tree) { DecisionTree::ID3Tree.new(labels, data, "RED", :discrete) }
|
||||||
When { tree.train }
|
When { tree.train }
|
||||||
Then { expect(tree.predict(["a1","b0","c0"])).to eq "RED" }
|
Then { expect(tree.predict(["a1", "b0", "c0"])).to eq "RED" }
|
||||||
end
|
end
|
||||||
|
|
||||||
describe "numerical labels case" do
|
describe "numerical labels case" do
|
||||||
@@ -109,11 +108,11 @@ describe describe DecisionTree::ID3Tree do
|
|||||||
File.delete("#{FIGURE_FILENAME}.png") if File.file?("#{FIGURE_FILENAME}.png")
|
File.delete("#{FIGURE_FILENAME}.png") if File.file?("#{FIGURE_FILENAME}.png")
|
||||||
end
|
end
|
||||||
|
|
||||||
Given(:labels) { ["sun", "rain"]}
|
Given(:labels) { ["sun", "rain"] }
|
||||||
Given(:data) do
|
Given(:data) do
|
||||||
[
|
[
|
||||||
[1,0,1],
|
[1, 0, 1],
|
||||||
[0,1,0]
|
[0, 1, 0]
|
||||||
]
|
]
|
||||||
end
|
end
|
||||||
Given(:tree) { DecisionTree::ID3Tree.new(labels, data, 1, :discrete) }
|
Given(:tree) { DecisionTree::ID3Tree.new(labels, data, 1, :discrete) }
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
require 'rspec/given'
|
require "rspec/given"
|
||||||
require 'decisiontree'
|
require "decisiontree"
|
||||||
require 'pry'
|
require "pry"
|
||||||
|
|
||||||
FIGURE_FILENAME = "just_a_spec"
|
FIGURE_FILENAME = "just_a_spec"
|
||||||
|
|||||||
Reference in New Issue
Block a user