diff --git a/lib/decisiontree.rb b/lib/decisiontree.rb index 5583923..99f893f 100644 --- a/lib/decisiontree.rb +++ b/lib/decisiontree.rb @@ -1 +1,2 @@ require File.dirname(__FILE__) + '/decisiontree/id3_tree.rb' +require 'core_extension/array.rb' diff --git a/lib/decisiontree/id3_tree.rb b/lib/decisiontree/id3_tree.rb index 937ce09..187b97e 100755 --- a/lib/decisiontree/id3_tree.rb +++ b/lib/decisiontree/id3_tree.rb @@ -13,32 +13,6 @@ class Object end end -class Array - def classification - collect(&:last) - end - - # calculate information entropy - def entropy - return 0 if empty? - - info = {} - total = 0 - each do |i| - info[i] = !info[i] ? 1 : (info[i] + 1) - total += 1 - end - - result = 0 - info.each do |_symbol, count| - if count > 0 - result += -count.to_f / total * Math.log(count.to_f / total) / Math.log(2.0) - end - end - result - end -end - module DecisionTree Node = Struct.new(:attribute, :threshold, :gain) @@ -105,9 +79,10 @@ module DecisionTree fitness = fitness_for(best.attribute) case type(best.attribute) when :continuous - data.partition do |d| + partitioned_data = data.partition do |d| d[attributes.index(best.attribute)] >= best.threshold - end.each_with_index do |examples, i| + end + partitioned_data.each_with_index do |examples, i| tree[best][String.new(l[i])] = id3_train(examples, attributes, (data.classification.mode rescue 0), &fitness) end when :discrete