diff --git a/lib/core_extensions/array.rb b/lib/core_extensions/array.rb index 5a70756..c726d5f 100644 --- a/lib/core_extensions/array.rb +++ b/lib/core_extensions/array.rb @@ -1,29 +1,19 @@ class Array - def classification - collect(&:last) - end - - # calculate information entropy def entropy - return 0 if empty? - - info = {} - each do |i| - info[i] = !info[i] ? 1 : (info[i] + 1) + each_with_object(Hash.new(0)) do |i, result| + result[i] += 1 + end.values.inject(0) do |sum, count| + percentage = count.to_f / length + sum + -percentage * Math.log2(percentage) end - - result(info, length) - end - - private - - def result(info, total) - final = 0 - info.each do |_symbol, count| - next unless count > 0 - percentage = count.to_f / total - final += -percentage * Math.log(percentage) / Math.log(2.0) - end - final end end + +module ArrayClassification + refine Array do + def classification + collect(&:last) + end + end +end + diff --git a/lib/decisiontree.rb b/lib/decisiontree.rb index 3da0b47..190fc0b 100644 --- a/lib/decisiontree.rb +++ b/lib/decisiontree.rb @@ -1,3 +1,3 @@ -require File.dirname(__FILE__) + '/decisiontree/id3_tree.rb' require 'core_extensions/object' require 'core_extensions/array' +require File.dirname(__FILE__) + '/decisiontree/id3_tree.rb' diff --git a/lib/decisiontree/id3_tree.rb b/lib/decisiontree/id3_tree.rb index 1d3658a..78d8493 100755 --- a/lib/decisiontree/id3_tree.rb +++ b/lib/decisiontree/id3_tree.rb @@ -8,6 +8,8 @@ require 'set' module DecisionTree Node = Struct.new(:attribute, :threshold, :gain) + using ArrayClassification + class ID3Tree def initialize(attributes, data, default, type) @used = {} @@ -119,10 +121,14 @@ module DecisionTree def id3_discrete(data, attributes, attribute) index = attributes.index(attribute) - values = Set.new - data.each { |d| values << d[index] } - partitions = values.to_a.sort.collect { |val| data.select { |d| d[index] == val } } - remainder = partitions.collect { |p| (p.size.to_f / data.size) * p.classification.entropy }.inject(0) { |a, e| e += a } + values = data.map { |row| row[index] }.uniq + remainder = values.sort.inject(0) do |sum, val| + classification = data.each_with_object([]) do |row, result| + result << row.last if row[index] == val + end + + sum + ((classification.size.to_f / data.size) * classification.entropy) + end [data.classification.entropy - remainder, index] end @@ -324,6 +330,7 @@ module DecisionTree class Bagging attr_accessor :classifiers + def initialize(attributes, data, default, type) @classifiers = [] @type = type @@ -333,10 +340,13 @@ module DecisionTree end def train(data = @data, attributes = @attributes, default = @default) - @classifiers = [] - 10.times { @classifiers << Ruleset.new(attributes, data, default, @type) } - @classifiers.each do |c| - c.train(data, attributes, default) + @classifiers = 5.times.map do |i| + Ruleset.new(attributes, data, default, @type) + end + + @classifiers.each_with_index do |classifier, index| + puts "Processing classifier ##{index + 1}" + classifier.train(data, attributes, default) end end @@ -352,3 +362,4 @@ module DecisionTree end end end +