From 28ac1ce5fcd53426a87cda147be11ba9e38ea4ae Mon Sep 17 00:00:00 2001 From: Igor Belyakov Date: Thu, 29 Feb 2024 14:48:16 -0800 Subject: [PATCH] IGNITE-20139 Random forest stopping criteria check fixed. Gain calculation implemented. --- .../randomforest/RandomForestTrainer.java | 6 +- .../ml/tree/randomforest/data/NodeSplit.java | 12 ++- .../data/impurity/GiniHistogram.java | 42 +++++++-- .../data/impurity/ImpurityHistogram.java | 7 +- .../impurity/ImpurityHistogramsComputer.java | 4 +- .../data/impurity/MSEHistogram.java | 18 ++-- .../RandomForestClassifierTrainerTest.java | 86 +++++++++++++++++++ .../tree/randomforest/RandomForestTest.java | 14 +-- 8 files changed, 158 insertions(+), 31 deletions(-) diff --git a/modules/ml-ext/ml/src/main/java/org/apache/ignite/ml/tree/randomforest/RandomForestTrainer.java b/modules/ml-ext/ml/src/main/java/org/apache/ignite/ml/tree/randomforest/RandomForestTrainer.java index 625944183..a133846e5 100644 --- a/modules/ml-ext/ml/src/main/java/org/apache/ignite/ml/tree/randomforest/RandomForestTrainer.java +++ b/modules/ml-ext/ml/src/main/java/org/apache/ignite/ml/tree/randomforest/RandomForestTrainer.java @@ -274,7 +274,7 @@ private void split(Queue learningQueue, Map nodesToL if (bestSplit.isPresent()) bestSplit.get().createLeaf(cornerNode); else { - cornerNode.setImpurity(Double.NEGATIVE_INFINITY); + cornerNode.setImpurity(0.0); cornerNode.toLeaf(0.0); } } @@ -383,8 +383,8 @@ private List defaultNodesToLearnSelectionStrgy(Queue queue) * @return true if split is needed. */ boolean needSplit(TreeNode parentNode, Optional split) { - return split.isPresent() && parentNode.getImpurity() - split.get().getImpurity() > minImpurityDelta && - parentNode.getDepth() < (maxDepth + 1); + return split.isPresent() && split.get().getGain() > minImpurityDelta && + parentNode.getDepth() < (maxDepth + 1); } /** diff --git a/modules/ml-ext/ml/src/main/java/org/apache/ignite/ml/tree/randomforest/data/NodeSplit.java b/modules/ml-ext/ml/src/main/java/org/apache/ignite/ml/tree/randomforest/data/NodeSplit.java index f195c175b..898da2f6c 100644 --- a/modules/ml-ext/ml/src/main/java/org/apache/ignite/ml/tree/randomforest/data/NodeSplit.java +++ b/modules/ml-ext/ml/src/main/java/org/apache/ignite/ml/tree/randomforest/data/NodeSplit.java @@ -36,6 +36,9 @@ public class NodeSplit implements Serializable { /** Impurity at this split point. */ private double impurity; + /** Gain at this split point. */ + private double gain; + /** */ public NodeSplit() { } @@ -45,11 +48,13 @@ public NodeSplit() { * * @param featureId Feature id. * @param val Feature split value. + * @param gain Gain value. * @param impurity Impurity value. */ - public NodeSplit(int featureId, double val, double impurity) { + public NodeSplit(int featureId, double val, double gain, double impurity) { this.featureId = featureId; this.val = val; + this.gain = gain; this.impurity = impurity; } @@ -80,6 +85,11 @@ public double getImpurity() { return impurity; } + /** */ + public double getGain() { + return gain; + } + /** */ public double getVal() { return val; diff --git a/modules/ml-ext/ml/src/main/java/org/apache/ignite/ml/tree/randomforest/data/impurity/GiniHistogram.java b/modules/ml-ext/ml/src/main/java/org/apache/ignite/ml/tree/randomforest/data/impurity/GiniHistogram.java index 864f3e8f9..20817b1f1 100644 --- a/modules/ml-ext/ml/src/main/java/org/apache/ignite/ml/tree/randomforest/data/impurity/GiniHistogram.java +++ b/modules/ml-ext/ml/src/main/java/org/apache/ignite/ml/tree/randomforest/data/impurity/GiniHistogram.java @@ -18,6 +18,7 @@ package org.apache.ignite.ml.tree.randomforest.data.impurity; import java.util.ArrayList; +import java.util.Arrays; import java.util.HashMap; import java.util.HashSet; import java.util.List; @@ -33,6 +34,8 @@ import org.apache.ignite.ml.tree.randomforest.data.NodeSplit; import org.apache.ignite.ml.tree.randomforest.data.impurity.basic.CountersHistogram; +import static org.apache.commons.math3.util.Precision.EPSILON; + /** * Class contains implementation of splitting point finding algorithm based on Gini metric (see * https://en.wikipedia.org/wiki/Gini_coefficient) and represents a set of histograms in according to this metric. @@ -104,6 +107,8 @@ public GiniHistogram(int sampleId, Map lblMapping, BucketMeta b double bestImpurity = Double.POSITIVE_INFINITY; double bestSplitVal = Double.NEGATIVE_INFINITY; int bestBucketId = -1; + double bestGain = 0; + double nodeImpurity = 0; List> countersDistribPerCls = hists.stream() .map(ObjectHistogram::computeDistributionFunction) @@ -113,6 +118,16 @@ public GiniHistogram(int sampleId, Map lblMapping, BucketMeta b .mapToDouble(x -> x.isEmpty() ? 0.0 : x.lastEntry().getValue()) .toArray(); + double totalSampleCnt = Arrays.stream(totalSampleCntPerLb).sum(); + + for (int lbId = 0; lbId < lblMapping.size(); lbId++) { + double lblProbability = totalSampleCntPerLb[lbId] / totalSampleCnt; + nodeImpurity += (lblProbability * (1 - lblProbability)); + } + + if (nodeImpurity < EPSILON) + return Optional.empty(); + Map lastLeftValues = new HashMap<>(); for (int i = 0; i < lblMapping.size(); i++) lastLeftValues.put(i, 0.0); @@ -140,24 +155,35 @@ public GiniHistogram(int sampleId, Map lblMapping, BucketMeta b //count of samples with label [corresponding lblId] to the left of bucket Double toLeftCnt = countersDistribPerCls.get(lbId).getOrDefault(bucketId, lastLeftValues.get(lbId)); - if (toLeftCnt > 0) - leftImpurity += Math.pow(toLeftCnt, 2) / totalToleftCnt; + if (toLeftCnt > 0) { + double lblLeftProbability = toLeftCnt / totalToleftCnt; + leftImpurity += (lblLeftProbability * (1 - lblLeftProbability)); + } //number of samples to the right of bucket = total samples count - toLeftCnt double toRightCnt = totalSampleCntPerLb[lbId] - toLeftCnt; - if (toRightCnt > 0) - rightImpurity += (Math.pow(toRightCnt, 2)) / totalToRightCnt; + + if (toRightCnt > 0) { + double lblRightProbability = toRightCnt / totalToRightCnt; + rightImpurity += (lblRightProbability * (1 - lblRightProbability)); + } } - double impurityInBucket = -(leftImpurity + rightImpurity); - if (impurityInBucket <= bestImpurity) { - bestImpurity = impurityInBucket; + double leftWeight = totalToleftCnt / (totalToleftCnt + totalToRightCnt); + double rightWeight = totalToRightCnt / (totalToleftCnt + totalToRightCnt); + + double weightedSplitImpurity = leftImpurity * leftWeight + rightImpurity * rightWeight; + + double gain = nodeImpurity - weightedSplitImpurity; + + if (gain > bestGain) { bestSplitVal = bucketMeta.bucketIdToValue(bucketId); bestBucketId = bucketId; + bestGain = gain; } } - return checkAndReturnSplitValue(bestBucketId, bestSplitVal, bestImpurity); + return checkAndReturnSplitValue(bestBucketId, bestSplitVal, bestGain, bestImpurity); } /** {@inheritDoc} */ diff --git a/modules/ml-ext/ml/src/main/java/org/apache/ignite/ml/tree/randomforest/data/impurity/ImpurityHistogram.java b/modules/ml-ext/ml/src/main/java/org/apache/ignite/ml/tree/randomforest/data/impurity/ImpurityHistogram.java index 964c3b215..37533a658 100644 --- a/modules/ml-ext/ml/src/main/java/org/apache/ignite/ml/tree/randomforest/data/impurity/ImpurityHistogram.java +++ b/modules/ml-ext/ml/src/main/java/org/apache/ignite/ml/tree/randomforest/data/impurity/ImpurityHistogram.java @@ -51,14 +51,15 @@ public ImpurityHistogram(int featureId) { * * @param bestBucketId Best bucket id. * @param bestSplitVal Best split value. - * @param bestImpurity Best impurity. + * @param bestGain Best gain. + * @param impurity Node's impurity. * @return Best split value. */ - protected Optional checkAndReturnSplitValue(int bestBucketId, double bestSplitVal, double bestImpurity) { + protected Optional checkAndReturnSplitValue(int bestBucketId, double bestSplitVal, double bestGain, double impurity) { if (isLastBucket(bestBucketId)) return Optional.empty(); else - return Optional.of(new NodeSplit(featureId, bestSplitVal, bestImpurity)); + return Optional.of(new NodeSplit(featureId, bestSplitVal, bestGain, impurity)); } /** diff --git a/modules/ml-ext/ml/src/main/java/org/apache/ignite/ml/tree/randomforest/data/impurity/ImpurityHistogramsComputer.java b/modules/ml-ext/ml/src/main/java/org/apache/ignite/ml/tree/randomforest/data/impurity/ImpurityHistogramsComputer.java index 521b42622..baae70d14 100644 --- a/modules/ml-ext/ml/src/main/java/org/apache/ignite/ml/tree/randomforest/data/impurity/ImpurityHistogramsComputer.java +++ b/modules/ml-ext/ml/src/main/java/org/apache/ignite/ml/tree/randomforest/data/impurity/ImpurityHistogramsComputer.java @@ -204,8 +204,8 @@ public NodeId getNodeId() { */ public Optional findBestSplit() { return perFeatureStatistics.values().stream() - .flatMap(x -> x.findBestSplit().map(Stream::of).orElse(Stream.empty())) - .min(Comparator.comparingDouble(NodeSplit::getImpurity)); + .flatMap(x -> x.findBestSplit().map(Stream::of).orElse(Stream.empty())) + .max(Comparator.comparingDouble(NodeSplit::getGain)); } } } diff --git a/modules/ml-ext/ml/src/main/java/org/apache/ignite/ml/tree/randomforest/data/impurity/MSEHistogram.java b/modules/ml-ext/ml/src/main/java/org/apache/ignite/ml/tree/randomforest/data/impurity/MSEHistogram.java index b10f8eeb1..5dc5e8bdd 100644 --- a/modules/ml-ext/ml/src/main/java/org/apache/ignite/ml/tree/randomforest/data/impurity/MSEHistogram.java +++ b/modules/ml-ext/ml/src/main/java/org/apache/ignite/ml/tree/randomforest/data/impurity/MSEHistogram.java @@ -97,9 +97,9 @@ public MSEHistogram(int sampleId, BucketMeta bucketMeta) { /** {@inheritDoc} */ @Override public Optional findBestSplit() { - double bestImpurity = Double.POSITIVE_INFINITY; double bestSplitVal = Double.NEGATIVE_INFINITY; int bestBucketId = -1; + double bestGain = 0; //counter corresponds to number of samples //ys corresponds to sumOfLabels @@ -112,6 +112,8 @@ public MSEHistogram(int sampleId, BucketMeta bucketMeta) { double ysMax = ysDistrib.lastEntry().getValue(); double y2sMax = y2sDistrib.lastEntry().getValue(); + double nodeImpurity = impurity(cntrMax, ysMax, y2sMax); + double lastLeftCntrVal = 0.0; double lastLeftYVal = 0.0; double lastLeftY2Val = 0.0; @@ -127,21 +129,23 @@ public MSEHistogram(int sampleId, BucketMeta bucketMeta) { double rightY = ysMax - leftY; double rightY2 = y2sMax - leftY2; - double impurity = 0.0; + double childrenImpurity = 0.0; if (leftCnt > 0) - impurity += impurity(leftCnt, leftY, leftY2); + childrenImpurity += impurity(leftCnt, leftY, leftY2); if (rightCnt > 0) - impurity += impurity(rightCnt, rightY, rightY2); + childrenImpurity += impurity(rightCnt, rightY, rightY2); + + double gain = nodeImpurity - childrenImpurity; - if (impurity < bestImpurity) { - bestImpurity = impurity; + if (gain > bestGain) { + bestGain = gain; bestSplitVal = bucketMeta.bucketIdToValue(bucketId); bestBucketId = bucketId; } } - return checkAndReturnSplitValue(bestBucketId, bestSplitVal, bestImpurity); + return checkAndReturnSplitValue(bestBucketId, bestSplitVal, bestGain, nodeImpurity); } /** diff --git a/modules/ml-ext/ml/src/test/java/org/apache/ignite/ml/tree/randomforest/RandomForestClassifierTrainerTest.java b/modules/ml-ext/ml/src/test/java/org/apache/ignite/ml/tree/randomforest/RandomForestClassifierTrainerTest.java index 4d2ef46fd..ba9dc20c9 100644 --- a/modules/ml-ext/ml/src/test/java/org/apache/ignite/ml/tree/randomforest/RandomForestClassifierTrainerTest.java +++ b/modules/ml-ext/ml/src/test/java/org/apache/ignite/ml/tree/randomforest/RandomForestClassifierTrainerTest.java @@ -19,9 +19,13 @@ import java.util.ArrayList; import java.util.HashMap; +import java.util.List; import java.util.Map; +import java.util.Random; +import org.apache.ignite.ml.IgniteModel; import org.apache.ignite.ml.TestUtils; import org.apache.ignite.ml.common.TrainerTest; +import org.apache.ignite.ml.composition.ModelsComposition; import org.apache.ignite.ml.composition.predictionsaggregator.OnMajorityPredictionsAggregator; import org.apache.ignite.ml.dataset.feature.FeatureMeta; import org.apache.ignite.ml.dataset.feature.extractor.impl.LabeledDummyVectorizer; @@ -29,9 +33,15 @@ import org.apache.ignite.ml.math.primitives.vector.VectorUtils; import org.apache.ignite.ml.structures.LabeledVector; import org.apache.ignite.ml.trainers.DatasetTrainer; +import org.apache.ignite.ml.tree.randomforest.data.RandomForestTreeModel; +import org.apache.ignite.ml.tree.randomforest.data.TreeNode; +import org.apache.ignite.testframework.GridTestUtils; import org.junit.Test; +import static org.apache.ignite.ml.tree.randomforest.data.FeaturesCountSelectionStrategies.SQRT; +import static org.apache.ignite.ml.tree.randomforest.data.TreeNode.Type.LEAF; import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNull; import static org.junit.Assert.assertTrue; /** @@ -97,4 +107,80 @@ public void testUpdate() { assertEquals(originalMdl.predict(v), updatedOnSameDS.predict(v), 0.01); assertEquals(originalMdl.predict(v), updatedOnEmptyDS.predict(v), 0.01); } + + /** + * Test checks whether the tree has nodes duplication. + */ + @Test + public void testDuplicateNodes() { + int sampleSize = 500; + Random rnd = new Random(1); + Map> sample = new HashMap<>(); + for (int i = 0; i < sampleSize; i++) { + sample.put(i, VectorUtils.of(rnd.nextDouble(), rnd.nextDouble(), rnd.nextDouble(), rnd.nextDouble()) + .labeled((double) i % 2)); + } + + ArrayList meta = new ArrayList<>(); + for (int i = 0; i < 4; i++) + meta.add(new FeatureMeta("", i, false)); + DatasetTrainer trainer = new RandomForestClassifierTrainer(meta) + .withAmountOfTrees(1) + .withMaxDepth(10) + .withFeaturesCountSelectionStrgy(SQRT) + .withSeed(777) + .withEnvironmentBuilder(TestUtils.testEnvBuilder()); + + ModelsComposition mdl = trainer.fit(sample, parts, new LabeledDummyVectorizer<>()); + + List> models = mdl.getModels(); + + assertEquals(1, mdl.getModels().size()); + + RandomForestTreeModel tree = (RandomForestTreeModel) models.get(0); + + TreeNode repeatingNode = findDuplicatedNode(tree.getRootNode()); + + assertNull(repeatingNode); + } + + /** + * Go through the tree and find a node branch that has repeating feature + value. + */ + private static TreeNode findDuplicatedNode(TreeNode node) { + if (node.getType() == LEAF) { + return null; + } + + TreeNode left = node.getLeft(); + if (getFeatureId(node) == getFeatureId(left) && getVal(node) == getVal(left)) { + return left; + } + + TreeNode inLeftBranch = findDuplicatedNode(left); + if (inLeftBranch != null) { + return inLeftBranch; + } + + TreeNode right = node.getRight(); + if (getFeatureId(node) == getFeatureId(right) && getVal(node) == getVal(right)) { + return right; + } + + return findDuplicatedNode(right); + } + + /** + * Get node's value + */ + private static double getVal(TreeNode node) { + return GridTestUtils.getFieldValue(node, TreeNode.class, "val"); + } + + /** + * Get node's feature id + */ + private static int getFeatureId(TreeNode node) { + return GridTestUtils.getFieldValue(node, TreeNode.class, "featureId"); + } } diff --git a/modules/ml-ext/ml/src/test/java/org/apache/ignite/ml/tree/randomforest/RandomForestTest.java b/modules/ml-ext/ml/src/test/java/org/apache/ignite/ml/tree/randomforest/RandomForestTest.java index eb81b36a7..76a766586 100644 --- a/modules/ml-ext/ml/src/test/java/org/apache/ignite/ml/tree/randomforest/RandomForestTest.java +++ b/modules/ml-ext/ml/src/test/java/org/apache/ignite/ml/tree/randomforest/RandomForestTest.java @@ -37,7 +37,7 @@ public class RandomForestTest { private final int cntOfTrees = 10; /** Min imp delta. */ - private final double minImpDelta = 1.0; + private final double minImpDelta = 0.1; /** Max depth. */ private final int maxDepth = 1; @@ -66,13 +66,13 @@ public class RandomForestTest { @Test public void testNeedSplit() { TreeNode node = new TreeNode(1, 1); - node.setImpurity(1000); - assertTrue(rf.needSplit(node, Optional.of(new NodeSplit(0, 0, node.getImpurity() - minImpDelta * 1.01)))); - assertFalse(rf.needSplit(node, Optional.of(new NodeSplit(0, 0, node.getImpurity() - minImpDelta * 0.5)))); - assertFalse(rf.needSplit(node, Optional.of(new NodeSplit(0, 0, node.getImpurity())))); + node.setImpurity(1.0); + assertTrue(rf.needSplit(node, Optional.of(new NodeSplit(0, 0, minImpDelta * 1.01, node.getImpurity())))); + assertFalse(rf.needSplit(node, Optional.of(new NodeSplit(0, 0, minImpDelta * 0.5, node.getImpurity())))); + assertFalse(rf.needSplit(node, Optional.of(new NodeSplit(0, 0, 0, node.getImpurity())))); TreeNode child = node.toConditional(0, 0).get(0); - child.setImpurity(1000); - assertFalse(rf.needSplit(child, Optional.of(new NodeSplit(0, 0, child.getImpurity() - minImpDelta * 1.01)))); + child.setImpurity(1.0); + assertFalse(rf.needSplit(child, Optional.of(new NodeSplit(0, 0, child.getImpurity(), minImpDelta * 1.01)))); } }