mirror of
https://github.com/dkam/decisiontree.git
synced 2025-12-27 22:54:52 +00:00
Making code easier to read
This commit is contained in:
@@ -14,7 +14,9 @@ class Object
|
||||
end
|
||||
|
||||
class Array
|
||||
def classification; collect { |v| v.last }; end
|
||||
def classification
|
||||
collect { |v| v.last }
|
||||
end
|
||||
|
||||
# calculate information entropy
|
||||
def entropy
|
||||
@@ -22,7 +24,7 @@ class Array
|
||||
|
||||
info = {}
|
||||
total = 0
|
||||
each {|i| info[i] = !info[i] ? 1 : (info[i] + 1); total += 1}
|
||||
each { |i| info[i] = !info[i] ? 1 : (info[i] + 1); total += 1}
|
||||
|
||||
result = 0
|
||||
info.each do |symbol, count|
|
||||
@@ -42,11 +44,11 @@ module DecisionTree
|
||||
end
|
||||
|
||||
def train(data=@data, attributes=@attributes, default=@default)
|
||||
attributes = attributes.map {|e| e.to_s}
|
||||
attributes = attributes.map { |e| e.to_s}
|
||||
initialize(attributes, data, default, @type)
|
||||
|
||||
# Remove samples with same attributes leaving most common classification
|
||||
data2 = data.inject({}) {|hash, d| hash[d.slice(0..-2)] ||= Hash.new(0); hash[d.slice(0..-2)][d.last] += 1; hash }.map{|key,val| key + [val.sort_by{ |k, v| v }.last.first]}
|
||||
data2 = data.inject({}) { |hash, d| hash[d.slice(0..-2)] ||= Hash.new(0); hash[d.slice(0..-2)][d.last] += 1; hash }.map{ |key,val| key + [val.sort_by{ |k, v| v }.last.first]}
|
||||
|
||||
@tree = id3_train(data2, attributes, default)
|
||||
end
|
||||
@@ -57,8 +59,8 @@ module DecisionTree
|
||||
|
||||
def fitness_for(attribute)
|
||||
case type(attribute)
|
||||
when :discrete; fitness = proc{|a,b,c| id3_discrete(a,b,c)}
|
||||
when :continuous; fitness = proc{|a,b,c| id3_continuous(a,b,c)}
|
||||
when :discrete; fitness = proc { |a,b,c| id3_discrete(a,b,c) }
|
||||
when :continuous; fitness = proc { |a,b,c| id3_continuous(a,b,c) }
|
||||
end
|
||||
end
|
||||
|
||||
@@ -122,7 +124,7 @@ module DecisionTree
|
||||
def id3_discrete(data, attributes, attribute)
|
||||
values = data.collect { |d| d[attributes.index(attribute)] }.uniq.sort
|
||||
partitions = values.collect { |val| data.select { |d| d[attributes.index(attribute)] == val } }
|
||||
remainder = partitions.collect {|p| (p.size.to_f / data.size) * p.classification.entropy}.inject(0) {|i,s| s+=i }
|
||||
remainder = partitions.collect { |p| (p.size.to_f / data.size) * p.classification.entropy}.inject(0) { |i,s| s+=i }
|
||||
|
||||
[data.classification.entropy - remainder, attributes.index(attribute)]
|
||||
end
|
||||
@@ -281,7 +283,7 @@ module DecisionTree
|
||||
end
|
||||
end
|
||||
end
|
||||
@rules = @rules.sort_by{|r| -r.accuracy(data)}
|
||||
@rules = @rules.sort_by{ |r| -r.accuracy(data) }
|
||||
end
|
||||
|
||||
def to_s
|
||||
@@ -320,7 +322,7 @@ module DecisionTree
|
||||
predictions[p] += accuracy unless p.nil?
|
||||
end
|
||||
return @default, 0.0 if predictions.empty?
|
||||
winner = predictions.sort_by {|k,v| -v}.first
|
||||
winner = predictions.sort_by { |k,v| -v}.first
|
||||
return winner[0], winner[1].to_f / @classifiers.size.to_f
|
||||
end
|
||||
end
|
||||
|
||||
Reference in New Issue
Block a user