diff --git a/src/main/java/quickml/supervised/classifier/decisionTree/TreeBuilder.java b/src/main/java/quickml/supervised/classifier/decisionTree/TreeBuilder.java index 8df12f79..f0b38802 100644 --- a/src/main/java/quickml/supervised/classifier/decisionTree/TreeBuilder.java +++ b/src/main/java/quickml/supervised/classifier/decisionTree/TreeBuilder.java @@ -43,12 +43,14 @@ public final class TreeBuilder implements Predicti private static final int HARD_MINIMUM_INSTANCES_PER_CATEGORICAL_VALUE = 10; public static final String MIN_SPLIT_FRACTION = "minSplitFraction"; public static final String EXEMPT_ATTRIBUTES = "exemptAttributes"; + public static final String IMBALANCE_PENALTY_POWER = "impbalancePenaltyPower"; private Scorer scorer; private int maxDepth = 5; private double minimumScore = 0.00000000000001; private int minDiscreteAttributeValueOccurances = 0; private double minSplitFraction = .005; + private double imbalancePenaltyPower = 0; private Set exemptAttributes = Sets.newHashSet(); private int minLeafInstances = 0; @@ -77,6 +79,10 @@ public TreeBuilder attributeIgnoringStrategy(AttributeIgnoringStrategy attribute this.attributeIgnoringStrategy = attributeIgnoringStrategy; return this; } + public TreeBuilder imbalancePenaltyPower(double imbalancePenaltyPower ) { + this.imbalancePenaltyPower = imbalancePenaltyPower; + return this; + } public TreeBuilder exemptAttributes(Set exemptAttributes) { this.exemptAttributes = exemptAttributes; @@ -108,6 +114,7 @@ public TreeBuilder copy() { copy.fractionOfDataToUseInHoldOutSet = fractionOfDataToUseInHoldOutSet; copy.minSplitFraction = minSplitFraction; copy.exemptAttributes = exemptAttributes; + copy.imbalancePenaltyPower = imbalancePenaltyPower; return copy; } @@ -137,8 +144,9 @@ public void updateBuilderConfig(final Map cfg) { if (cfg.containsKey(ATTRIBUTE_IGNORING_STRATEGY)) attributeIgnoringStrategy((AttributeIgnoringStrategy) cfg.get(ATTRIBUTE_IGNORING_STRATEGY)); if (cfg.containsKey(IGNORE_ATTR_PROB)) - ignoreAttributeAtNodeProbability((Double)cfg.get(IGNORE_ATTR_PROB)); - + ignoreAttributeAtNodeProbability((Double) cfg.get(IGNORE_ATTR_PROB)); + if (cfg.containsKey(IMBALANCE_PENALTY_POWER)) + imbalancePenaltyPower((Double)cfg.get(IMBALANCE_PENALTY_POWER)); penalizeCategoricalSplitsBySplitAttributeIntrinsicValue(cfg.containsKey(PENALIZE_CATEGORICAL_SPLITS) ? (Boolean) cfg.get(PENALIZE_CATEGORICAL_SPLITS) : true); } @@ -434,7 +442,7 @@ private Pair createTwoClassCategoricalNode(Node parent double bestScore = 0; final Pair> valueOutcomeCountsPairs = - ClassificationCounter.getSortedListOfAttributeValuesWithClassificationCounters(instances, attribute, minorityClassification); //returs a list of ClassificationCounterList + ClassificationCounter.getSortedListOfAttributeValuesWithClassificationCounters(instances, attribute, minorityClassification); //returns a list of ClassificationCounterList ClassificationCounter outCounts = new ClassificationCounter(valueOutcomeCountsPairs.getValue0()); //classification counter treating all values the same ClassificationCounter inCounts = new ClassificationCounter(); //the histogram of counts by classification for the in-set @@ -474,8 +482,11 @@ private Pair createTwoClassCategoricalNode(Node parent double thisScore = scorer.scoreSplit(inCounts, outCounts); valuesInTheInset++; if (penalizeCategoricalSplitsBySplitAttributeIntrinsicValue) { - thisScore = thisScore * (1 - degreeOfGainRatioPenalty) + degreeOfGainRatioPenalty * (thisScore / intrinsicValueOfAttribute); } - + thisScore = thisScore * (1 - degreeOfGainRatioPenalty) + degreeOfGainRatioPenalty * (thisScore / intrinsicValueOfAttribute); + } + if (imbalancePenaltyPower!=0) { + thisScore/=Math.pow(Math.min(inCounts.getTotal(), outCounts.getTotal()), imbalancePenaltyPower); + } if (thisScore > bestScore) { bestScore = thisScore; lastValOfInset = valueWithClassificationCounter.attributeValue; @@ -712,6 +723,11 @@ private Pair createNumericBranch(Node parent, final St double thisScore = scorer.scoreSplit(inClassificationCounts, outClassificationCounts); + + if (imbalancePenaltyPower!=0) { + thisScore/=Math.pow(Math.min(inClassificationCounts.getTotal(), outClassificationCounts.getTotal()), imbalancePenaltyPower); + } + if (thisScore > bestScore) { bestScore = thisScore; bestThreshold = threshold; diff --git a/src/test/java/quickml/supervised/predictiveModelOptimizer/PredictiveModelOptimizerIntegrationTest.java b/src/test/java/quickml/supervised/predictiveModelOptimizer/PredictiveModelOptimizerIntegrationTest.java index 5749379e..8acaa1fb 100644 --- a/src/test/java/quickml/supervised/predictiveModelOptimizer/PredictiveModelOptimizerIntegrationTest.java +++ b/src/test/java/quickml/supervised/predictiveModelOptimizer/PredictiveModelOptimizerIntegrationTest.java @@ -35,7 +35,7 @@ public class PredictiveModelOptimizerIntegrationTest { @Before public void setUp() throws Exception { List advertisingInstances = getAdvertisingInstances(); - advertisingInstances = advertisingInstances.subList(0, 3000); + // advertisingInstances = advertisingInstances.subList(0, 3000); optimizer = new PredictiveModelOptimizerBuilder() .modelBuilder(new RandomForestBuilder<>()) .dataCycler(new OutOfTimeData<>(advertisingInstances, 0.2, 12, new OnespotDateTimeExtractor())) @@ -62,17 +62,17 @@ private Map createConfig() { CompositeAttributeIgnoringStrategy compositeAttributeIgnoringStrategy = new CompositeAttributeIgnoringStrategy(Arrays.asList( new IgnoreAttributesWithConstantProbability(0.7), new IgnoreAttributesInSet(attributesToIgnore, probabilityOfDiscardingFromAttributesToIgnore) )); - config.put(ATTRIBUTE_IGNORING_STRATEGY, new FixedOrderRecommender(new IgnoreAttributesWithConstantProbability(0.7), compositeAttributeIgnoringStrategy )); - config.put(NUM_TREES, new MonotonicConvergenceRecommender(asList(2, 4, 8), 0.02)); - config.put(MAX_DEPTH, new FixedOrderRecommender( 4, 8, 16));//Integer.MAX_VALUE, 2, 3, 5, 6, 9)); + config.put(ATTRIBUTE_IGNORING_STRATEGY, new FixedOrderRecommender(new IgnoreAttributesWithConstantProbability(0.7)));//, compositeAttributeIgnoringStrategy )); + config.put(NUM_TREES, new MonotonicConvergenceRecommender(asList(8), 0.02)); + config.put(MAX_DEPTH, new FixedOrderRecommender(8));//, 16));//Integer.MAX_VALUE, 2, 3, 5, 6, 9)); config.put(MIN_SCORE, new FixedOrderRecommender(0.00000000000001));//, Double.MIN_VALUE, 0.0, 0.000001, 0.0001, 0.001, 0.01, 0.1)); - config.put(MIN_OCCURRENCES_OF_ATTRIBUTE_VALUE, new FixedOrderRecommender(2, 11, 16, 30 )); - config.put(MIN_LEAF_INSTANCES, new FixedOrderRecommender(0, 20, 40)); - config.put(SCORER, new FixedOrderRecommender(new InformationGainScorer(), new GiniImpurityScorer())); - config.put(DEGREE_OF_GAIN_RATIO_PENALTY, new FixedOrderRecommender(1.0, 0.75, .5 )); - config.put(MIN_SPLIT_FRACTION, new FixedOrderRecommender(0.01, 0.25, .5 )); - config.put(EXEMPT_ATTRIBUTES, new FixedOrderRecommender(exemptAttributes)); - + config.put(MIN_OCCURRENCES_OF_ATTRIBUTE_VALUE, new FixedOrderRecommender(11));//;, 16, 30 )); + config.put(MIN_LEAF_INSTANCES, new FixedOrderRecommender(20));//, 40)); + config.put(SCORER, new FixedOrderRecommender(new GiniImpurityScorer()));//, new InformationGainScorer())), ; + config.put(DEGREE_OF_GAIN_RATIO_PENALTY, new FixedOrderRecommender(1.0));//, 0.75, .5 )); + config.put(MIN_SPLIT_FRACTION, new FixedOrderRecommender(0.001));// 0.25, .5 )); + // config.put(EXEMPT_ATTRIBUTES, new FixedOrderRecommender(exemptAttributes)); + config.put(IMBALANCE_PENALTY_POWER, new FixedOrderRecommender(0.0, 1.0, 2.0)); return config; }