diff --git a/lib/decisiontree/id3_tree.rb b/lib/decisiontree/id3_tree.rb index 0e8aaa7..8221962 100755 --- a/lib/decisiontree/id3_tree.rb +++ b/lib/decisiontree/id3_tree.rb @@ -127,12 +127,12 @@ module DecisionTree index = attributes.index(attribute) values = data.map { |row| row[index] }.uniq - remainder = values.sort.inject(0, :+) do |val| + remainder = values.sort.inject(0) do |sum, val| classification = data.each_with_object([]) do |row, result| result << row.last if row[index] == val end - ((classification.size.to_f / data.size) * classification.entropy) + sum + ((classification.size.to_f / data.size) * classification.entropy) end [data.classification.entropy - remainder, index]