mirror of
https://github.com/dkam/decisiontree.git
synced 2025-12-28 07:04:53 +00:00
Merge pull request #32 from cheerfulstoic/master
Performance Improvements
This commit is contained in:
BIN
lib/.DS_Store
vendored
Normal file
BIN
lib/.DS_Store
vendored
Normal file
Binary file not shown.
@@ -1,29 +1,20 @@
|
|||||||
class Array
|
class Array
|
||||||
|
def entropy
|
||||||
|
each_with_object(Hash.new(0)) do |i, result|
|
||||||
|
result[i] += 1
|
||||||
|
end.values.sum do |count|
|
||||||
|
percentage = count.to_f / length
|
||||||
|
|
||||||
|
-percentage * Math.log2(percentage)
|
||||||
|
end
|
||||||
|
end
|
||||||
|
end
|
||||||
|
|
||||||
|
module ArrayClassification
|
||||||
|
refine Array do
|
||||||
def classification
|
def classification
|
||||||
collect(&:last)
|
collect(&:last)
|
||||||
end
|
end
|
||||||
|
|
||||||
# calculate information entropy
|
|
||||||
def entropy
|
|
||||||
return 0 if empty?
|
|
||||||
|
|
||||||
info = {}
|
|
||||||
each do |i|
|
|
||||||
info[i] = !info[i] ? 1 : (info[i] + 1)
|
|
||||||
end
|
|
||||||
|
|
||||||
result(info, length)
|
|
||||||
end
|
|
||||||
|
|
||||||
private
|
|
||||||
|
|
||||||
def result(info, total)
|
|
||||||
final = 0
|
|
||||||
info.each do |_symbol, count|
|
|
||||||
next unless count > 0
|
|
||||||
percentage = count.to_f / total
|
|
||||||
final += -percentage * Math.log(percentage) / Math.log(2.0)
|
|
||||||
end
|
|
||||||
final
|
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
|
|||||||
@@ -1,3 +1,3 @@
|
|||||||
require File.dirname(__FILE__) + '/decisiontree/id3_tree.rb'
|
|
||||||
require 'core_extensions/object'
|
require 'core_extensions/object'
|
||||||
require 'core_extensions/array'
|
require 'core_extensions/array'
|
||||||
|
require File.dirname(__FILE__) + '/decisiontree/id3_tree.rb'
|
||||||
|
|||||||
@@ -6,6 +6,8 @@
|
|||||||
module DecisionTree
|
module DecisionTree
|
||||||
Node = Struct.new(:attribute, :threshold, :gain)
|
Node = Struct.new(:attribute, :threshold, :gain)
|
||||||
|
|
||||||
|
using ArrayClassification
|
||||||
|
|
||||||
class ID3Tree
|
class ID3Tree
|
||||||
def initialize(attributes, data, default, type)
|
def initialize(attributes, data, default, type)
|
||||||
@used = {}
|
@used = {}
|
||||||
@@ -28,7 +30,7 @@ module DecisionTree
|
|||||||
end
|
end
|
||||||
|
|
||||||
data2 = data2.map do |key, val|
|
data2 = data2.map do |key, val|
|
||||||
key + [val.sort_by { |_k, v| v }.last.first]
|
key + [val.sort_by { |_, v| v }.last.first]
|
||||||
end
|
end
|
||||||
|
|
||||||
@tree = id3_train(data2, attributes, default)
|
@tree = id3_train(data2, attributes, default)
|
||||||
@@ -41,9 +43,9 @@ module DecisionTree
|
|||||||
def fitness_for(attribute)
|
def fitness_for(attribute)
|
||||||
case type(attribute)
|
case type(attribute)
|
||||||
when :discrete
|
when :discrete
|
||||||
proc { |a, b, c| id3_discrete(a, b, c) }
|
proc { |*args| id3_discrete(*args) }
|
||||||
when :continuous
|
when :continuous
|
||||||
proc { |a, b, c| id3_continuous(a, b, c) }
|
proc { |*args| id3_continuous(*args) }
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
@@ -66,14 +68,13 @@ module DecisionTree
|
|||||||
@used.has_key?(best.attribute) ? @used[best.attribute] += [best.threshold] : @used[best.attribute] = [best.threshold]
|
@used.has_key?(best.attribute) ? @used[best.attribute] += [best.threshold] : @used[best.attribute] = [best.threshold]
|
||||||
tree, l = {best => {}}, ['>=', '<']
|
tree, l = {best => {}}, ['>=', '<']
|
||||||
|
|
||||||
fitness = fitness_for(best.attribute)
|
|
||||||
case type(best.attribute)
|
case type(best.attribute)
|
||||||
when :continuous
|
when :continuous
|
||||||
partitioned_data = data.partition do |d|
|
partitioned_data = data.partition do |d|
|
||||||
d[attributes.index(best.attribute)] >= best.threshold
|
d[attributes.index(best.attribute)] >= best.threshold
|
||||||
end
|
end
|
||||||
partitioned_data.each_with_index do |examples, i|
|
partitioned_data.each_with_index do |examples, i|
|
||||||
tree[best][String.new(l[i])] = id3_train(examples, attributes, (data.classification.mode rescue 0), &fitness)
|
tree[best][String.new(l[i])] = id3_train(examples, attributes, (data.classification.mode rescue 0))
|
||||||
end
|
end
|
||||||
when :discrete
|
when :discrete
|
||||||
values = data.collect { |d| d[attributes.index(best.attribute)] }.uniq.sort
|
values = data.collect { |d| d[attributes.index(best.attribute)] }.uniq.sort
|
||||||
@@ -83,7 +84,7 @@ module DecisionTree
|
|||||||
end
|
end
|
||||||
end
|
end
|
||||||
partitions.each_with_index do |examples, i|
|
partitions.each_with_index do |examples, i|
|
||||||
tree[best][values[i]] = id3_train(examples, attributes - [values[i]], (data.classification.mode rescue 0), &fitness)
|
tree[best][values[i]] = id3_train(examples, attributes - [values[i]], (data.classification.mode rescue 0))
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
@@ -116,11 +117,18 @@ module DecisionTree
|
|||||||
|
|
||||||
# ID3 for discrete label cases
|
# ID3 for discrete label cases
|
||||||
def id3_discrete(data, attributes, attribute)
|
def id3_discrete(data, attributes, attribute)
|
||||||
values = data.collect { |d| d[attributes.index(attribute)] }.uniq.sort
|
index = attributes.index(attribute)
|
||||||
partitions = values.collect { |val| data.select { |d| d[attributes.index(attribute)] == val } }
|
|
||||||
remainder = partitions.collect { |p| (p.size.to_f / data.size) * p.classification.entropy }.inject(0) { |a, e| e += a }
|
|
||||||
|
|
||||||
[data.classification.entropy - remainder, attributes.index(attribute)]
|
values = data.map { |row| row[index] }.uniq
|
||||||
|
remainder = values.sort.sum do |val|
|
||||||
|
classification = data.each_with_object([]) do |row, result|
|
||||||
|
result << row.last if row[index] == val
|
||||||
|
end
|
||||||
|
|
||||||
|
((classification.size.to_f / data.size) * classification.entropy)
|
||||||
|
end
|
||||||
|
|
||||||
|
[data.classification.entropy - remainder, index]
|
||||||
end
|
end
|
||||||
|
|
||||||
def predict(test)
|
def predict(test)
|
||||||
@@ -320,6 +328,7 @@ module DecisionTree
|
|||||||
|
|
||||||
class Bagging
|
class Bagging
|
||||||
attr_accessor :classifiers
|
attr_accessor :classifiers
|
||||||
|
|
||||||
def initialize(attributes, data, default, type)
|
def initialize(attributes, data, default, type)
|
||||||
@classifiers = []
|
@classifiers = []
|
||||||
@type = type
|
@type = type
|
||||||
@@ -329,10 +338,13 @@ module DecisionTree
|
|||||||
end
|
end
|
||||||
|
|
||||||
def train(data = @data, attributes = @attributes, default = @default)
|
def train(data = @data, attributes = @attributes, default = @default)
|
||||||
@classifiers = []
|
@classifiers = 10.times.map do |i|
|
||||||
10.times { @classifiers << Ruleset.new(attributes, data, default, @type) }
|
Ruleset.new(attributes, data, default, @type)
|
||||||
@classifiers.each do |c|
|
end
|
||||||
c.train(data, attributes, default)
|
|
||||||
|
@classifiers.each_with_index do |classifier, index|
|
||||||
|
puts "Processing classifier ##{index + 1}"
|
||||||
|
classifier.train(data, attributes, default)
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
@@ -348,3 +360,4 @@ module DecisionTree
|
|||||||
end
|
end
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user