Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -274,7 +274,7 @@ private void split(Queue<TreeNode> learningQueue, Map<NodeId, TreeNode> nodesToL
if (bestSplit.isPresent())
bestSplit.get().createLeaf(cornerNode);
else {
cornerNode.setImpurity(Double.NEGATIVE_INFINITY);
cornerNode.setImpurity(0.0);
cornerNode.toLeaf(0.0);
}
}
Expand Down Expand Up @@ -383,8 +383,8 @@ private List<TreeNode> defaultNodesToLearnSelectionStrgy(Queue<TreeNode> queue)
* @return true if split is needed.
*/
boolean needSplit(TreeNode parentNode, Optional<NodeSplit> split) {
return split.isPresent() && parentNode.getImpurity() - split.get().getImpurity() > minImpurityDelta &&
parentNode.getDepth() < (maxDepth + 1);
return split.isPresent() && split.get().getGain() > minImpurityDelta &&
parentNode.getDepth() < (maxDepth + 1);
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,9 @@ public class NodeSplit implements Serializable {
/** Impurity at this split point. */
private double impurity;

/** Gain at this split point. */
private double gain;

/** */
public NodeSplit() {
}
Expand All @@ -45,11 +48,13 @@ public NodeSplit() {
*
* @param featureId Feature id.
* @param val Feature split value.
* @param gain Gain value.
* @param impurity Impurity value.
*/
public NodeSplit(int featureId, double val, double impurity) {
public NodeSplit(int featureId, double val, double gain, double impurity) {
this.featureId = featureId;
this.val = val;
this.gain = gain;
this.impurity = impurity;
}

Expand Down Expand Up @@ -80,6 +85,11 @@ public double getImpurity() {
return impurity;
}

/** */
public double getGain() {
return gain;
}

/** */
public double getVal() {
return val;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
package org.apache.ignite.ml.tree.randomforest.data.impurity;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
Expand All @@ -33,6 +34,8 @@
import org.apache.ignite.ml.tree.randomforest.data.NodeSplit;
import org.apache.ignite.ml.tree.randomforest.data.impurity.basic.CountersHistogram;

import static org.apache.commons.math3.util.Precision.EPSILON;

/**
* Class contains implementation of splitting point finding algorithm based on Gini metric (see
* https://en.wikipedia.org/wiki/Gini_coefficient) and represents a set of histograms in according to this metric.
Expand Down Expand Up @@ -104,6 +107,8 @@ public GiniHistogram(int sampleId, Map<Double, Integer> lblMapping, BucketMeta b
double bestImpurity = Double.POSITIVE_INFINITY;
double bestSplitVal = Double.NEGATIVE_INFINITY;
int bestBucketId = -1;
double bestGain = 0;
double nodeImpurity = 0;

List<TreeMap<Integer, Double>> countersDistribPerCls = hists.stream()
.map(ObjectHistogram::computeDistributionFunction)
Expand All @@ -113,6 +118,16 @@ public GiniHistogram(int sampleId, Map<Double, Integer> lblMapping, BucketMeta b
.mapToDouble(x -> x.isEmpty() ? 0.0 : x.lastEntry().getValue())
.toArray();

double totalSampleCnt = Arrays.stream(totalSampleCntPerLb).sum();

for (int lbId = 0; lbId < lblMapping.size(); lbId++) {
double lblProbability = totalSampleCntPerLb[lbId] / totalSampleCnt;
nodeImpurity += (lblProbability * (1 - lblProbability));
}

if (nodeImpurity < EPSILON)
return Optional.empty();

Map<Integer, Double> lastLeftValues = new HashMap<>();
for (int i = 0; i < lblMapping.size(); i++)
lastLeftValues.put(i, 0.0);
Expand Down Expand Up @@ -140,24 +155,35 @@ public GiniHistogram(int sampleId, Map<Double, Integer> lblMapping, BucketMeta b
//count of samples with label [corresponding lblId] to the left of bucket
Double toLeftCnt = countersDistribPerCls.get(lbId).getOrDefault(bucketId, lastLeftValues.get(lbId));

if (toLeftCnt > 0)
leftImpurity += Math.pow(toLeftCnt, 2) / totalToleftCnt;
if (toLeftCnt > 0) {
double lblLeftProbability = toLeftCnt / totalToleftCnt;
leftImpurity += (lblLeftProbability * (1 - lblLeftProbability));
}

//number of samples to the right of bucket = total samples count - toLeftCnt
double toRightCnt = totalSampleCntPerLb[lbId] - toLeftCnt;
if (toRightCnt > 0)
rightImpurity += (Math.pow(toRightCnt, 2)) / totalToRightCnt;

if (toRightCnt > 0) {
double lblRightProbability = toRightCnt / totalToRightCnt;
rightImpurity += (lblRightProbability * (1 - lblRightProbability));
}
}

double impurityInBucket = -(leftImpurity + rightImpurity);
if (impurityInBucket <= bestImpurity) {
bestImpurity = impurityInBucket;
double leftWeight = totalToleftCnt / (totalToleftCnt + totalToRightCnt);
double rightWeight = totalToRightCnt / (totalToleftCnt + totalToRightCnt);

double weightedSplitImpurity = leftImpurity * leftWeight + rightImpurity * rightWeight;

double gain = nodeImpurity - weightedSplitImpurity;

if (gain > bestGain) {
bestSplitVal = bucketMeta.bucketIdToValue(bucketId);
bestBucketId = bucketId;
bestGain = gain;
}
}

return checkAndReturnSplitValue(bestBucketId, bestSplitVal, bestImpurity);
return checkAndReturnSplitValue(bestBucketId, bestSplitVal, bestGain, bestImpurity);
}

/** {@inheritDoc} */
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,14 +51,15 @@ public ImpurityHistogram(int featureId) {
*
* @param bestBucketId Best bucket id.
* @param bestSplitVal Best split value.
* @param bestImpurity Best impurity.
* @param bestGain Best gain.
* @param impurity Node's impurity.
* @return Best split value.
*/
protected Optional<NodeSplit> checkAndReturnSplitValue(int bestBucketId, double bestSplitVal, double bestImpurity) {
protected Optional<NodeSplit> checkAndReturnSplitValue(int bestBucketId, double bestSplitVal, double bestGain, double impurity) {
if (isLastBucket(bestBucketId))
return Optional.empty();
else
return Optional.of(new NodeSplit(featureId, bestSplitVal, bestImpurity));
return Optional.of(new NodeSplit(featureId, bestSplitVal, bestGain, impurity));
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -204,8 +204,8 @@ public NodeId getNodeId() {
*/
public Optional<NodeSplit> findBestSplit() {
return perFeatureStatistics.values().stream()
.flatMap(x -> x.findBestSplit().map(Stream::of).orElse(Stream.empty()))
.min(Comparator.comparingDouble(NodeSplit::getImpurity));
.flatMap(x -> x.findBestSplit().map(Stream::of).orElse(Stream.empty()))
.max(Comparator.comparingDouble(NodeSplit::getGain));
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -97,9 +97,9 @@ public MSEHistogram(int sampleId, BucketMeta bucketMeta) {

/** {@inheritDoc} */
@Override public Optional<NodeSplit> findBestSplit() {
double bestImpurity = Double.POSITIVE_INFINITY;
double bestSplitVal = Double.NEGATIVE_INFINITY;
int bestBucketId = -1;
double bestGain = 0;

//counter corresponds to number of samples
//ys corresponds to sumOfLabels
Expand All @@ -112,6 +112,8 @@ public MSEHistogram(int sampleId, BucketMeta bucketMeta) {
double ysMax = ysDistrib.lastEntry().getValue();
double y2sMax = y2sDistrib.lastEntry().getValue();

double nodeImpurity = impurity(cntrMax, ysMax, y2sMax);

double lastLeftCntrVal = 0.0;
double lastLeftYVal = 0.0;
double lastLeftY2Val = 0.0;
Expand All @@ -127,21 +129,23 @@ public MSEHistogram(int sampleId, BucketMeta bucketMeta) {
double rightY = ysMax - leftY;
double rightY2 = y2sMax - leftY2;

double impurity = 0.0;
double childrenImpurity = 0.0;

if (leftCnt > 0)
impurity += impurity(leftCnt, leftY, leftY2);
childrenImpurity += impurity(leftCnt, leftY, leftY2);
if (rightCnt > 0)
impurity += impurity(rightCnt, rightY, rightY2);
childrenImpurity += impurity(rightCnt, rightY, rightY2);

double gain = nodeImpurity - childrenImpurity;

if (impurity < bestImpurity) {
bestImpurity = impurity;
if (gain > bestGain) {
bestGain = gain;
bestSplitVal = bucketMeta.bucketIdToValue(bucketId);
bestBucketId = bucketId;
}
}

return checkAndReturnSplitValue(bestBucketId, bestSplitVal, bestImpurity);
return checkAndReturnSplitValue(bestBucketId, bestSplitVal, bestGain, nodeImpurity);
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,19 +19,29 @@

import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Random;
import org.apache.ignite.ml.IgniteModel;
import org.apache.ignite.ml.TestUtils;
import org.apache.ignite.ml.common.TrainerTest;
import org.apache.ignite.ml.composition.ModelsComposition;
import org.apache.ignite.ml.composition.predictionsaggregator.OnMajorityPredictionsAggregator;
import org.apache.ignite.ml.dataset.feature.FeatureMeta;
import org.apache.ignite.ml.dataset.feature.extractor.impl.LabeledDummyVectorizer;
import org.apache.ignite.ml.math.primitives.vector.Vector;
import org.apache.ignite.ml.math.primitives.vector.VectorUtils;
import org.apache.ignite.ml.structures.LabeledVector;
import org.apache.ignite.ml.trainers.DatasetTrainer;
import org.apache.ignite.ml.tree.randomforest.data.RandomForestTreeModel;
import org.apache.ignite.ml.tree.randomforest.data.TreeNode;
import org.apache.ignite.testframework.GridTestUtils;
import org.junit.Test;

import static org.apache.ignite.ml.tree.randomforest.data.FeaturesCountSelectionStrategies.SQRT;
import static org.apache.ignite.ml.tree.randomforest.data.TreeNode.Type.LEAF;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertNull;
import static org.junit.Assert.assertTrue;

/**
Expand Down Expand Up @@ -97,4 +107,80 @@ public void testUpdate() {
assertEquals(originalMdl.predict(v), updatedOnSameDS.predict(v), 0.01);
assertEquals(originalMdl.predict(v), updatedOnEmptyDS.predict(v), 0.01);
}

/**
* Test checks whether the tree has nodes duplication.
*/
@Test
public void testDuplicateNodes() {
int sampleSize = 500;
Random rnd = new Random(1);
Map<Integer, LabeledVector<Double>> sample = new HashMap<>();
for (int i = 0; i < sampleSize; i++) {
sample.put(i, VectorUtils.of(rnd.nextDouble(), rnd.nextDouble(), rnd.nextDouble(), rnd.nextDouble())
.labeled((double) i % 2));
}

ArrayList<FeatureMeta> meta = new ArrayList<>();
for (int i = 0; i < 4; i++)
meta.add(new FeatureMeta("", i, false));
DatasetTrainer<RandomForestModel, Double> trainer = new RandomForestClassifierTrainer(meta)
.withAmountOfTrees(1)
.withMaxDepth(10)
.withFeaturesCountSelectionStrgy(SQRT)
.withSeed(777)
.withEnvironmentBuilder(TestUtils.testEnvBuilder());

ModelsComposition mdl = trainer.fit(sample, parts, new LabeledDummyVectorizer<>());

List<IgniteModel<Vector, Double>> models = mdl.getModels();

assertEquals(1, mdl.getModels().size());

RandomForestTreeModel tree = (RandomForestTreeModel) models.get(0);

TreeNode repeatingNode = findDuplicatedNode(tree.getRootNode());

assertNull(repeatingNode);
}

/**
* Go through the tree and find a node branch that has repeating feature + value.
*/
private static TreeNode findDuplicatedNode(TreeNode node) {
if (node.getType() == LEAF) {
return null;
}

TreeNode left = node.getLeft();
if (getFeatureId(node) == getFeatureId(left) && getVal(node) == getVal(left)) {
return left;
}

TreeNode inLeftBranch = findDuplicatedNode(left);
if (inLeftBranch != null) {
return inLeftBranch;
}

TreeNode right = node.getRight();
if (getFeatureId(node) == getFeatureId(right) && getVal(node) == getVal(right)) {
return right;
}

return findDuplicatedNode(right);
}

/**
* Get node's value
*/
private static double getVal(TreeNode node) {
return GridTestUtils.getFieldValue(node, TreeNode.class, "val");
}

/**
* Get node's feature id
*/
private static int getFeatureId(TreeNode node) {
return GridTestUtils.getFieldValue(node, TreeNode.class, "featureId");
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ public class RandomForestTest {
private final int cntOfTrees = 10;

/** Min imp delta. */
private final double minImpDelta = 1.0;
private final double minImpDelta = 0.1;

/** Max depth. */
private final int maxDepth = 1;
Expand Down Expand Up @@ -66,13 +66,13 @@ public class RandomForestTest {
@Test
public void testNeedSplit() {
TreeNode node = new TreeNode(1, 1);
node.setImpurity(1000);
assertTrue(rf.needSplit(node, Optional.of(new NodeSplit(0, 0, node.getImpurity() - minImpDelta * 1.01))));
assertFalse(rf.needSplit(node, Optional.of(new NodeSplit(0, 0, node.getImpurity() - minImpDelta * 0.5))));
assertFalse(rf.needSplit(node, Optional.of(new NodeSplit(0, 0, node.getImpurity()))));
node.setImpurity(1.0);
assertTrue(rf.needSplit(node, Optional.of(new NodeSplit(0, 0, minImpDelta * 1.01, node.getImpurity()))));
assertFalse(rf.needSplit(node, Optional.of(new NodeSplit(0, 0, minImpDelta * 0.5, node.getImpurity()))));
assertFalse(rf.needSplit(node, Optional.of(new NodeSplit(0, 0, 0, node.getImpurity()))));

TreeNode child = node.toConditional(0, 0).get(0);
child.setImpurity(1000);
assertFalse(rf.needSplit(child, Optional.of(new NodeSplit(0, 0, child.getImpurity() - minImpDelta * 1.01))));
child.setImpurity(1.0);
assertFalse(rf.needSplit(child, Optional.of(new NodeSplit(0, 0, child.getImpurity(), minImpDelta * 1.01))));
}
}