Skip to content

Commit

Permalink
Merge pull request #101 from sanity/TreeBuilderR
Browse files Browse the repository at this point in the history
Tree builder r
  • Loading branch information
athawk81 committed Mar 8, 2015
2 parents da04812 + 4539118 commit cf5bb08
Show file tree
Hide file tree
Showing 26 changed files with 483 additions and 188 deletions.
4 changes: 3 additions & 1 deletion pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,9 @@
4) *ANY* change to the master branch (ie. when a feature branch is merged) must
be accompanied by a bump in version number, regardless of how minor the change.
-->
<version>0.6.1</version>

<version>0.7.0</version>


<repositories>
<repository>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import quickml.supervised.classifier.decisionTree.TreeBuilder;
import quickml.supervised.classifier.decisionTree.scorers.GiniImpurityScorer;
import quickml.supervised.classifier.decisionTree.scorers.InformationGainScorer;
import quickml.supervised.classifier.decisionTree.tree.attributeIgnoringStrategies.IgnoreAttributesWithConstantProbability;
import quickml.supervised.classifier.downsampling.DownsamplingClassifier;
import quickml.supervised.classifier.downsampling.DownsamplingClassifierBuilder;
import quickml.supervised.classifier.randomForest.RandomForest;
Expand Down Expand Up @@ -64,7 +65,7 @@ public static Pair<Map<String, Object>, DownsamplingClassifier> getOptimizedDown
.iterations(3).build();
Map<String, Object> bestParams = optimizer.determineOptimalConfig();

RandomForestBuilder<ClassifierInstance> randomForestBuilder = new RandomForestBuilder<>(new TreeBuilder<>().ignoreAttributeAtNodeProbability(0.7)).numTrees(24);
RandomForestBuilder<ClassifierInstance> randomForestBuilder = new RandomForestBuilder<>(new TreeBuilder<>().attributeIgnoringStrategy(new IgnoreAttributesWithConstantProbability(0.7))).numTrees(24);
DownsamplingClassifierBuilder<ClassifierInstance> downsamplingClassifierBuilder = new DownsamplingClassifierBuilder<>(randomForestBuilder,0.1);
downsamplingClassifierBuilder.updateBuilderConfig(bestParams);

Expand All @@ -91,7 +92,7 @@ private static int getTimeSliceHours(List<ClassifierInstance> trainingData, int
private static Map<String, FieldValueRecommender> createConfig() {
Map<String, FieldValueRecommender> config = Maps.newHashMap();
config.put(MAX_DEPTH, new FixedOrderRecommender(4, 8, 16));//Integer.MAX_VALUE, 2, 3, 5, 6, 9));
config.put(MIN_CAT_ATTR_OCC, new FixedOrderRecommender(7, 14));
config.put(MIN_OCCURRENCES_OF_ATTRIBUTE_VALUE, new FixedOrderRecommender(7, 14));
config.put(MIN_LEAF_INSTANCES, new FixedOrderRecommender(0, 15));
config.put(DownsamplingClassifierBuilder.MINORITY_INSTANCE_PROPORTION, new FixedOrderRecommender(.1, .2));
config.put(DEGREE_OF_GAIN_RATIO_PENALTY, new FixedOrderRecommender(1.0, 0.75));
Expand Down
20 changes: 10 additions & 10 deletions src/main/java/quickml/supervised/classifier/decisionTree/Tree.java
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,11 @@
*/
public class Tree extends AbstractClassifier {
static final long serialVersionUID = 56394564395635672L;
public final Node node;
public final Node root;
private Set<Serializable> classifications = new HashSet<>();

protected Tree(Node tree, Set<Serializable> classifications) {
this.node = tree;
protected Tree(Node root, Set<Serializable> classifications) {
this.root = root;
this.classifications = classifications;
}

Expand All @@ -35,19 +35,19 @@ public Set<Serializable> getClassifications() {

@Override
public double getProbability(AttributesMap attributes, Serializable classification) {
Leaf leaf = node.getLeaf(attributes);
Leaf leaf = root.getLeaf(attributes);
return leaf.getProbability(classification);
}

@Override
public double getProbabilityWithoutAttributes(AttributesMap attributes, Serializable classification, Set<String> attributesToIgnore) {
return node.getProbabilityWithoutAttributes(attributes, classification, attributesToIgnore);
return root.getProbabilityWithoutAttributes(attributes, classification, attributesToIgnore);
}


@Override
public PredictionMap predict(AttributesMap attributes) {
Leaf leaf = node.getLeaf(attributes);
Leaf leaf = root.getLeaf(attributes);
Map<Serializable, Double> probsByClassification = Maps.newHashMap();
for (Serializable classification : leaf.getClassifications()) {
probsByClassification.put(classification, leaf.getProbability(classification));
Expand All @@ -66,7 +66,7 @@ public PredictionMap predictWithoutAttributes(AttributesMap attributes, Set<Stri

@Override
public Serializable getClassificationByMaxProb(AttributesMap attributes) {
Leaf leaf = node.getLeaf(attributes);
Leaf leaf = root.getLeaf(attributes);
return leaf.getBestClassification();
}

Expand All @@ -77,20 +77,20 @@ public boolean equals(final Object o) {

final Tree tree = (Tree) o;

if (!node.equals(tree.node)) return false;
if (!root.equals(tree.root)) return false;

return true;
}

@Override
public int hashCode() {
return node.hashCode();
return root.hashCode();
}

@Override
public String toString() {
StringBuilder dump = new StringBuilder();
node.dump(dump);
root.dump(dump);
return dump.toString();
}
}
Loading

0 comments on commit cf5bb08

Please sign in to comment.