mirror of
https://github.com/dkam/decisiontree.git
synced 2025-12-28 07:04:53 +00:00
Merge branch 'dvisockas-master'
This commit is contained in:
@@ -2,15 +2,25 @@ require 'rubygems'
|
|||||||
require 'decisiontree'
|
require 'decisiontree'
|
||||||
include DecisionTree
|
include DecisionTree
|
||||||
|
|
||||||
# ---Continuous-----------------------------------------------------------------------------------------
|
# ---Continuous---
|
||||||
|
|
||||||
# Read in the training data
|
# Read in the training data
|
||||||
training, attributes = [], nil
|
training = []
|
||||||
File.open('data/continuous-training.txt','r').each_line { |line|
|
File.open('data/continuous-training.txt', 'r').each_line do |line|
|
||||||
data = line.strip.chomp('.').split(',')
|
data = line.strip.chomp('.').split(',')
|
||||||
attributes ||= data
|
attributes ||= data
|
||||||
training.push(data.collect {|v| (v == 'healthy') || (v == 'colic') ? (v == 'healthy' ? 1 : 0) : v.to_f})
|
training_data = data.collect do |v|
|
||||||
}
|
case v
|
||||||
|
when 'healthy'
|
||||||
|
1
|
||||||
|
when 'colic'
|
||||||
|
0
|
||||||
|
else
|
||||||
|
v.to_f
|
||||||
|
end
|
||||||
|
end
|
||||||
|
training.push(training_data)
|
||||||
|
end
|
||||||
|
|
||||||
# Remove the attribute row from the training data
|
# Remove the attribute row from the training data
|
||||||
training.shift
|
training.shift
|
||||||
@@ -19,15 +29,25 @@ training.shift
|
|||||||
dec_tree = ID3Tree.new(attributes, training, 1, :continuous)
|
dec_tree = ID3Tree.new(attributes, training, 1, :continuous)
|
||||||
dec_tree.train
|
dec_tree.train
|
||||||
|
|
||||||
#---- Test the tree....
|
# ---Test the tree---
|
||||||
|
|
||||||
# Read in the test cases
|
# Read in the test cases
|
||||||
# Note: omit the attribute line (first line), we know the labels from the training data
|
# Note: omit the attribute line (first line), we know the labels from the training data
|
||||||
test = []
|
test = []
|
||||||
File.open('data/continuous-test.txt','r').each_line { |line|
|
File.open('data/continuous-test.txt', 'r').each_line do |line|
|
||||||
data = line.strip.chomp('.').split(',')
|
data = line.strip.chomp('.').split(',')
|
||||||
test.push(data.collect {|v| (v == 'healthy') || (v == 'colic') ? (v == 'healthy' ? 1 : 0) : v.to_f})
|
test_data = data.collect do |v|
|
||||||
}
|
if v == 'healthy' || v == 'colic'
|
||||||
|
v == 'healthy' ? 1 : 0
|
||||||
|
else
|
||||||
|
v.to_f
|
||||||
|
end
|
||||||
|
end
|
||||||
|
test.push(test_data)
|
||||||
|
end
|
||||||
|
|
||||||
# Let the tree predict the output and compare it to the true specified value
|
# Let the tree predict the output and compare it to the true specified value
|
||||||
test.each { |t| predict = dec_tree.predict(t); puts "Predict: #{predict} ... True: #{t.last}"}
|
test.each do |t|
|
||||||
|
predict = dec_tree.predict(t)
|
||||||
|
puts "Predict: #{predict} ... True: #{t.last}"
|
||||||
|
end
|
||||||
|
|||||||
@@ -1,15 +1,25 @@
|
|||||||
require 'rubygems'
|
require 'rubygems'
|
||||||
require 'decisiontree'
|
require 'decisiontree'
|
||||||
|
|
||||||
# ---Discrete-----------------------------------------------------------------------------------------
|
# ---Discrete---
|
||||||
|
|
||||||
# Read in the training data
|
# Read in the training data
|
||||||
training, attributes = [], nil
|
training = []
|
||||||
File.open('data/discrete-training.txt','r').each_line { |line|
|
File.open('data/discrete-training.txt', 'r').each_line do |line|
|
||||||
data = line.strip.split(',')
|
data = line.strip.split(',')
|
||||||
attributes ||= data
|
attributes ||= data
|
||||||
training.push(data.collect {|v| (v == 'will buy') || (v == "won't buy") ? (v == 'will buy' ? 1 : 0) : v})
|
training_data = data.collect do |v|
|
||||||
}
|
case v
|
||||||
|
when 'will buy'
|
||||||
|
1
|
||||||
|
when "won't buy"
|
||||||
|
0
|
||||||
|
else
|
||||||
|
v
|
||||||
|
end
|
||||||
|
end
|
||||||
|
training.push(training_data)
|
||||||
|
end
|
||||||
|
|
||||||
# Remove the attribute row from the training data
|
# Remove the attribute row from the training data
|
||||||
training.shift
|
training.shift
|
||||||
@@ -18,17 +28,31 @@ training.shift
|
|||||||
dec_tree = DecisionTree::ID3Tree.new(attributes, training, 1, :discrete)
|
dec_tree = DecisionTree::ID3Tree.new(attributes, training, 1, :discrete)
|
||||||
dec_tree.train
|
dec_tree.train
|
||||||
|
|
||||||
#---- Test the tree....
|
# ---Test the tree---
|
||||||
|
|
||||||
# Read in the test cases
|
# Read in the test cases
|
||||||
# Note: omit the attribute line (first line), we know the labels from the training data
|
# Note: omit the attribute line (first line), we know the labels from the training data
|
||||||
test = []
|
test = []
|
||||||
File.open('data/discrete-test.txt','r').each_line { |line| data = line.strip.split(',')
|
File.open('data/discrete-test.txt', 'r').each_line do |line|
|
||||||
test.push(data.collect {|v| (v == 'will buy') || (v == "won't buy") ? (v == 'will buy' ? 1 : 0) : v})
|
data = line.strip.split(',')
|
||||||
}
|
test_data = data.collect do |v|
|
||||||
|
case v
|
||||||
|
when 'will buy'
|
||||||
|
1
|
||||||
|
when "won't buy"
|
||||||
|
0
|
||||||
|
else
|
||||||
|
v
|
||||||
|
end
|
||||||
|
end
|
||||||
|
training.push(test_data)
|
||||||
|
end
|
||||||
|
|
||||||
# Let the tree predict the output and compare it to the true specified value
|
# Let the tree predict the output and compare it to the true specified value
|
||||||
test.each { |t| predict = dec_tree.predict(t); puts "Predict: #{predict} ... True: #{t.last}"; }
|
test.each do |t|
|
||||||
|
predict = dec_tree.predict(t)
|
||||||
|
puts "Predict: #{predict} ... True: #{t.last}"
|
||||||
|
end
|
||||||
|
|
||||||
# Graph the tree, save to 'discrete.png'
|
# Graph the tree, save to 'discrete.png'
|
||||||
dec_tree.graph("discrete")
|
dec_tree.graph('discrete')
|
||||||
|
|||||||
@@ -10,7 +10,7 @@ training = [
|
|||||||
[38, 'sick'],
|
[38, 'sick'],
|
||||||
[36.7, 'healthy'],
|
[36.7, 'healthy'],
|
||||||
[40, 'sick'],
|
[40, 'sick'],
|
||||||
[50, 'really sick'],
|
[50, 'really sick']
|
||||||
]
|
]
|
||||||
|
|
||||||
# Instantiate the tree, and train it based on the data (set default to '1')
|
# Instantiate the tree, and train it based on the data (set default to '1')
|
||||||
@@ -20,9 +20,7 @@ dec_tree.train
|
|||||||
test = [37, 'sick']
|
test = [37, 'sick']
|
||||||
|
|
||||||
decision = dec_tree.predict(test)
|
decision = dec_tree.predict(test)
|
||||||
puts "Predicted: #{decision} ... True decision: #{test.last}";
|
puts "Predicted: #{decision} ... True decision: #{test.last}"
|
||||||
|
|
||||||
# Graph the tree, save to 'tree.png'
|
# Graph the tree, save to 'tree.png'
|
||||||
dec_tree.graph("tree")
|
dec_tree.graph('tree')
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
29
lib/core_extensions/array.rb
Normal file
29
lib/core_extensions/array.rb
Normal file
@@ -0,0 +1,29 @@
|
|||||||
|
class Array
|
||||||
|
def classification
|
||||||
|
collect(&:last)
|
||||||
|
end
|
||||||
|
|
||||||
|
# calculate information entropy
|
||||||
|
def entropy
|
||||||
|
return 0 if empty?
|
||||||
|
|
||||||
|
info = {}
|
||||||
|
each do |i|
|
||||||
|
info[i] = !info[i] ? 1 : (info[i] + 1)
|
||||||
|
end
|
||||||
|
|
||||||
|
result(info, length)
|
||||||
|
end
|
||||||
|
|
||||||
|
private
|
||||||
|
|
||||||
|
def result(info, total)
|
||||||
|
final = 0
|
||||||
|
info.each do |_symbol, count|
|
||||||
|
next unless count > 0
|
||||||
|
percentage = count.to_f / total
|
||||||
|
final += -percentage * Math.log(percentage) / Math.log(2.0)
|
||||||
|
end
|
||||||
|
final
|
||||||
|
end
|
||||||
|
end
|
||||||
9
lib/core_extensions/object.rb
Normal file
9
lib/core_extensions/object.rb
Normal file
@@ -0,0 +1,9 @@
|
|||||||
|
class Object
|
||||||
|
def save_to_file(filename)
|
||||||
|
File.open(filename, 'w+') { |f| f << Marshal.dump(self) }
|
||||||
|
end
|
||||||
|
|
||||||
|
def self.load_from_file(filename)
|
||||||
|
Marshal.load(File.read(filename))
|
||||||
|
end
|
||||||
|
end
|
||||||
@@ -1 +1,3 @@
|
|||||||
require File.dirname(__FILE__) + '/decisiontree/id3_tree.rb'
|
require File.dirname(__FILE__) + '/decisiontree/id3_tree.rb'
|
||||||
|
require 'core_extensions/object'
|
||||||
|
require 'core_extensions/array'
|
||||||
|
|||||||
@@ -3,50 +3,33 @@
|
|||||||
### Copyright (c) 2007 Ilya Grigorik <ilya AT igvita DOT com>
|
### Copyright (c) 2007 Ilya Grigorik <ilya AT igvita DOT com>
|
||||||
### Modifed at 2007 by José Ignacio Fernández <joseignacio.fernandez AT gmail DOT com>
|
### Modifed at 2007 by José Ignacio Fernández <joseignacio.fernandez AT gmail DOT com>
|
||||||
|
|
||||||
class Object
|
|
||||||
def save_to_file(filename)
|
|
||||||
File.open(filename, 'w+' ) { |f| f << Marshal.dump(self) }
|
|
||||||
end
|
|
||||||
|
|
||||||
def self.load_from_file(filename)
|
|
||||||
Marshal.load( File.read( filename ) )
|
|
||||||
end
|
|
||||||
end
|
|
||||||
|
|
||||||
class Array
|
|
||||||
def classification; collect { |v| v.last }; end
|
|
||||||
|
|
||||||
# calculate information entropy
|
|
||||||
def entropy
|
|
||||||
return 0 if empty?
|
|
||||||
|
|
||||||
info = {}
|
|
||||||
total = 0
|
|
||||||
each {|i| info[i] = !info[i] ? 1 : (info[i] + 1); total += 1}
|
|
||||||
|
|
||||||
result = 0
|
|
||||||
info.each do |symbol, count|
|
|
||||||
result += -count.to_f/total*Math.log(count.to_f/total)/Math.log(2.0) if (count > 0)
|
|
||||||
end
|
|
||||||
result
|
|
||||||
end
|
|
||||||
end
|
|
||||||
|
|
||||||
module DecisionTree
|
module DecisionTree
|
||||||
Node = Struct.new(:attribute, :threshold, :gain)
|
Node = Struct.new(:attribute, :threshold, :gain)
|
||||||
|
|
||||||
class ID3Tree
|
class ID3Tree
|
||||||
def initialize(attributes, data, default, type)
|
def initialize(attributes, data, default, type)
|
||||||
@used, @tree, @type = {}, {}, type
|
@used = {}
|
||||||
@data, @attributes, @default = data, attributes, default
|
@tree = {}
|
||||||
|
@type = type
|
||||||
|
@data = data
|
||||||
|
@attributes = attributes
|
||||||
|
@default = default
|
||||||
end
|
end
|
||||||
|
|
||||||
def train(data=@data, attributes=@attributes, default=@default)
|
def train(data = @data, attributes = @attributes, default = @default)
|
||||||
attributes = attributes.map {|e| e.to_s}
|
attributes = attributes.map(&:to_s)
|
||||||
initialize(attributes, data, default, @type)
|
initialize(attributes, data, default, @type)
|
||||||
|
|
||||||
# Remove samples with same attributes leaving most common classification
|
# Remove samples with same attributes leaving most common classification
|
||||||
data2 = data.inject({}) {|hash, d| hash[d.slice(0..-2)] ||= Hash.new(0); hash[d.slice(0..-2)][d.last] += 1; hash }.map{|key,val| key + [val.sort_by{ |k, v| v }.last.first]}
|
data2 = data.inject({}) do |hash, d|
|
||||||
|
hash[d.slice(0..-2)] ||= Hash.new(0)
|
||||||
|
hash[d.slice(0..-2)][d.last] += 1
|
||||||
|
hash
|
||||||
|
end
|
||||||
|
|
||||||
|
data2 = data2.map do |key, val|
|
||||||
|
key + [val.sort_by { |_k, v| v }.last.first]
|
||||||
|
end
|
||||||
|
|
||||||
@tree = id3_train(data2, attributes, default)
|
@tree = id3_train(data2, attributes, default)
|
||||||
end
|
end
|
||||||
@@ -57,12 +40,14 @@ module DecisionTree
|
|||||||
|
|
||||||
def fitness_for(attribute)
|
def fitness_for(attribute)
|
||||||
case type(attribute)
|
case type(attribute)
|
||||||
when :discrete; fitness = proc{|a,b,c| id3_discrete(a,b,c)}
|
when :discrete
|
||||||
when :continuous; fitness = proc{|a,b,c| id3_continuous(a,b,c)}
|
proc { |a, b, c| id3_discrete(a, b, c) }
|
||||||
|
when :continuous
|
||||||
|
proc { |a, b, c| id3_continuous(a, b, c) }
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
def id3_train(data, attributes, default, used={})
|
def id3_train(data, attributes, default, _used={})
|
||||||
return default if data.empty?
|
return default if data.empty?
|
||||||
|
|
||||||
# return classification if all examples have the same classification
|
# return classification if all examples have the same classification
|
||||||
@@ -75,7 +60,7 @@ module DecisionTree
|
|||||||
performance = attributes.collect { |attribute| fitness_for(attribute).call(data, attributes, attribute) }
|
performance = attributes.collect { |attribute| fitness_for(attribute).call(data, attributes, attribute) }
|
||||||
max = performance.max { |a,b| a[0] <=> b[0] }
|
max = performance.max { |a,b| a[0] <=> b[0] }
|
||||||
min = performance.min { |a,b| a[0] <=> b[0] }
|
min = performance.min { |a,b| a[0] <=> b[0] }
|
||||||
max = performance.shuffle.first if max[0] == min[0]
|
max = performance.sample if max[0] == min[0]
|
||||||
best = Node.new(attributes[performance.index(max)], max[1], max[0])
|
best = Node.new(attributes[performance.index(max)], max[1], max[0])
|
||||||
best.threshold = nil if @type == :discrete
|
best.threshold = nil if @type == :discrete
|
||||||
@used.has_key?(best.attribute) ? @used[best.attribute] += [best.threshold] : @used[best.attribute] = [best.threshold]
|
@used.has_key?(best.attribute) ? @used[best.attribute] += [best.threshold] : @used[best.attribute] = [best.threshold]
|
||||||
@@ -84,15 +69,22 @@ module DecisionTree
|
|||||||
fitness = fitness_for(best.attribute)
|
fitness = fitness_for(best.attribute)
|
||||||
case type(best.attribute)
|
case type(best.attribute)
|
||||||
when :continuous
|
when :continuous
|
||||||
data.partition { |d| d[attributes.index(best.attribute)] >= best.threshold }.each_with_index { |examples, i|
|
partitioned_data = data.partition do |d|
|
||||||
|
d[attributes.index(best.attribute)] >= best.threshold
|
||||||
|
end
|
||||||
|
partitioned_data.each_with_index do |examples, i|
|
||||||
tree[best][String.new(l[i])] = id3_train(examples, attributes, (data.classification.mode rescue 0), &fitness)
|
tree[best][String.new(l[i])] = id3_train(examples, attributes, (data.classification.mode rescue 0), &fitness)
|
||||||
}
|
end
|
||||||
when :discrete
|
when :discrete
|
||||||
values = data.collect { |d| d[attributes.index(best.attribute)] }.uniq.sort
|
values = data.collect { |d| d[attributes.index(best.attribute)] }.uniq.sort
|
||||||
partitions = values.collect { |val| data.select { |d| d[attributes.index(best.attribute)] == val } }
|
partitions = values.collect do |val|
|
||||||
partitions.each_with_index { |examples, i|
|
data.select do |d|
|
||||||
tree[best][values[i]] = id3_train(examples, attributes-[values[i]], (data.classification.mode rescue 0), &fitness)
|
d[attributes.index(best.attribute)] == val
|
||||||
}
|
end
|
||||||
|
end
|
||||||
|
partitions.each_with_index do |examples, i|
|
||||||
|
tree[best][values[i]] = id3_train(examples, attributes - [values[i]], (data.classification.mode rescue 0), &fitness)
|
||||||
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
tree
|
tree
|
||||||
@@ -100,19 +92,23 @@ module DecisionTree
|
|||||||
|
|
||||||
# ID3 for binary classification of continuous variables (e.g. healthy / sick based on temperature thresholds)
|
# ID3 for binary classification of continuous variables (e.g. healthy / sick based on temperature thresholds)
|
||||||
def id3_continuous(data, attributes, attribute)
|
def id3_continuous(data, attributes, attribute)
|
||||||
values, thresholds = data.collect { |d| d[attributes.index(attribute)] }.uniq.sort, []
|
values = data.collect { |d| d[attributes.index(attribute)] }.uniq.sort
|
||||||
|
thresholds = []
|
||||||
return [-1, -1] if values.size == 1
|
return [-1, -1] if values.size == 1
|
||||||
values.each_index { |i| thresholds.push((values[i]+(values[i+1].nil? ? values[i] : values[i+1])).to_f / 2) }
|
values.each_index do |i|
|
||||||
|
thresholds.push((values[i] + (values[i + 1].nil? ? values[i] : values[i + 1])).to_f / 2)
|
||||||
|
end
|
||||||
thresholds.pop
|
thresholds.pop
|
||||||
#thresholds -= used[attribute] if used.has_key? attribute
|
#thresholds -= used[attribute] if used.has_key? attribute
|
||||||
|
|
||||||
gain = thresholds.collect { |threshold|
|
gain = thresholds.collect do |threshold|
|
||||||
sp = data.partition { |d| d[attributes.index(attribute)] >= threshold }
|
sp = data.partition { |d| d[attributes.index(attribute)] >= threshold }
|
||||||
pos = (sp[0].size).to_f / data.size
|
pos = (sp[0].size).to_f / data.size
|
||||||
neg = (sp[1].size).to_f / data.size
|
neg = (sp[1].size).to_f / data.size
|
||||||
|
|
||||||
[data.classification.entropy - pos*sp[0].classification.entropy - neg*sp[1].classification.entropy, threshold]
|
[data.classification.entropy - pos * sp[0].classification.entropy - neg * sp[1].classification.entropy, threshold]
|
||||||
}.max { |a,b| a[0] <=> b[0] }
|
end
|
||||||
|
gain = gain.max { |a, b| a[0] <=> b[0] }
|
||||||
|
|
||||||
return [-1, -1] if gain.size == 0
|
return [-1, -1] if gain.size == 0
|
||||||
gain
|
gain
|
||||||
@@ -122,7 +118,7 @@ module DecisionTree
|
|||||||
def id3_discrete(data, attributes, attribute)
|
def id3_discrete(data, attributes, attribute)
|
||||||
values = data.collect { |d| d[attributes.index(attribute)] }.uniq.sort
|
values = data.collect { |d| d[attributes.index(attribute)] }.uniq.sort
|
||||||
partitions = values.collect { |val| data.select { |d| d[attributes.index(attribute)] == val } }
|
partitions = values.collect { |val| data.select { |d| d[attributes.index(attribute)] == val } }
|
||||||
remainder = partitions.collect {|p| (p.size.to_f / data.size) * p.classification.entropy}.inject(0) {|i,s| s+=i }
|
remainder = partitions.collect { |p| (p.size.to_f / data.size) * p.classification.entropy }.inject(0) { |a, e| e += a }
|
||||||
|
|
||||||
[data.classification.entropy - remainder, attributes.index(attribute)]
|
[data.classification.entropy - remainder, attributes.index(attribute)]
|
||||||
end
|
end
|
||||||
@@ -131,7 +127,7 @@ module DecisionTree
|
|||||||
descend(@tree, test)
|
descend(@tree, test)
|
||||||
end
|
end
|
||||||
|
|
||||||
def graph(filename, file_type = "png")
|
def graph(filename, file_type = 'png')
|
||||||
require 'graphr'
|
require 'graphr'
|
||||||
dgp = DotGraphPrinter.new(build_tree)
|
dgp = DotGraphPrinter.new(build_tree)
|
||||||
dgp.write_to_file("#{filename}.#{file_type}", file_type)
|
dgp.write_to_file("#{filename}.#{file_type}", file_type)
|
||||||
@@ -143,12 +139,12 @@ module DecisionTree
|
|||||||
rs
|
rs
|
||||||
end
|
end
|
||||||
|
|
||||||
def build_rules(tree=@tree)
|
def build_rules(tree = @tree)
|
||||||
attr = tree.to_a.first
|
attr = tree.to_a.first
|
||||||
cases = attr[1].to_a
|
cases = attr[1].to_a
|
||||||
rules = []
|
rules = []
|
||||||
cases.each do |c,child|
|
cases.each do |c, child|
|
||||||
if child.is_a?(Hash) then
|
if child.is_a?(Hash)
|
||||||
build_rules(child).each do |r|
|
build_rules(child).each do |r|
|
||||||
r2 = r.clone
|
r2 = r.clone
|
||||||
r2.premises.unshift([attr.first, c])
|
r2.premises.unshift([attr.first, c])
|
||||||
@@ -162,42 +158,46 @@ module DecisionTree
|
|||||||
end
|
end
|
||||||
|
|
||||||
private
|
private
|
||||||
|
|
||||||
def descend(tree, test)
|
def descend(tree, test)
|
||||||
attr = tree.to_a.first
|
attr = tree.to_a.first
|
||||||
return @default if !attr
|
return @default unless attr
|
||||||
if type(attr.first.attribute) == :continuous
|
if type(attr.first.attribute) == :continuous
|
||||||
return attr[1]['>='] if !attr[1]['>='].is_a?(Hash) and test[@attributes.index(attr.first.attribute)] >= attr.first.threshold
|
return attr[1]['>='] if !attr[1]['>='].is_a?(Hash) && test[@attributes.index(attr.first.attribute)] >= attr.first.threshold
|
||||||
return attr[1]['<'] if !attr[1]['<'].is_a?(Hash) and test[@attributes.index(attr.first.attribute)] < attr.first.threshold
|
return attr[1]['<'] if !attr[1]['<'].is_a?(Hash) && test[@attributes.index(attr.first.attribute)] < attr.first.threshold
|
||||||
return descend(attr[1]['>='],test) if test[@attributes.index(attr.first.attribute)] >= attr.first.threshold
|
return descend(attr[1]['>='], test) if test[@attributes.index(attr.first.attribute)] >= attr.first.threshold
|
||||||
return descend(attr[1]['<'],test) if test[@attributes.index(attr.first.attribute)] < attr.first.threshold
|
return descend(attr[1]['<'], test) if test[@attributes.index(attr.first.attribute)] < attr.first.threshold
|
||||||
else
|
else
|
||||||
return attr[1][test[@attributes.index(attr[0].attribute)]] if !attr[1][test[@attributes.index(attr[0].attribute)]].is_a?(Hash)
|
return attr[1][test[@attributes.index(attr[0].attribute)]] if !attr[1][test[@attributes.index(attr[0].attribute)]].is_a?(Hash)
|
||||||
return descend(attr[1][test[@attributes.index(attr[0].attribute)]],test)
|
return descend(attr[1][test[@attributes.index(attr[0].attribute)]], test)
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
def build_tree(tree = @tree)
|
def build_tree(tree = @tree)
|
||||||
return [] unless tree.is_a?(Hash)
|
return [] unless tree.is_a?(Hash)
|
||||||
return [["Always", @default]] if tree.empty?
|
return [['Always', @default]] if tree.empty?
|
||||||
|
|
||||||
attr = tree.to_a.first
|
attr = tree.to_a.first
|
||||||
|
|
||||||
links = attr[1].keys.collect do |key|
|
links = attr[1].keys.collect do |key|
|
||||||
parent_text = "#{attr[0].attribute}\n(#{attr[0].object_id})"
|
parent_text = "#{attr[0].attribute}\n(#{attr[0].object_id})"
|
||||||
if attr[1][key].is_a?(Hash) then
|
if attr[1][key].is_a?(Hash)
|
||||||
child = attr[1][key].to_a.first[0]
|
child = attr[1][key].to_a.first[0]
|
||||||
child_text = "#{child.attribute}\n(#{child.object_id})"
|
child_text = "#{child.attribute}\n(#{child.object_id})"
|
||||||
else
|
else
|
||||||
child = attr[1][key]
|
child = attr[1][key]
|
||||||
child_text = "#{child}\n(#{child.to_s.clone.object_id})"
|
child_text = "#{child}\n(#{child.to_s.clone.object_id})"
|
||||||
end
|
end
|
||||||
label_text = "#{key} #{type(attr[0].attribute) == :continuous ? attr[0].threshold : ""}"
|
label_text = "#{key} ''"
|
||||||
|
if type(attr[0].attribute) == :continuous
|
||||||
|
label_text.gsub!("''", attr[0].threshold)
|
||||||
|
end
|
||||||
|
|
||||||
[parent_text, child_text, label_text]
|
[parent_text, child_text, label_text]
|
||||||
end
|
end
|
||||||
attr[1].keys.each { |key| links += build_tree(attr[1][key]) }
|
attr[1].keys.each { |key| links += build_tree(attr[1][key]) }
|
||||||
|
|
||||||
return links
|
links
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
@@ -206,48 +206,56 @@ module DecisionTree
|
|||||||
attr_accessor :conclusion
|
attr_accessor :conclusion
|
||||||
attr_accessor :attributes
|
attr_accessor :attributes
|
||||||
|
|
||||||
def initialize(attributes,premises=[],conclusion=nil)
|
def initialize(attributes, premises = [], conclusion = nil)
|
||||||
@attributes, @premises, @conclusion = attributes, premises, conclusion
|
@attributes = attributes
|
||||||
|
@premises = premises
|
||||||
|
@conclusion = conclusion
|
||||||
end
|
end
|
||||||
|
|
||||||
def to_s
|
def to_s
|
||||||
str = ''
|
str = ''
|
||||||
@premises.each do |p|
|
@premises.each do |p|
|
||||||
str += "#{p.first.attribute} #{p.last} #{p.first.threshold}" if p.first.threshold
|
if p.first.threshold
|
||||||
str += "#{p.first.attribute} = #{p.last}" if !p.first.threshold
|
str += "#{p.first.attribute} #{p.last} #{p.first.threshold}"
|
||||||
|
else
|
||||||
|
str += "#{p.first.attribute} = #{p.last}"
|
||||||
|
end
|
||||||
str += "\n"
|
str += "\n"
|
||||||
end
|
end
|
||||||
str += "=> #{@conclusion} (#{accuracy})"
|
str += "=> #{@conclusion} (#{accuracy})"
|
||||||
end
|
end
|
||||||
|
|
||||||
def predict(test)
|
def predict(test)
|
||||||
verifies = true;
|
verifies = true
|
||||||
@premises.each do |p|
|
@premises.each do |p|
|
||||||
if p.first.threshold then # Continuous
|
if p.first.threshold # Continuous
|
||||||
if !(p.last == '>=' && test[@attributes.index(p.first.attribute)] >= p.first.threshold) && !(p.last == '<' && test[@attributes.index(p.first.attribute)] < p.first.threshold) then
|
if !(p.last == '>=' && test[@attributes.index(p.first.attribute)] >= p.first.threshold) && !(p.last == '<' && test[@attributes.index(p.first.attribute)] < p.first.threshold)
|
||||||
verifies = false; break
|
verifies = false
|
||||||
|
break
|
||||||
end
|
end
|
||||||
else # Discrete
|
else # Discrete
|
||||||
if test[@attributes.index(p.first.attribute)] != p.last then
|
if test[@attributes.index(p.first.attribute)] != p.last
|
||||||
verifies = false; break
|
verifies = false
|
||||||
|
break
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
return @conclusion if verifies
|
return @conclusion if verifies
|
||||||
return nil
|
nil
|
||||||
end
|
end
|
||||||
|
|
||||||
def get_accuracy(data)
|
def get_accuracy(data)
|
||||||
correct = 0; total = 0
|
correct = 0
|
||||||
|
total = 0
|
||||||
data.each do |d|
|
data.each do |d|
|
||||||
prediction = predict(d)
|
prediction = predict(d)
|
||||||
correct += 1 if d.last == prediction
|
correct += 1 if d.last == prediction
|
||||||
total += 1 if !prediction.nil?
|
total += 1 unless prediction.nil?
|
||||||
end
|
end
|
||||||
(correct.to_f + 1) / (total.to_f + 2)
|
(correct.to_f + 1) / (total.to_f + 2)
|
||||||
end
|
end
|
||||||
|
|
||||||
def accuracy(data=nil)
|
def accuracy(data = nil)
|
||||||
data.nil? ? @accuracy : @accuracy = get_accuracy(data)
|
data.nil? ? @accuracy : @accuracy = get_accuracy(data)
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
@@ -256,14 +264,16 @@ module DecisionTree
|
|||||||
attr_accessor :rules
|
attr_accessor :rules
|
||||||
|
|
||||||
def initialize(attributes, data, default, type)
|
def initialize(attributes, data, default, type)
|
||||||
@attributes, @default, @type = attributes, default, type
|
@attributes = attributes
|
||||||
mixed_data = data.sort_by {rand}
|
@default = default
|
||||||
|
@type = type
|
||||||
|
mixed_data = data.sort_by { rand }
|
||||||
cut = (mixed_data.size.to_f * 0.67).to_i
|
cut = (mixed_data.size.to_f * 0.67).to_i
|
||||||
@train_data = mixed_data.slice(0..cut-1)
|
@train_data = mixed_data.slice(0..cut - 1)
|
||||||
@prune_data = mixed_data.slice(cut..-1)
|
@prune_data = mixed_data.slice(cut..-1)
|
||||||
end
|
end
|
||||||
|
|
||||||
def train(train_data=@train_data, attributes=@attributes, default=@default)
|
def train(train_data = @train_data, attributes = @attributes, default = @default)
|
||||||
dec_tree = DecisionTree::ID3Tree.new(attributes, train_data, default, @type)
|
dec_tree = DecisionTree::ID3Tree.new(attributes, train_data, default, @type)
|
||||||
dec_tree.train
|
dec_tree.train
|
||||||
@rules = dec_tree.build_rules
|
@rules = dec_tree.build_rules
|
||||||
@@ -271,21 +281,23 @@ module DecisionTree
|
|||||||
prune
|
prune
|
||||||
end
|
end
|
||||||
|
|
||||||
def prune(data=@prune_data)
|
def prune(data = @prune_data)
|
||||||
@rules.each do |r|
|
@rules.each do |r|
|
||||||
(1..r.premises.size).each do
|
(1..r.premises.size).each do
|
||||||
acc1 = r.accuracy(data)
|
acc1 = r.accuracy(data)
|
||||||
p = r.premises.pop
|
p = r.premises.pop
|
||||||
if acc1 > r.get_accuracy(data) then
|
if acc1 > r.get_accuracy(data)
|
||||||
r.premises.push(p); break
|
r.premises.push(p)
|
||||||
|
break
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
@rules = @rules.sort_by{|r| -r.accuracy(data)}
|
@rules = @rules.sort_by { |r| -r.accuracy(data) }
|
||||||
end
|
end
|
||||||
|
|
||||||
def to_s
|
def to_s
|
||||||
str = ''; @rules.each { |rule| str += "#{rule}\n\n" }
|
str = ''
|
||||||
|
@rules.each { |rule| str += "#{rule}\n\n" }
|
||||||
str
|
str
|
||||||
end
|
end
|
||||||
|
|
||||||
@@ -294,18 +306,21 @@ module DecisionTree
|
|||||||
prediction = r.predict(test)
|
prediction = r.predict(test)
|
||||||
return prediction, r.accuracy unless prediction.nil?
|
return prediction, r.accuracy unless prediction.nil?
|
||||||
end
|
end
|
||||||
return @default, 0.0
|
[@default, 0.0]
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
class Bagging
|
class Bagging
|
||||||
attr_accessor :classifiers
|
attr_accessor :classifiers
|
||||||
def initialize(attributes, data, default, type)
|
def initialize(attributes, data, default, type)
|
||||||
@classifiers, @type = [], type
|
@classifiers = []
|
||||||
@data, @attributes, @default = data, attributes, default
|
@type = type
|
||||||
|
@data = data
|
||||||
|
@attributes = attributes
|
||||||
|
@default = default
|
||||||
end
|
end
|
||||||
|
|
||||||
def train(data=@data, attributes=@attributes, default=@default)
|
def train(data = @data, attributes = @attributes, default = @default)
|
||||||
@classifiers = []
|
@classifiers = []
|
||||||
10.times { @classifiers << Ruleset.new(attributes, data, default, @type) }
|
10.times { @classifiers << Ruleset.new(attributes, data, default, @type) }
|
||||||
@classifiers.each do |c|
|
@classifiers.each do |c|
|
||||||
@@ -320,8 +335,8 @@ module DecisionTree
|
|||||||
predictions[p] += accuracy unless p.nil?
|
predictions[p] += accuracy unless p.nil?
|
||||||
end
|
end
|
||||||
return @default, 0.0 if predictions.empty?
|
return @default, 0.0 if predictions.empty?
|
||||||
winner = predictions.sort_by {|k,v| -v}.first
|
winner = predictions.sort_by { |_k, v| -v }.first
|
||||||
return winner[0], winner[1].to_f / @classifiers.size.to_f
|
[winner[0], winner[1].to_f / @classifiers.size.to_f]
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|||||||
Reference in New Issue
Block a user