From 4da9dcf68eb01cafda2635a7fcd9853d2a5c5195 Mon Sep 17 00:00:00 2001 From: Ilya Grigorik Date: Sat, 14 Sep 2013 11:01:27 -0700 Subject: [PATCH] convert labels to string - closes #13 --- lib/decisiontree/id3_tree.rb | 1 + spec/id3_spec.rb | 16 ++++++++++++++++ 2 files changed, 17 insertions(+) diff --git a/lib/decisiontree/id3_tree.rb b/lib/decisiontree/id3_tree.rb index 8f9a7e8..04c9191 100755 --- a/lib/decisiontree/id3_tree.rb +++ b/lib/decisiontree/id3_tree.rb @@ -42,6 +42,7 @@ module DecisionTree end def train(data=@data, attributes=@attributes, default=@default) + attributes = attributes.map {|e| e.to_s} initialize(attributes, data, default, @type) # Remove samples with same attributes leaving most common classification diff --git a/spec/id3_spec.rb b/spec/id3_spec.rb index 5eb67ba..b86b802 100644 --- a/spec/id3_spec.rb +++ b/spec/id3_spec.rb @@ -89,4 +89,20 @@ describe describe DecisionTree::ID3Tree do Then { tree.predict(["a1","b0","c0"]).should == "RED" } end + describe "numerical labels case" do + Given(:labels) { [1, 2] } + Given(:data) do + [ + [1, 1, true], + [1, 2, false], + [2, 1, false], + [2, 2, true] + ] + end + Given(:tree) { DecisionTree::ID3Tree.new labels, data, nil, :discrete } + When { tree.train } + Then { + lambda { tree.predict([1, 1]) }.should_not raise_error + } + end end