mirror of
https://github.com/dkam/decisiontree.git
synced 2025-12-28 15:14:52 +00:00
Merge pull request #8 from rustyio/fix-infinite-recursion
Fix infinite recursion.
This commit is contained in:
@@ -69,9 +69,14 @@ module DecisionTree
|
|||||||
# return classification if all examples have the same classification
|
# return classification if all examples have the same classification
|
||||||
return data.first.last if data.classification.uniq.size == 1
|
return data.first.last if data.classification.uniq.size == 1
|
||||||
|
|
||||||
# Choose best attribute (1. enumerate all attributes / 2. Pick best attribute)
|
# Choose best attribute:
|
||||||
|
# 1. enumerate all attributes
|
||||||
|
# 2. Pick best attribute
|
||||||
|
# 3. If attributes all score the same, then pick a random one to avoid infinite recursion.
|
||||||
performance = attributes.collect { |attribute| fitness_for(attribute).call(data, attributes, attribute) }
|
performance = attributes.collect { |attribute| fitness_for(attribute).call(data, attributes, attribute) }
|
||||||
max = performance.max { |a,b| a[0] <=> b[0] }
|
max = performance.max { |a,b| a[0] <=> b[0] }
|
||||||
|
min = performance.min { |a,b| a[0] <=> b[0] }
|
||||||
|
max = performance.shuffle.first if max[0] == min[0]
|
||||||
best = Node.new(attributes[performance.index(max)], max[1], max[0])
|
best = Node.new(attributes[performance.index(max)], max[1], max[0])
|
||||||
best.threshold = nil if @type == :discrete
|
best.threshold = nil if @type == :discrete
|
||||||
@used.has_key?(best.attribute) ? @used[best.attribute] += [best.threshold] : @used[best.attribute] = [best.threshold]
|
@used.has_key?(best.attribute) ? @used[best.attribute] += [best.threshold] : @used[best.attribute] = [best.threshold]
|
||||||
|
|||||||
@@ -74,4 +74,19 @@ describe describe DecisionTree::ID3Tree do
|
|||||||
Then { tree.predict([2, "blue"]).should == "not angry" }
|
Then { tree.predict([2, "blue"]).should == "not angry" }
|
||||||
end
|
end
|
||||||
|
|
||||||
|
describe "infinite recursion case" do
|
||||||
|
Given(:labels) { [:a, :b, :c] }
|
||||||
|
Given(:data) do
|
||||||
|
[
|
||||||
|
["a1", "b0", "c0", "RED"],
|
||||||
|
["a1", "b1", "c1", "RED"],
|
||||||
|
["a1", "b1", "c0", "BLUE"],
|
||||||
|
["a1", "b0", "c1", "BLUE"]
|
||||||
|
]
|
||||||
|
end
|
||||||
|
Given(:tree) { DecisionTree::ID3Tree.new(labels, data, "RED", :discrete) }
|
||||||
|
When { tree.train }
|
||||||
|
Then { tree.predict(["a1","b0","c0"]).should == "RED" }
|
||||||
|
end
|
||||||
|
|
||||||
end
|
end
|
||||||
|
|||||||
Reference in New Issue
Block a user