From 15660ead3da3f498c374fa849196eeb80bdc8fb7 Mon Sep 17 00:00:00 2001 From: Erik Osheim Date: Wed, 9 Dec 2015 17:37:59 -0800 Subject: [PATCH 01/11] Make all splits into binary splits. This commit removes the multi-way split support Brushfire has. In practice splits end up being binary in most important cases, and we end up with a lot of extra unnecessary predicates and indirection. As of this commit, all the tests pass and things more or less work. There are some things that can still be cleaned up and commented better, but this believed to be a moment where things work. --- .../com/stripe/brushfire/AnnotatedTree.scala | 269 +++++++++--------- .../com/stripe/brushfire/Brushfire.scala | 14 +- .../com/stripe/brushfire/Dispatched.scala | 6 +- .../com/stripe/brushfire/Evaluators.scala | 50 ++-- .../com/stripe/brushfire/Injections.scala | 78 +++-- .../com/stripe/brushfire/Predicate.scala | 25 +- .../com/stripe/brushfire/Splitters.scala | 14 +- .../scala/com/stripe/brushfire/Tree.scala | 39 ++- .../com/stripe/brushfire/TreeTraversal.scala | 6 +- .../stripe/brushfire/JsonInjectionsSpec.scala | 3 +- .../com/stripe/brushfire/PredicateSpec.scala | 6 +- .../com/stripe/brushfire/TreeGenerators.scala | 4 +- .../stripe/brushfire/TreeTraversalSpec.scala | 28 +- .../stripe/brushfire/scalding/Trainer.scala | 26 +- 14 files changed, 271 insertions(+), 297 deletions(-) diff --git a/brushfire-core/src/main/scala/com/stripe/brushfire/AnnotatedTree.scala b/brushfire-core/src/main/scala/com/stripe/brushfire/AnnotatedTree.scala index 6b82821..387ca9f 100644 --- a/brushfire-core/src/main/scala/com/stripe/brushfire/AnnotatedTree.scala +++ b/brushfire-core/src/main/scala/com/stripe/brushfire/AnnotatedTree.scala @@ -2,24 +2,37 @@ package com.stripe.brushfire import com.twitter.algebird._ -sealed trait Node[K, V, T, A] { +import java.lang.Math.{abs, max} + +sealed abstract class Node[K, V, T, A] { def annotation: A -} -case class SplitNode[K, V, T, A: Semigroup](children: Seq[(K, Predicate[V], Node[K, V, T, A])]) extends Node[K, V, T, A] { - require(children.nonEmpty) + def renumber(nextId: Int): (Int, Node[K, V, T, A]) = + this match { + case SplitNode(p, k, lc0, rc0, a) => + val (n1, lc) = lc0.renumber(nextId) + val (n2, rc) = rc0.renumber(n1) + (n2, SplitNode(p, k, lc, rc, a)) + case LeafNode(_, target, annotation) => + (nextId + 1, LeafNode(nextId, target, annotation)) + } +} - lazy val annotation: A = - Semigroup.sumOption(children.map(_._3.annotation)).get +case class SplitNode[K, V, T, A](predicate: Predicate[V], key: K, leftChild: Node[K, V, T, A], rightChild: Node[K, V, T, A], annotation: A) extends Node[K, V, T, A] { + def evaluate(row: Map[K, V]): List[Node[K, V, T, A]] = + predicate(row.get(key)) match { + case Some(true) => leftChild :: Nil + case Some(false) => rightChild :: Nil + case None => leftChild :: rightChild :: Nil + } +} - def findChildren(row: Map[K, V]): List[Node[K, V, T, A]] = - children.collect { case (k, p, n) if p(row.get(k)) => n }(collection.breakOut) +object SplitNode { + def apply[K, V, T, A: Semigroup](p: Predicate[V], k: K, lc: Node[K, V, T, A], rc: Node[K, V, T, A]): SplitNode[K, V, T, A] = + SplitNode(p, k, lc, rc, Semigroup.plus(lc.annotation, rc.annotation)) } -case class LeafNode[K, V, T, A]( - index: Int, - target: T, - annotation: A) extends Node[K, V, T, A] +case class LeafNode[K, V, T, A](index: Int, target: T, annotation: A) extends Node[K, V, T, A] object LeafNode { def apply[K, V, T, A: Monoid](index: Int, target: T): LeafNode[K, V, T, A] = @@ -27,51 +40,53 @@ object LeafNode { } case class AnnotatedTree[K, V, T, A: Semigroup](root: Node[K, V, T, A]) { + + /** + * Transform the splits of a tree (the predicates and keys of the + * nodes) while leaving everything else alone. + */ private def mapSplits[K0, V0](f: (K, Predicate[V]) => (K0, Predicate[V0])): AnnotatedTree[K0, V0, T, A] = { - def recur(node: Node[K, V, T, A]): Node[K0, V0, T, A] = node match { - case SplitNode(children) => - SplitNode(children.map { - case (key, pred, child) => - val (key0, pred0) = f(key, pred) - (key0, pred0, recur(child)) - }) - - case LeafNode(index, target, annotation) => - LeafNode(index, target, annotation) + def recur(node: Node[K, V, T, A]): Node[K0, V0, T, A] = { + node match { + case SplitNode(p, k, lc, rc, _) => + val (key, pred) = f(k, p) + SplitNode(pred, key, recur(lc), recur(rc)) + case LeafNode(index, target, a) => + LeafNode(index, target, a) + } } - AnnotatedTree(recur(root)) } + /** + * Transform the leaves of a tree (the target and annotation) while + * leaving the structure alone. + */ private def mapLeaves[T0, A0: Semigroup](f: (T, A) => (T0, A0)): AnnotatedTree[K, V, T0, A0] = { def recur(node: Node[K, V, T, A]): Node[K, V, T0, A0] = node match { - case SplitNode(children) => - SplitNode(children.map { - case (key, pred, child) => - (key, pred, recur(child)) - }) - - case LeafNode(index, target, annotation) => - val (target1, annotation1) = f(target, annotation) - LeafNode(index, target1, annotation1) + case SplitNode(p, k, lc, rc, _) => + SplitNode(p, k, recur(lc), recur(rc)) + case LeafNode(index, target, a) => + val (target1, a1) = f(target, a) + LeafNode(index, target1, a1) } - AnnotatedTree(recur(root)) } /** - * Annotate the tree by mapping the leaf target distributions to some - * annotation for the leaves, then bubbling the annotations up the tree using - * the `Semigroup` for the annotation type. + * Annotate the tree by mapping the leaf target distributions to + * some annotation for the leaves, then bubbling the annotations up + * the tree using the `Semigroup` for the annotation type. */ def annotate[A1: Semigroup](f: T => A1): AnnotatedTree[K, V, T, A1] = mapLeaves { (t, _) => (t, f(t)) } /** * Re-annotate the leaves of this tree using `f` to transform the - * annotations. This will then bubble the annotations up the tree using the - * `Semigroup` for `A1`. If `f` is a semigroup homomorphism, then this - * (semantically) just transforms the annotation at each node using `f`. + * annotations. This will then bubble the annotations up the tree + * using the `Semigroup` for `A1`. If `f` is a semigroup + * homomorphism, then this (semantically) just transforms the + * annotation at each node using `f`. */ def mapAnnotation[A1: Semigroup](f: A => A1): AnnotatedTree[K, V, T, A1] = mapLeaves { (t, a) => (t, f(a)) } @@ -83,8 +98,8 @@ case class AnnotatedTree[K, V, T, A: Semigroup](root: Node[K, V, T, A]) { mapSplits { (k, p) => (f(k), p) } /** - * Maps the [[Predicate]]s in the `Tree` using `f`. Note, this will only - * produce a valid `Tree` if `f` preserves the ordering (ie if + * Maps the [[Predicate]]s in the `Tree` using `f`. Note, this will + * only produce a valid `Tree` if `f` preserves the ordering (ie if * `a.compare(b) == f(a).compare(f(b))`). */ def mapPredicates[V1: Ordering](f: V => V1): AnnotatedTree[K, V1, T, A] = @@ -93,84 +108,84 @@ case class AnnotatedTree[K, V, T, A: Semigroup](root: Node[K, V, T, A]) { /** * Returns the leaf with index `leafIndex` by performing a DFS. */ - def leafAt(leafIndex: Int): Option[LeafNode[K, V, T, A]] = leafAt(leafIndex, root) + def leafAt(leafIndex: Int): Option[LeafNode[K, V, T, A]] = + leafAt(leafIndex, root) /** - * Returns the leaf with index `leafIndex` that is a descendant of `start` by - * performing a DFS starting with `start`. + * Returns the leaf with index `leafIndex` that is a descendant of + * `start` by performing a DFS starting with `start`. */ - def leafAt(leafIndex: Int, start: Node[K, V, T, A]): Option[LeafNode[K, V, T, A]] = { + def leafAt(leafIndex: Int, start: Node[K, V, T, A]): Option[LeafNode[K, V, T, A]] = start match { - case leaf @ LeafNode(_, _, _) => - if (leaf.index == leafIndex) Some(leaf) else None - case SplitNode(children) => - children - .flatMap { case (_, _, child) => leafAt(leafIndex, child) } - .headOption + case SplitNode(p, k, lc, rc, _) => + leafAt(leafIndex, lc) match { + case None => leafAt(leafIndex, rc) + case some => some + } + case leaf @ LeafNode(index, _, _) => + if (index == leafIndex) Some(leaf) else None } - } /** * Prune a tree to minimize validation error. * - * Recursively replaces each split with a leaf when it would have a lower error than the sum of the child leaves - * errors. + * Recursively replaces each split with a leaf when it would have a + * lower error than the sum of the child leaves errors. * * @param validationData A Map from leaf index to validation data. * @param voter to create predictions from target distributions. * @param error to calculate an error statistic given observations (validation) and predictions (training). * @return The new, pruned tree. */ - def prune[P, E](validationData: Map[Int, T], voter: Voter[T, P], error: Error[T, P, E])(implicit targetMonoid: Monoid[T], errorOrdering: Ordering[E]): AnnotatedTree[K, V, T, A] = { - AnnotatedTree(pruneNode(validationData, this.root, voter, error)._2).renumberLeaves - } + def prune[P, E: Ordering](validationData: Map[Int, T], voter: Voter[T, P], error: Error[T, P, E])(implicit m: Monoid[T]): AnnotatedTree[K, V, T, A] = + AnnotatedTree(pruneNode(validationData, root, voter, error)._2) /** - * Prune a tree to minimize validation error, starting from given root node. + * Prune a tree to minimize validation error, starting from given + * root node. * - * This method recursively traverses the tree from the root, branching on splits, until it finds leaves, then goes back - * down the tree combining leaves when such a combination would reduce validation error. + * This method recursively traverses the tree from the root, + * branching on splits, until it finds leaves, then goes back down + * the tree combining leaves when such a combination would reduce + * validation error. * * @param validationData Map from leaf index to validation data. * @param start The root node of the tree. * @return A node at the root of the new, pruned tree. */ - def pruneNode[P, E](validationData: Map[Int, T], start: Node[K, V, T, A], voter: Voter[T, P], error: Error[T, P, E])(implicit targetMonoid: Monoid[T], errorOrdering: Ordering[E]): (Map[Int, T], Node[K, V, T, A]) = { - type ChildSeqType = (K, Predicate[V], Node[K, V, T, A]) + def pruneNode[P, E: Ordering](validationData: Map[Int, T], start: Node[K, V, T, A], voter: Voter[T, P], error: Error[T, P, E])(implicit m: Monoid[T]): (Map[Int, T], Node[K, V, T, A]) = { start match { case leaf @ LeafNode(_, _, _) => // Bounce at the bottom and start back up the tree. (validationData, leaf) - case SplitNode(children) => - // Call pruneNode on each child, accumulating modified children and - // additions to the validation data along the way. - val (newData, newChildren) = - children.foldLeft((validationData, Seq[ChildSeqType]())) { - case ((vData, childSeq), (k, p, child)) => - pruneNode(vData, child, voter, error) match { - case (v, c) => (v, childSeq :+ (k, p, c)) - } - } - // Now that we've taken care of the children, prune the current level. - val childLeaves = newChildren.collect { case (k, v, s @ LeafNode(_, _, _)) => (k, v, s) } - - if (childLeaves.size == newChildren.size) { - // If all children are leaves, we can potentially prune. - val parent = SplitNode(newChildren) - pruneLevel(parent, childLeaves, newData, voter, error) - } else { - // If any children are SplitNodes, we can't prune. - (newData, SplitNode(newChildren)) + case SplitNode(p, k, lc0, rc0, _) => + val (vData, lc1) = pruneNode(validationData, lc0, voter, error) + val (newData, rc1) = pruneNode(vData, rc0, voter, error) + + // If all the children are leaves, we can potentially + // prune. Otherwise we definitely cannot prune. + lc1 match { + case lc2 @ LeafNode(_, _, _) => + rc1 match { + case rc2 @ LeafNode(_, _, _) => + pruneLevel(SplitNode(p, k, lc2, rc2), lc2, rc2, newData, voter, error) + case _ => + (newData, SplitNode(p, k, lc1, rc1)) + } + case _ => + (newData, SplitNode(p, k, lc1, rc1)) } } } /** - * Test conditions and optionally replace parent with a new leaf that combines children. + * Test conditions and optionally replace parent with a new leaf + * that combines children. * - * Also merges validation data for any combined leaves. This relies on a hack that assumes no leaves have negative - * indices to start out. + * Also merges validation data for any combined leaves. This relies + * on a hack that assumes no leaves have negative indices to start + * out. * * @param parent * @param children @@ -178,31 +193,33 @@ case class AnnotatedTree[K, V, T, A: Semigroup](root: Node[K, V, T, A]) { */ def pruneLevel[P, E]( parent: SplitNode[K, V, T, A], - children: Seq[(K, Predicate[V], LeafNode[K, V, T, A])], + leftChild: LeafNode[K, V, T, A], + rightChild: LeafNode[K, V, T, A], validationData: Map[Int, T], voter: Voter[T, P], - error: Error[T, P, E])(implicit targetMonoid: Monoid[T], - errorOrdering: Ordering[E]): (Map[Int, T], Node[K, V, T, A]) = { + error: Error[T, P, E])(implicit targetMonoid: Monoid[T], errorOrdering: Ordering[E]): (Map[Int, T], Node[K, V, T, A]) = { - // Get training and validation data and validation error for each leaf. - val (targets, validations, errors) = children.unzip3 { - case (k, p, leaf) => - val trainingTarget = leaf.target - val validationTarget = validationData.getOrElse(leaf.index, targetMonoid.zero) - val leafError = error.create(validationTarget, voter.combine(Some(trainingTarget))) - (trainingTarget, validationTarget, leafError) - } + def v(leaf: LeafNode[K, V, T, A]): T = + validationData.getOrElse(leaf.index, targetMonoid.zero) + + def e(leaf: LeafNode[K, V, T, A]): E = + error.create(v(leaf), voter.combine(Some(leaf.target))) - val targetSum = targetMonoid.sum(targets) // Combined training targets to create the training data of the potential combined node. + val targetSum = targetMonoid.plus(leftChild.target, rightChild.target) + val validationSum = targetMonoid.plus(v(leftChild), v(rightChild)) + val sumOfErrors = error.semigroup.plus(e(leftChild), e(rightChild)) + + // Get training and validation data and validation error for each leaf. val targetPrediction = voter.combine(Some(targetSum)) // Generate prediction from combined target. - val validationSum = targetMonoid.sum(validations) // Combined validation target for combined node. val errorOfSums = error.create(validationSum, targetPrediction) // Error of potential combined node. - val sumOfErrors = error.semigroup.sumOption(errors) // Sum of errors of leaves. + // Compare sum of errors and error of sums (and lower us out of the sum of errors Option). - val doCombine = sumOfErrors.exists { sumOE => errorOrdering.gteq(sumOE, errorOfSums) } - if (doCombine) { // Create a new leaf from the combination of the children. + val doCombine = errorOrdering.gteq(sumOfErrors, errorOfSums) + + if (doCombine) { + // Create a new leaf from the combination of the children. // Find a unique (negative) index for the new leaf: - val newIndex = -1 * children.map { case (k, p, leaf) => Math.abs(leaf.index) }.max + val newIndex = -1 * max(abs(leftChild.index), abs(rightChild.index)) val node = LeafNode[K, V, T, A](newIndex, targetSum, parent.annotation) (validationData + (newIndex -> validationSum), node) } else { @@ -211,48 +228,35 @@ case class AnnotatedTree[K, V, T, A: Semigroup](root: Node[K, V, T, A]) { } def leafFor(row: Map[K, V], id: Option[String] = None)(implicit traversal: TreeTraversal[K, V, T, A]): Option[LeafNode[K, V, T, A]] = - traversal.find(root, row, id).headOption + traversal.find(this, row, id).headOption def leafIndexFor(row: Map[K, V], id: Option[String] = None)(implicit traversal: TreeTraversal[K, V, T, A]): Option[Int] = leafFor(row, id).map(_.index) def targetFor(row: Map[K, V], id: Option[String] = None)(implicit traversal: TreeTraversal[K, V, T, A], semigroup: Semigroup[T]): Option[T] = - semigroup.sumOption(traversal.find(root, row, id).map(_.target)) + semigroup.sumOption(traversal.find(this, row, id).map(_.target)) /** * For each leaf, this may convert the leaf to a [[SplitNode]] whose children * have the target distribution returned by calling `fn` with the leaf's * index. If `fn` returns an empty `Seq`, then the leaf is left as-is. */ - def growByLeafIndex(fn: Int => Seq[(K, Predicate[V], T, A)]): AnnotatedTree[K, V, T, A] = { - var newIndex = -1 - def incrIndex(): Int = { - newIndex += 1 - newIndex - } - - def growFrom(start: Node[K, V, T, A]): Node[K, V, T, A] = { + def growByLeafIndex(fn: Int => Option[SplitNode[K, V, T, A]]): AnnotatedTree[K, V, T, A] = { + def growFrom(nextIndex: Int, start: Node[K, V, T, A]): (Int, Node[K, V, T, A]) = { start match { case LeafNode(index, target, annotation) => - val newChildren = fn(index) - if (newChildren.isEmpty) - LeafNode[K, V, T, A](incrIndex(), target, annotation) - else - SplitNode[K, V, T, A](newChildren.map { - case (feature, predicate, target, childAnnotation) => - val child = LeafNode[K, V, T, A](incrIndex(), target, childAnnotation) - (feature, predicate, child) - }) - - case SplitNode(children) => - SplitNode[K, V, T, A](children.map { - case (feature, predicate, child) => - (feature, predicate, growFrom(child)) - }) + fn(index) match { + case None => (nextIndex + 1, LeafNode(nextIndex, target, annotation)) + case Some(split) => split.renumber(nextIndex) + } + case SplitNode(p, k, lc0, rc0, _) => + val (n1, lc) = growFrom(nextIndex, lc0) + val (n2, rc) = growFrom(n1, rc0) + (n2, SplitNode(p, k, lc, rc)) } } - AnnotatedTree(growFrom(root)) + AnnotatedTree(growFrom(0, root)._2) } /** @@ -266,11 +270,8 @@ case class AnnotatedTree[K, V, T, A: Semigroup](root: Node[K, V, T, A]) { start match { case LeafNode(index, _, _) => fn(index).getOrElse(start) - case SplitNode(children) => - SplitNode[K, V, T, A](children.map { - case (feature, predicate, child) => - (feature, predicate, updateFrom(child)) - }) + case SplitNode(p, k, lc0, rc0, _) => + SplitNode(p, k, updateFrom(lc0), updateFrom(rc0)) } } @@ -283,5 +284,5 @@ case class AnnotatedTree[K, V, T, A: Semigroup](root: Node[K, V, T, A]) { * @return A new tree with leaves renumbered. */ def renumberLeaves: AnnotatedTree[K, V, T, A] = - this.growByLeafIndex { i => Nil } + AnnotatedTree(root.renumber(0)._2) } diff --git a/brushfire-core/src/main/scala/com/stripe/brushfire/Brushfire.scala b/brushfire-core/src/main/scala/com/stripe/brushfire/Brushfire.scala index 05d9b1a..cca141f 100644 --- a/brushfire-core/src/main/scala/com/stripe/brushfire/Brushfire.scala +++ b/brushfire-core/src/main/scala/com/stripe/brushfire/Brushfire.scala @@ -42,14 +42,22 @@ trait Splitter[V, T] { } /** Candidate split for a tree node */ -trait Split[V, T] { - def predicates: Iterable[(Predicate[V], T)] +sealed trait Split[V, T] { + def predicate: Predicate[V] + def leftDistribution: T + def rightDistribution: T + def distributions: List[T] = leftDistribution :: rightDistribution :: Nil + + def createSplitNode[K](feature: K): SplitNode[K, V, T, Unit] = + SplitNode(predicate, feature, LeafNode(-1, leftDistribution), LeafNode(-1, rightDistribution)) } +case class BinarySplit[V, T](predicate: Predicate[V], leftDistribution: T, rightDistribution: T) extends Split[V, T] + /** Evaluates the goodness of a candidate split */ trait Evaluator[V, T] { /** returns a (possibly transformed) version of the input split, and a numeric goodness score */ - def evaluate(split: Split[V, T]): (Split[V, T], Double) + def evaluate(split: Split[V, T]): Option[(Split[V, T], Double)] } /** Provides stopping conditions which guide when splits will be attempted */ diff --git a/brushfire-core/src/main/scala/com/stripe/brushfire/Dispatched.scala b/brushfire-core/src/main/scala/com/stripe/brushfire/Dispatched.scala index da29548..cc50f76 100644 --- a/brushfire-core/src/main/scala/com/stripe/brushfire/Dispatched.scala +++ b/brushfire-core/src/main/scala/com/stripe/brushfire/Dispatched.scala @@ -59,11 +59,7 @@ object Dispatched { def wrapSplits[X, T, A: Ordering, B, C: Ordering, D](splits: Iterable[Split[X, T]])(fn: X => Dispatched[A, B, C, D]) = { splits.map { split => - new Split[Dispatched[A, B, C, D], T] { - def predicates = split.predicates.map { - case (pred, p) => (pred.map(fn), p) - } - } + BinarySplit(split.predicate.map(fn), split.leftDistribution, split.rightDistribution) } } } diff --git a/brushfire-core/src/main/scala/com/stripe/brushfire/Evaluators.scala b/brushfire-core/src/main/scala/com/stripe/brushfire/Evaluators.scala index 8704d7e..d46f311 100644 --- a/brushfire-core/src/main/scala/com/stripe/brushfire/Evaluators.scala +++ b/brushfire-core/src/main/scala/com/stripe/brushfire/Evaluators.scala @@ -4,8 +4,8 @@ import com.twitter.algebird._ case class ChiSquaredEvaluator[V, L, W](implicit weightMonoid: Monoid[W], weightDouble: W => Double) extends Evaluator[V, Map[L, W]] { - def evaluate(split: Split[V, Map[L, W]]) = { - val rows = split.predicates.map { _._2 }.filter { _.nonEmpty } + def evaluate(split: Split[V, Map[L, W]]): Option[(Split[V, Map[L, W]], Double)] = { + val rows = split.distributions.filter { _.nonEmpty } if (rows.size > 1) { val n = weightMonoid.sum(rows.flatMap { _.values }) val rowTotals = rows.map { row => weightMonoid.sum(row.values) }.toList @@ -20,43 +20,29 @@ case class ChiSquaredEvaluator[V, L, W](implicit weightMonoid: Monoid[W], weight val delta = observed - expected (delta * delta) / expected }).sum - (split, testStatistic) - } else - (EmptySplit[V, Map[L, W]](), Double.NegativeInfinity) + Some((split, testStatistic)) + } else { + None + } } } case class MinWeightEvaluator[V, L, W: Monoid](minWeight: W => Boolean, wrapped: Evaluator[V, Map[L, W]]) extends Evaluator[V, Map[L, W]] { - def evaluate(split: Split[V, Map[L, W]]) = { - val (baseSplit, baseScore) = wrapped.evaluate(split) - if (baseSplit.predicates.forall { - case (pred, freq) => - val totalWeight = Monoid.sum(freq.values) - minWeight(totalWeight) - }) - (baseSplit, baseScore) - else - (EmptySplit[V, Map[L, W]](), Double.NegativeInfinity) - } -} -case class EmptySplit[V, P]() extends Split[V, P] { - val predicates = Nil -} + private[this] def test(dist: Map[L, W]): Boolean = minWeight(Monoid.sum(dist.values)) -case class ErrorEvaluator[V, T, P, E](error: Error[T, P, E], voter: Voter[T, P])(fn: E => Double) - extends Evaluator[V, T] { - def evaluate(split: Split[V, T]) = { - val totalErrorOption = - error.semigroup.sumOption( - split - .predicates - .map { case (_, target) => error.create(target, voter.combine(Some(target))) }) - - totalErrorOption match { - case Some(totalError) => (split, -fn(totalError)) - case None => (EmptySplit[V, T](), Double.NegativeInfinity) + def evaluate(split: Split[V, Map[L, W]]): Option[(Split[V, Map[L, W]], Double)] = + wrapped.evaluate(split).filter { case (baseSplit, _) => + test(baseSplit.leftDistribution) && test(baseSplit.rightDistribution) } +} + +case class ErrorEvaluator[V, T, P, E](error: Error[T, P, E], voter: Voter[T, P])(fn: E => Double) extends Evaluator[V, T] { + def evaluate(split: Split[V, T]): Option[(Split[V, T], Double)] = { + def e(t: T): E = error.create(t, voter.combine(Some(t))) + val e0 = e(split.leftDistribution) + val e1 = e(split.rightDistribution) + Some((split, -fn(error.semigroup.plus(e0, e1)))) } } diff --git a/brushfire-core/src/main/scala/com/stripe/brushfire/Injections.scala b/brushfire-core/src/main/scala/com/stripe/brushfire/Injections.scala index 3a4b748..f3632e7 100644 --- a/brushfire-core/src/main/scala/com/stripe/brushfire/Injections.scala +++ b/brushfire-core/src/main/scala/com/stripe/brushfire/Injections.scala @@ -110,7 +110,7 @@ object JsonInjections { val predNode = n.get("exists") if (predNode.isNull) Success(IsPresent[V](None)) else fromJsonNode[Predicate[V]](predNode).map(p => IsPresent(Some(p))) - case _ => sys.error("Not a predicate node") + case _ => sys.error("Not a predicate node: " + n) } } } @@ -124,57 +124,49 @@ object JsonInjections { obj.put("distribution", toJsonNode(target)) obj - case SplitNode(children) => - val ary = JsonNodeFactory.instance.arrayNode - children.foreach { - case (feature, predicate, child) => - val obj = JsonNodeFactory.instance.objectNode - obj.put("feature", toJsonNode(feature)) - obj.put("predicate", toJsonNode(predicate)) - obj.put("display", toJsonNode(Predicate.display(predicate))) - obj.put("children", toJsonNode(child)(nodeJsonNodeInjection)) - ary.add(obj) - } - ary + case SplitNode(p, k, lc, rc, _) => + val obj = JsonNodeFactory.instance.objectNode + obj.put("predicate", toJsonNode(p)(predicateJsonNodeInjection)) + obj.put("key", toJsonNode(k)) + obj.put("left", toJsonNode(lc)(nodeJsonNodeInjection)) + obj.put("right", toJsonNode(rc)(nodeJsonNodeInjection)) + obj } - def tryChild(node: JsonNode, property: String) = Try { + def tryChild(node: JsonNode, property: String): Try[JsonNode] = + Try { + val child = node.get(property) + assert(child != null, property + " != null") + child + } + + def tryLoad[T: JsonNodeInjection](node: JsonNode, property: String): Try[T] = { val child = node.get(property) - assert(child != null, property + " != null") - child + if (child == null) Failure(new IllegalArgumentException(property + " != null")) + else fromJsonNode[T](child) } - override def invert(n: JsonNode) = { - Option(n.get("leaf")) match { - case Some(indexNode) => fromJsonNode[Int](indexNode).flatMap { index => - fromJsonNode[T](n.get("distribution")).map { target => - LeafNode(index, target) - } - } - - case None => - val children = n.getElements.asScala.map { c => - for ( - featureNode <- tryChild(c, "feature"); - feature <- fromJsonNode[K](featureNode); - predicateNode <- tryChild(c, "predicate"); - predicate <- fromJsonNode[Predicate[V]](predicateNode); - childNode <- tryChild(c, "children"); - child <- fromJsonNode[Node[K, V, T, Unit]](childNode) - ) yield (feature, predicate, child) - }.toList - - children.find { _.isFailure } match { - case Some(Failure(e)) => Failure(InversionFailure(n, e)) - case _ => Success(SplitNode[K, V, T, Unit](children.map { _.get })) - } + override def invert(n: JsonNode): Try[Node[K, V, T, Unit]] = + if (n.has("leaf")) { + for { + index <- tryLoad[Int](n, "leaf") + target <- tryLoad[T](n, "distribution") + } yield LeafNode(index, target) + } else { + for { + p <- tryLoad[Predicate[V]](n, "predicate") + k <- tryLoad[K](n, "key") + left <- tryLoad[Node[K, V, T, Unit]](n, "left")(nodeJsonNodeInjection) + right <- tryLoad[Node[K, V, T, Unit]](n, "right")(nodeJsonNodeInjection) + } yield SplitNode(p, k, left, right) } - } } new AbstractJsonNodeInjection[Tree[K, V, T]] { - def apply(tree: Tree[K, V, T]) = toJsonNode(tree.root) - override def invert(n: JsonNode) = fromJsonNode[Node[K, V, T, Unit]](n).map { root => Tree(root) } + def apply(tree: Tree[K, V, T]): JsonNode = + toJsonNode(tree.root) + override def invert(n: JsonNode): Try[Tree[K, V, T]] = + fromJsonNode[Node[K, V, T, Unit]](n).map(root => Tree(root)) } } diff --git a/brushfire-core/src/main/scala/com/stripe/brushfire/Predicate.scala b/brushfire-core/src/main/scala/com/stripe/brushfire/Predicate.scala index 2ed1d62..c4a9efb 100644 --- a/brushfire-core/src/main/scala/com/stripe/brushfire/Predicate.scala +++ b/brushfire-core/src/main/scala/com/stripe/brushfire/Predicate.scala @@ -5,7 +5,13 @@ package com.stripe.brushfire * `true` or `false`. It is generally used within a [[Tree]] to decide which * branch to follow while trying to classify a row/feature vector. */ -sealed trait Predicate[V] extends (Option[V] => Boolean) { +sealed trait Predicate[V] extends (Option[V] => Option[Boolean]) { + + def run(o: Option[V]): Boolean = + apply(o) match { + case Some(b) => b + case None => true + } /** * Map the value types of this [[Predicate]] using `f`. @@ -24,7 +30,7 @@ sealed trait Predicate[V] extends (Option[V] => Boolean) { * defined and is equal to `value` (according to the input's `equals` method). */ case class EqualTo[V](value: V) extends Predicate[V] { - def apply(v: Option[V]) = v.isEmpty || (v.get == value) + def apply(v: Option[V]): Option[Boolean] = v.map(_ == value) } /** @@ -33,7 +39,7 @@ case class EqualTo[V](value: V) extends Predicate[V] { * of type `V` to handle the comparison. */ case class LessThan[V](value: V)(implicit ord: Ordering[V]) extends Predicate[V] { - def apply(v: Option[V]) = v.isEmpty || ord.lt(v.get, value) + def apply(v: Option[V]): Option[Boolean] = v.map(x => ord.lt(x, value)) } /** @@ -41,7 +47,7 @@ case class LessThan[V](value: V)(implicit ord: Ordering[V]) extends Predicate[V] * `false` if `pred` returns `true`. */ case class Not[V](pred: Predicate[V]) extends Predicate[V] { - def apply(v: Option[V]) = v.isEmpty || !pred(v) + def apply(v: Option[V]): Option[Boolean] = pred(v).map(!_) } /** @@ -49,7 +55,10 @@ case class Not[V](pred: Predicate[V]) extends Predicate[V] { * returns `true`. */ case class AnyOf[V](preds: Seq[Predicate[V]]) extends Predicate[V] { - def apply(v: Option[V]) = preds.exists { p => p(v) } + def apply(v: Option[V]): Option[Boolean] = { + val it = preds.iterator.map(_(v)).flatten + if (!it.hasNext) None else Some(it.exists(_ == true)) + } } /** @@ -63,7 +72,11 @@ case class AnyOf[V](preds: Seq[Predicate[V]]) extends Predicate[V] { * value is present (not missing). */ case class IsPresent[V](pred: Option[Predicate[V]]) extends Predicate[V] { - def apply(v: Option[V]) = v.isDefined && pred.fold(true)(_(v)) + def apply(v: Option[V]): Option[Boolean] = + pred match { + case None => Some(v.isDefined) + case Some(pred) => Some(v.isDefined && pred(v) == Some(true)) + } } object Predicate { diff --git a/brushfire-core/src/main/scala/com/stripe/brushfire/Splitters.scala b/brushfire-core/src/main/scala/com/stripe/brushfire/Splitters.scala index 9fb9633..0c915a6 100644 --- a/brushfire-core/src/main/scala/com/stripe/brushfire/Splitters.scala +++ b/brushfire-core/src/main/scala/com/stripe/brushfire/Splitters.scala @@ -2,8 +2,7 @@ package com.stripe.brushfire import com.twitter.algebird._ -case class BinarySplitter[V, T: Monoid](partition: V => Predicate[V]) - extends Splitter[V, T] { +case class BinarySplitter[V, T: Monoid](partition: V => Predicate[V]) extends Splitter[V, T] { type S = Map[V, T] def create(value: V, target: T) = Map(value -> target) @@ -13,7 +12,7 @@ case class BinarySplitter[V, T: Monoid](partition: V => Predicate[V]) def split(parent: T, stats: Map[V, T]) = { stats.keys.map { v => val predicate = partition(v) - val (trues, falses) = stats.partition { case (v, d) => predicate(Some(v)) } + val (trues, falses) = stats.partition { case (v, d) => predicate.run(Some(v)) } BinarySplit(predicate, Monoid.sum(trues.values), Monoid.sum(falses.values)) } } @@ -93,12 +92,3 @@ case class SpaceSaverSplitter[V, L](capacity: Int = 1000) } } } - -case class BinarySplit[V, T]( - predicate: Predicate[V], - leftDistribution: T, - rightDistribution: T) - extends Split[V, T] { - def predicates = - List(predicate -> leftDistribution, Not(predicate) -> rightDistribution) -} diff --git a/brushfire-core/src/main/scala/com/stripe/brushfire/Tree.scala b/brushfire-core/src/main/scala/com/stripe/brushfire/Tree.scala index 5f78d92..dea2161 100644 --- a/brushfire-core/src/main/scala/com/stripe/brushfire/Tree.scala +++ b/brushfire-core/src/main/scala/com/stripe/brushfire/Tree.scala @@ -8,7 +8,8 @@ object Tree { def apply[K, V, T](node: Node[K, V, T, Unit]): Tree[K, V, T] = AnnotatedTree(node) - def singleton[K, V, T](t: T): Tree[K, V, T] = Tree(LeafNode(0, t, ())) + def singleton[K, V, T](t: T): Tree[K, V, T] = + AnnotatedTree(LeafNode(0, t, ())) def expand[K, V, T: Monoid](times: Int, treeIndex: Int, leaf: LeafNode[K, V, T, Unit], splitter: Splitter[V, T], evaluator: Evaluator[V, T], stopper: Stopper[T], sampler: Sampler[K], instances: Iterable[Instance[K, V, T]]): Node[K, V, T, Unit] = { if (times > 0 && stopper.shouldSplit(leaf.target)) { @@ -19,33 +20,25 @@ object Tree { if(sampler.includeFeature(f, treeIndex, leaf.index)) Map(f -> splitter.create(v, instance.target)) else - Map.empty[K,splitter.S] + Map.empty[K, splitter.S] } }).flatMap { featureMap => - val splits = featureMap.toList.flatMap { - case (f, s) => - splitter.split(leaf.target, s).map { x => f -> evaluator.evaluate(x) } - } - if(splits.isEmpty) - None - else { - val (splitFeature, (split, _)) = splits.maxBy { case (f, (x, s)) => s } - val edges = split.predicates.toList.map { - case (pred, _) => - val newInstances = instances.filter { inst => pred(inst.features.get(splitFeature)) } - val target = Monoid.sum(newInstances.map { _.target }) - (pred, target, newInstances) - } + val splits = for { + (f, s) <- featureMap.toList + split <- splitter.split(leaf.target, s) + tpl <- evaluator.evaluate(split) + } yield (f, tpl) - if (edges.count { case (_, _, newInstances) => newInstances.nonEmpty } > 1) { - Some(SplitNode(edges.map { - case (pred, target, newInstances) => - (splitFeature, pred, expand[K, V, T](times - 1, treeIndex, LeafNode(0, target), splitter, evaluator, stopper, sampler, newInstances)) - })) - } else { - None + if (splits.isEmpty) None else { + val (splitFeature, (split, _)) = splits.maxBy { case (f, (x, s)) => s } + val pred = split.predicate + def ex(dist: T): Node[K, V, T, Unit] = { + val newInstances = instances.filter { inst => pred.run(inst.features.get(splitFeature)) } + val target = Monoid.sum(newInstances.map(_.target)) + expand(times - 1, treeIndex, LeafNode(0, target), splitter, evaluator, stopper, sampler, newInstances) } + Some(SplitNode(pred, splitFeature, ex(split.leftDistribution), ex(split.rightDistribution))) } }.getOrElse(leaf) } else { diff --git a/brushfire-core/src/main/scala/com/stripe/brushfire/TreeTraversal.scala b/brushfire-core/src/main/scala/com/stripe/brushfire/TreeTraversal.scala index 43c19d5..80bb433 100644 --- a/brushfire-core/src/main/scala/com/stripe/brushfire/TreeTraversal.scala +++ b/brushfire-core/src/main/scala/com/stripe/brushfire/TreeTraversal.scala @@ -148,10 +148,10 @@ case class DepthFirstTreeTraversal[K, V, T, A](order: (Random, List[Node[K, V, T Stream.empty case (leaf @ LeafNode(_, _, _)) :: rest => leaf #:: loop0(rest) - case (split @ SplitNode(_)) :: rest => - val newStack = split.findChildren(row) match { + case (split @ SplitNode(_, _, _, _, _)) :: rest => + val newStack = split.evaluate(row) match { case Nil => rest - case node :: Nil => node :: rest + //case node :: Nil => node :: rest case candidates => order(rng, candidates) ::: rest } loop(newStack) diff --git a/brushfire-core/src/test/scala/com/stripe/brushfire/JsonInjectionsSpec.scala b/brushfire-core/src/test/scala/com/stripe/brushfire/JsonInjectionsSpec.scala index 7e8b1ce..c8d7b76 100644 --- a/brushfire-core/src/test/scala/com/stripe/brushfire/JsonInjectionsSpec.scala +++ b/brushfire-core/src/test/scala/com/stripe/brushfire/JsonInjectionsSpec.scala @@ -15,7 +15,8 @@ class JsonInjectionsSpec extends WordSpec with Matchers with Checkers { "treeJsonInjection" should { "round-trip" in { check { (tree: Tree[String, Double, Map[String, Long]]) => - val haggard = fromJsonNode[Tree[String, Double, Map[String, Long]]](toJsonNode(tree)) + val json = toJsonNode(tree) + val haggard = fromJsonNode[Tree[String, Double, Map[String, Long]]](json) Try(tree) == haggard } } diff --git a/brushfire-core/src/test/scala/com/stripe/brushfire/PredicateSpec.scala b/brushfire-core/src/test/scala/com/stripe/brushfire/PredicateSpec.scala index e0dbc8b..7732356 100644 --- a/brushfire-core/src/test/scala/com/stripe/brushfire/PredicateSpec.scala +++ b/brushfire-core/src/test/scala/com/stripe/brushfire/PredicateSpec.scala @@ -10,15 +10,15 @@ class PredicateSpec extends WordSpec with Matchers with Checkers { "allow missing values in all but IsPresent" in { check { (pred: Predicate[Int]) => pred match { - case IsPresent(_) => pred(None) == false - case _ => pred(None) == true + case IsPresent(_) => pred.run(None) == false + case _ => pred.run(None) == true } } } "Not negates the predicate" in { check { (pred: Predicate[Int], value: Int) => - !pred(Some(value)) == Not(pred)(Some(value)) + !pred.run(Some(value)) == Not(pred).run(Some(value)) } } diff --git a/brushfire-core/src/test/scala/com/stripe/brushfire/TreeGenerators.scala b/brushfire-core/src/test/scala/com/stripe/brushfire/TreeGenerators.scala index 2b5414b..4a90637 100644 --- a/brushfire-core/src/test/scala/com/stripe/brushfire/TreeGenerators.scala +++ b/brushfire-core/src/test/scala/com/stripe/brushfire/TreeGenerators.scala @@ -39,9 +39,7 @@ object TreeGenerators { key <- genK left <- genNode(index, maxDepth - 1) right <- genNode(index | (1 << (maxDepth - 1)), maxDepth - 1) - } yield SplitNode(List( - (key, pred, left), - (key, Not(pred), right))) + } yield SplitNode(pred, key, left, right) def genNode(index: Int, maxDepth: Int): Gen[Node[K, V, T, Unit]] = if (maxDepth > 1) { diff --git a/brushfire-core/src/test/scala/com/stripe/brushfire/TreeTraversalSpec.scala b/brushfire-core/src/test/scala/com/stripe/brushfire/TreeTraversalSpec.scala index daefc1c..93b6fc5 100644 --- a/brushfire-core/src/test/scala/com/stripe/brushfire/TreeTraversalSpec.scala +++ b/brushfire-core/src/test/scala/com/stripe/brushfire/TreeTraversalSpec.scala @@ -13,19 +13,25 @@ import org.scalatest.prop.Checkers class TreeTraversalSpec extends WordSpec with Matchers with Checkers { import TreeGenerators._ + def split[T, A: Semigroup](key: String, pred: Predicate[Double], left: Node[String, Double, T, A], right: Node[String, Double, T, A]): SplitNode[String, Double, T, A] = + SplitNode(pred, key, left, right) + "depthFirst" should { "always choose the left side of a split in a binary tree" in { val simpleTreeGen = genBinaryTree(arbitrary[String], arbitrary[Double], arbitrary[Map[String, Long]], 2) .filter(_.root match { - case SplitNode(children) => - children.collect { case (_, IsPresent(_), _) => true }.isEmpty + case SplitNode(p, _, _, _, _) => + p match { + case IsPresent(_) => false + case _ => true + } case _ => false }) check(Prop.forAll(simpleTreeGen) { tree => (tree.root: @unchecked) match { - case SplitNode(children) => - TreeTraversal.depthFirst.find(tree, Map.empty[String, Double], None).headOption == Some(children.head._3) + case SplitNode(_, _, lc, rc, _) => + TreeTraversal.depthFirst.find(tree, Map.empty[String, Double], None).headOption == Some(lc) } }) } @@ -44,18 +50,15 @@ class TreeTraversalSpec extends WordSpec with Matchers with Checkers { } } - def split[T, A: Semigroup](key: String, pred: Predicate[Double], left: Node[String, Double, T, A], right: Node[String, Double, T, A]): SplitNode[String, Double, T, A] = - SplitNode((key, pred, left) :: (key, Not(pred), right) :: Nil) - "weightedDepthFirst" should { implicit val traversal = TreeTraversal.weightedDepthFirst[String, Double, Double, Int] "choose the heaviest node in a split" in { val split1 = split("f1", LessThan(0D), LeafNode(0, -1D, 12), LeafNode(0, 1D, 9)) - AnnotatedTree(split1).leafFor(Map.empty[String, Double]) shouldBe Some(split1.children.head._3) + AnnotatedTree(split1).leafFor(Map.empty[String, Double]) shouldBe Some(split1.leftChild) val split2 = split("f2", LessThan(0D), LeafNode(0, -1D, -3), LeafNode(0, 1D, 9)) - AnnotatedTree(split2).leafFor(Map.empty[String, Double]) shouldBe Some(split2.children.last._3) + AnnotatedTree(split2).leafFor(Map.empty[String, Double]) shouldBe Some(split2.rightChild) } "choose heaviest path in a tree" in { @@ -77,11 +80,8 @@ class TreeTraversalSpec extends WordSpec with Matchers with Checkers { def collectLeafs[K, V, T, A](node: Node[K, V, T, A]): Set[LeafNode[K, V, T, A]] = node match { - case SplitNode(children) => - children.collect { - case (_, p, n) if p(None) => collectLeafs(n) - }.flatten.toSet - + case SplitNode(p, _, lc, rc, _) => + (if (p.run(None)) List(lc, rc) else Nil).flatMap(collectLeafs).toSet case leaf @ LeafNode(_, _, _) => Set(leaf) } diff --git a/brushfire-scalding/src/main/scala/com/stripe/brushfire/scalding/Trainer.scala b/brushfire-scalding/src/main/scala/com/stripe/brushfire/scalding/Trainer.scala index 36bc518..dce9a97 100644 --- a/brushfire-scalding/src/main/scala/com/stripe/brushfire/scalding/Trainer.scala +++ b/brushfire-scalding/src/main/scala/com/stripe/brushfire/scalding/Trainer.scala @@ -130,12 +130,11 @@ case class Trainer[K: Ordering, V, T: Monoid]( .flatMap { case ((treeIndex, leafIndex, feature), target) => treeMap(treeIndex).leafAt(leafIndex).toList.flatMap { leaf => - splitter - .split(leaf.target, target) - .map { rawSplit => - val (split, goodness) = evaluator.evaluate(rawSplit) + splitter.split(leaf.target, target).flatMap { rawSplit => + evaluator.evaluate(rawSplit).map { case (split, goodness) => treeIndex -> Map(leafIndex -> (feature, split, goodness)) } + } } } @@ -146,18 +145,15 @@ case class Trainer[K: Ordering, V, T: Monoid]( .group .withReducers(reducers) .sum - .map { - case (treeIndex, map) => - val newTree = - treeMap(treeIndex) - .growByLeafIndex { index => - for ( - (feature, split, _) <- map.get(index).toList; - (predicate, target) <- split.predicates - ) yield (feature, predicate, target, ()) - } + .map { case (treeIndex, map) => + val newTree = + treeMap(treeIndex).growByLeafIndex { index => + map.get(index).map { case (feature, split, _) => + split.createSplitNode(feature) + } + } - treeIndex -> newTree + treeIndex -> newTree.renumberLeaves }.writeThrough(TreeSource(path)) } } From 09213c8fe2d2d20c7067d8cf91e62501d6388ade Mon Sep 17 00:00:00 2001 From: Erik Osheim Date: Wed, 9 Dec 2015 17:50:02 -0800 Subject: [PATCH 02/11] Make Split concrete. This removes BinarySplit, since now all splits are binary. It is a bit nicer to be working with Split as a concrete case class. --- .../scala/com/stripe/brushfire/Brushfire.scala | 15 ++++++++------- .../scala/com/stripe/brushfire/Dispatched.scala | 7 ++----- .../scala/com/stripe/brushfire/Evaluators.scala | 12 ++++++------ .../scala/com/stripe/brushfire/Splitters.scala | 8 ++++---- .../main/scala/com/stripe/brushfire/TDigest.scala | 4 ++-- .../main/scala/com/stripe/brushfire/Tree.scala | 5 ++--- 6 files changed, 24 insertions(+), 27 deletions(-) diff --git a/brushfire-core/src/main/scala/com/stripe/brushfire/Brushfire.scala b/brushfire-core/src/main/scala/com/stripe/brushfire/Brushfire.scala index cca141f..da6799f 100644 --- a/brushfire-core/src/main/scala/com/stripe/brushfire/Brushfire.scala +++ b/brushfire-core/src/main/scala/com/stripe/brushfire/Brushfire.scala @@ -42,17 +42,18 @@ trait Splitter[V, T] { } /** Candidate split for a tree node */ -sealed trait Split[V, T] { - def predicate: Predicate[V] - def leftDistribution: T - def rightDistribution: T - def distributions: List[T] = leftDistribution :: rightDistribution :: Nil +case class Split[V, T](predicate: Predicate[V], leftDistribution: T, rightDistribution: T) { + /** + * Given a feature key, create a SplitNode from this Split. + * + * Note that the leaves of this node will likely need to be + * renumbered if this node is put into a larger tree. + */ def createSplitNode[K](feature: K): SplitNode[K, V, T, Unit] = - SplitNode(predicate, feature, LeafNode(-1, leftDistribution), LeafNode(-1, rightDistribution)) + SplitNode(predicate, feature, LeafNode(0, leftDistribution), LeafNode(1, rightDistribution)) } -case class BinarySplit[V, T](predicate: Predicate[V], leftDistribution: T, rightDistribution: T) extends Split[V, T] /** Evaluates the goodness of a candidate split */ trait Evaluator[V, T] { diff --git a/brushfire-core/src/main/scala/com/stripe/brushfire/Dispatched.scala b/brushfire-core/src/main/scala/com/stripe/brushfire/Dispatched.scala index cc50f76..50073e5 100644 --- a/brushfire-core/src/main/scala/com/stripe/brushfire/Dispatched.scala +++ b/brushfire-core/src/main/scala/com/stripe/brushfire/Dispatched.scala @@ -57,9 +57,6 @@ object Dispatched { def continuous[C](c: C) = Continuous(c) def sparse[D](d: D) = Sparse(d) - def wrapSplits[X, T, A: Ordering, B, C: Ordering, D](splits: Iterable[Split[X, T]])(fn: X => Dispatched[A, B, C, D]) = { - splits.map { split => - BinarySplit(split.predicate.map(fn), split.leftDistribution, split.rightDistribution) - } - } + def wrapSplits[X, T, A: Ordering, B, C: Ordering, D](splits: Iterable[Split[X, T]])(fn: X => Dispatched[A, B, C, D]) = + splits.map { case Split(p, left, right) => Split(p.map(fn), left, right) } } diff --git a/brushfire-core/src/main/scala/com/stripe/brushfire/Evaluators.scala b/brushfire-core/src/main/scala/com/stripe/brushfire/Evaluators.scala index d46f311..5e54320 100644 --- a/brushfire-core/src/main/scala/com/stripe/brushfire/Evaluators.scala +++ b/brushfire-core/src/main/scala/com/stripe/brushfire/Evaluators.scala @@ -5,7 +5,8 @@ import com.twitter.algebird._ case class ChiSquaredEvaluator[V, L, W](implicit weightMonoid: Monoid[W], weightDouble: W => Double) extends Evaluator[V, Map[L, W]] { def evaluate(split: Split[V, Map[L, W]]): Option[(Split[V, Map[L, W]], Double)] = { - val rows = split.distributions.filter { _.nonEmpty } + val Split(_, left, right) = split + val rows = (left :: right :: Nil).filter(_.nonEmpty) if (rows.size > 1) { val n = weightMonoid.sum(rows.flatMap { _.values }) val rowTotals = rows.map { row => weightMonoid.sum(row.values) }.toList @@ -33,16 +34,15 @@ case class MinWeightEvaluator[V, L, W: Monoid](minWeight: W => Boolean, wrapped: private[this] def test(dist: Map[L, W]): Boolean = minWeight(Monoid.sum(dist.values)) def evaluate(split: Split[V, Map[L, W]]): Option[(Split[V, Map[L, W]], Double)] = - wrapped.evaluate(split).filter { case (baseSplit, _) => - test(baseSplit.leftDistribution) && test(baseSplit.rightDistribution) + wrapped.evaluate(split).filter { case (Split(_, left, right), _) => + test(left) && test(right) } } case class ErrorEvaluator[V, T, P, E](error: Error[T, P, E], voter: Voter[T, P])(fn: E => Double) extends Evaluator[V, T] { def evaluate(split: Split[V, T]): Option[(Split[V, T], Double)] = { + val Split(_, left, right) = split def e(t: T): E = error.create(t, voter.combine(Some(t))) - val e0 = e(split.leftDistribution) - val e1 = e(split.rightDistribution) - Some((split, -fn(error.semigroup.plus(e0, e1)))) + Some((split, -fn(error.semigroup.plus(e(left), e(right))))) } } diff --git a/brushfire-core/src/main/scala/com/stripe/brushfire/Splitters.scala b/brushfire-core/src/main/scala/com/stripe/brushfire/Splitters.scala index 0c915a6..2e37faa 100644 --- a/brushfire-core/src/main/scala/com/stripe/brushfire/Splitters.scala +++ b/brushfire-core/src/main/scala/com/stripe/brushfire/Splitters.scala @@ -13,7 +13,7 @@ case class BinarySplitter[V, T: Monoid](partition: V => Predicate[V]) extends Sp stats.keys.map { v => val predicate = partition(v) val (trues, falses) = stats.partition { case (v, d) => predicate.run(Some(v)) } - BinarySplit(predicate, Monoid.sum(trues.values), Monoid.sum(falses.values)) + Split(predicate, Monoid.sum(trues.values), Monoid.sum(falses.values)) } } } @@ -48,7 +48,7 @@ case class QTreeSplitter[T: Monoid](k: Int) val predicate = LessThan(threshold) val leftDist = stats.rangeSumBounds(stats.lowerBound, threshold)._1 val rightDist = stats.rangeSumBounds(threshold, stats.upperBound)._1 - BinarySplit(predicate, leftDist, rightDist) + Split(predicate, leftDist, rightDist) } } @@ -69,7 +69,7 @@ case class SparseSplitter[V, T: Group]() extends Splitter[V, T] { def create(value: V, target: T) = target val semigroup = implicitly[Semigroup[T]] def split(parent: T, stats: T) = - BinarySplit(IsPresent[V](None), stats, Group.minus(parent, stats)) :: Nil + Split(IsPresent[V](None), stats, Group.minus(parent, stats)) :: Nil } case class SpaceSaverSplitter[V, L](capacity: Int = 1000) @@ -88,7 +88,7 @@ case class SpaceSaverSplitter[V, L](capacity: Int = 1000) .flatMap { _.counters.keys }.toSet .map { v: V => val mins = stats.mapValues { ss => ss.frequency(v).min } - BinarySplit(EqualTo(v), mins, Group.minus(parent, mins)) + Split(EqualTo(v), mins, Group.minus(parent, mins)) } } } diff --git a/brushfire-core/src/main/scala/com/stripe/brushfire/TDigest.scala b/brushfire-core/src/main/scala/com/stripe/brushfire/TDigest.scala index 8f60384..46cf6e0 100644 --- a/brushfire-core/src/main/scala/com/stripe/brushfire/TDigest.scala +++ b/brushfire-core/src/main/scala/com/stripe/brushfire/TDigest.scala @@ -79,7 +79,7 @@ case class TDigestSplitter[L](k: Int = 25, compression: Double = 100.0) extends // and so they can be discarded immediately if left.nonEmpty || right.nonEmpty } yield { - BinarySplit(LessThan(q), left, right) + Split(LessThan(q), left, right) } // if the input is not continuous or has too few examples we will end up @@ -101,4 +101,4 @@ case class TDigestSplitter[L](k: Int = 25, compression: Double = 100.0) extends td } } -} \ No newline at end of file +} diff --git a/brushfire-core/src/main/scala/com/stripe/brushfire/Tree.scala b/brushfire-core/src/main/scala/com/stripe/brushfire/Tree.scala index dea2161..ec15c8a 100644 --- a/brushfire-core/src/main/scala/com/stripe/brushfire/Tree.scala +++ b/brushfire-core/src/main/scala/com/stripe/brushfire/Tree.scala @@ -31,14 +31,13 @@ object Tree { } yield (f, tpl) if (splits.isEmpty) None else { - val (splitFeature, (split, _)) = splits.maxBy { case (f, (x, s)) => s } - val pred = split.predicate + val (splitFeature, (Split(pred, left, right), _)) = splits.maxBy { case (f, (x, s)) => s } def ex(dist: T): Node[K, V, T, Unit] = { val newInstances = instances.filter { inst => pred.run(inst.features.get(splitFeature)) } val target = Monoid.sum(newInstances.map(_.target)) expand(times - 1, treeIndex, LeafNode(0, target), splitter, evaluator, stopper, sampler, newInstances) } - Some(SplitNode(pred, splitFeature, ex(split.leftDistribution), ex(split.rightDistribution))) + Some(SplitNode(pred, splitFeature, ex(left), ex(right))) } }.getOrElse(leaf) } else { From 7610672ede5f8a870455ecaa8fd48427ef2ad928 Mon Sep 17 00:00:00 2001 From: Erik Osheim Date: Thu, 10 Dec 2015 11:09:44 -0800 Subject: [PATCH 03/11] Address the first round of review comments. This should fix all the issues raised by @tixxit. --- .../com/stripe/brushfire/AnnotatedTree.scala | 36 +++++++++---------- .../com/stripe/brushfire/Brushfire.scala | 2 +- .../com/stripe/brushfire/Injections.scala | 19 +++++----- .../scala/com/stripe/brushfire/Tree.scala | 2 +- .../com/stripe/brushfire/TreeTraversal.scala | 2 +- .../com/stripe/brushfire/TreeGenerators.scala | 4 +-- .../stripe/brushfire/TreeTraversalSpec.scala | 6 ++-- 7 files changed, 35 insertions(+), 36 deletions(-) diff --git a/brushfire-core/src/main/scala/com/stripe/brushfire/AnnotatedTree.scala b/brushfire-core/src/main/scala/com/stripe/brushfire/AnnotatedTree.scala index 387ca9f..b401ea3 100644 --- a/brushfire-core/src/main/scala/com/stripe/brushfire/AnnotatedTree.scala +++ b/brushfire-core/src/main/scala/com/stripe/brushfire/AnnotatedTree.scala @@ -9,16 +9,16 @@ sealed abstract class Node[K, V, T, A] { def renumber(nextId: Int): (Int, Node[K, V, T, A]) = this match { - case SplitNode(p, k, lc0, rc0, a) => + case SplitNode(k, p, lc0, rc0, a) => val (n1, lc) = lc0.renumber(nextId) val (n2, rc) = rc0.renumber(n1) - (n2, SplitNode(p, k, lc, rc, a)) + (n2, SplitNode(k, p, lc, rc, a)) case LeafNode(_, target, annotation) => (nextId + 1, LeafNode(nextId, target, annotation)) } } -case class SplitNode[K, V, T, A](predicate: Predicate[V], key: K, leftChild: Node[K, V, T, A], rightChild: Node[K, V, T, A], annotation: A) extends Node[K, V, T, A] { +case class SplitNode[K, V, T, A](key: K, predicate: Predicate[V], leftChild: Node[K, V, T, A], rightChild: Node[K, V, T, A], annotation: A) extends Node[K, V, T, A] { def evaluate(row: Map[K, V]): List[Node[K, V, T, A]] = predicate(row.get(key)) match { case Some(true) => leftChild :: Nil @@ -28,8 +28,8 @@ case class SplitNode[K, V, T, A](predicate: Predicate[V], key: K, leftChild: Nod } object SplitNode { - def apply[K, V, T, A: Semigroup](p: Predicate[V], k: K, lc: Node[K, V, T, A], rc: Node[K, V, T, A]): SplitNode[K, V, T, A] = - SplitNode(p, k, lc, rc, Semigroup.plus(lc.annotation, rc.annotation)) + def apply[K, V, T, A: Semigroup](k: K, p: Predicate[V], lc: Node[K, V, T, A], rc: Node[K, V, T, A]): SplitNode[K, V, T, A] = + SplitNode(k, p, lc, rc, Semigroup.plus(lc.annotation, rc.annotation)) } case class LeafNode[K, V, T, A](index: Int, target: T, annotation: A) extends Node[K, V, T, A] @@ -48,9 +48,9 @@ case class AnnotatedTree[K, V, T, A: Semigroup](root: Node[K, V, T, A]) { private def mapSplits[K0, V0](f: (K, Predicate[V]) => (K0, Predicate[V0])): AnnotatedTree[K0, V0, T, A] = { def recur(node: Node[K, V, T, A]): Node[K0, V0, T, A] = { node match { - case SplitNode(p, k, lc, rc, _) => + case SplitNode(k, p, lc, rc, _) => val (key, pred) = f(k, p) - SplitNode(pred, key, recur(lc), recur(rc)) + SplitNode(key, pred, recur(lc), recur(rc)) case LeafNode(index, target, a) => LeafNode(index, target, a) } @@ -64,8 +64,8 @@ case class AnnotatedTree[K, V, T, A: Semigroup](root: Node[K, V, T, A]) { */ private def mapLeaves[T0, A0: Semigroup](f: (T, A) => (T0, A0)): AnnotatedTree[K, V, T0, A0] = { def recur(node: Node[K, V, T, A]): Node[K, V, T0, A0] = node match { - case SplitNode(p, k, lc, rc, _) => - SplitNode(p, k, recur(lc), recur(rc)) + case SplitNode(k, p, lc, rc, _) => + SplitNode(k, p, recur(lc), recur(rc)) case LeafNode(index, target, a) => val (target1, a1) = f(target, a) LeafNode(index, target1, a1) @@ -117,7 +117,7 @@ case class AnnotatedTree[K, V, T, A: Semigroup](root: Node[K, V, T, A]) { */ def leafAt(leafIndex: Int, start: Node[K, V, T, A]): Option[LeafNode[K, V, T, A]] = start match { - case SplitNode(p, k, lc, rc, _) => + case SplitNode(k, p, lc, rc, _) => leafAt(leafIndex, lc) match { case None => leafAt(leafIndex, rc) case some => some @@ -159,7 +159,7 @@ case class AnnotatedTree[K, V, T, A: Semigroup](root: Node[K, V, T, A]) { // Bounce at the bottom and start back up the tree. (validationData, leaf) - case SplitNode(p, k, lc0, rc0, _) => + case SplitNode(k, p, lc0, rc0, _) => val (vData, lc1) = pruneNode(validationData, lc0, voter, error) val (newData, rc1) = pruneNode(vData, rc0, voter, error) @@ -169,12 +169,12 @@ case class AnnotatedTree[K, V, T, A: Semigroup](root: Node[K, V, T, A]) { case lc2 @ LeafNode(_, _, _) => rc1 match { case rc2 @ LeafNode(_, _, _) => - pruneLevel(SplitNode(p, k, lc2, rc2), lc2, rc2, newData, voter, error) + pruneLevel(SplitNode(k, p, lc2, rc2), lc2, rc2, newData, voter, error) case _ => - (newData, SplitNode(p, k, lc1, rc1)) + (newData, SplitNode(k, p, lc1, rc1)) } case _ => - (newData, SplitNode(p, k, lc1, rc1)) + (newData, SplitNode(k, p, lc1, rc1)) } } } @@ -249,10 +249,10 @@ case class AnnotatedTree[K, V, T, A: Semigroup](root: Node[K, V, T, A]) { case None => (nextIndex + 1, LeafNode(nextIndex, target, annotation)) case Some(split) => split.renumber(nextIndex) } - case SplitNode(p, k, lc0, rc0, _) => + case SplitNode(k, p, lc0, rc0, _) => val (n1, lc) = growFrom(nextIndex, lc0) val (n2, rc) = growFrom(n1, rc0) - (n2, SplitNode(p, k, lc, rc)) + (n2, SplitNode(k, p, lc, rc)) } } @@ -270,8 +270,8 @@ case class AnnotatedTree[K, V, T, A: Semigroup](root: Node[K, V, T, A]) { start match { case LeafNode(index, _, _) => fn(index).getOrElse(start) - case SplitNode(p, k, lc0, rc0, _) => - SplitNode(p, k, updateFrom(lc0), updateFrom(rc0)) + case SplitNode(k, p, lc0, rc0, _) => + SplitNode(k, p, updateFrom(lc0), updateFrom(rc0)) } } diff --git a/brushfire-core/src/main/scala/com/stripe/brushfire/Brushfire.scala b/brushfire-core/src/main/scala/com/stripe/brushfire/Brushfire.scala index da6799f..ac27458 100644 --- a/brushfire-core/src/main/scala/com/stripe/brushfire/Brushfire.scala +++ b/brushfire-core/src/main/scala/com/stripe/brushfire/Brushfire.scala @@ -51,7 +51,7 @@ case class Split[V, T](predicate: Predicate[V], leftDistribution: T, rightDistri * renumbered if this node is put into a larger tree. */ def createSplitNode[K](feature: K): SplitNode[K, V, T, Unit] = - SplitNode(predicate, feature, LeafNode(0, leftDistribution), LeafNode(1, rightDistribution)) + SplitNode(feature, predicate, LeafNode(0, leftDistribution), LeafNode(1, rightDistribution)) } diff --git a/brushfire-core/src/main/scala/com/stripe/brushfire/Injections.scala b/brushfire-core/src/main/scala/com/stripe/brushfire/Injections.scala index f3632e7..82fbf79 100644 --- a/brushfire-core/src/main/scala/com/stripe/brushfire/Injections.scala +++ b/brushfire-core/src/main/scala/com/stripe/brushfire/Injections.scala @@ -124,21 +124,20 @@ object JsonInjections { obj.put("distribution", toJsonNode(target)) obj - case SplitNode(p, k, lc, rc, _) => + case SplitNode(k, p, lc, rc, _) => val obj = JsonNodeFactory.instance.objectNode - obj.put("predicate", toJsonNode(p)(predicateJsonNodeInjection)) obj.put("key", toJsonNode(k)) + obj.put("predicate", toJsonNode(p)(predicateJsonNodeInjection)) obj.put("left", toJsonNode(lc)(nodeJsonNodeInjection)) obj.put("right", toJsonNode(rc)(nodeJsonNodeInjection)) obj } - def tryChild(node: JsonNode, property: String): Try[JsonNode] = - Try { - val child = node.get(property) - assert(child != null, property + " != null") - child - } + def tryChild(node: JsonNode, property: String): Try[JsonNode] = { + val child = node.get(property) + if (child == null) Failure(new IllegalArgumentException(property + " != null")) + else Success(child) + } def tryLoad[T: JsonNodeInjection](node: JsonNode, property: String): Try[T] = { val child = node.get(property) @@ -154,11 +153,11 @@ object JsonInjections { } yield LeafNode(index, target) } else { for { - p <- tryLoad[Predicate[V]](n, "predicate") k <- tryLoad[K](n, "key") + p <- tryLoad[Predicate[V]](n, "predicate") left <- tryLoad[Node[K, V, T, Unit]](n, "left")(nodeJsonNodeInjection) right <- tryLoad[Node[K, V, T, Unit]](n, "right")(nodeJsonNodeInjection) - } yield SplitNode(p, k, left, right) + } yield SplitNode(k, p, left, right) } } diff --git a/brushfire-core/src/main/scala/com/stripe/brushfire/Tree.scala b/brushfire-core/src/main/scala/com/stripe/brushfire/Tree.scala index ec15c8a..271de5d 100644 --- a/brushfire-core/src/main/scala/com/stripe/brushfire/Tree.scala +++ b/brushfire-core/src/main/scala/com/stripe/brushfire/Tree.scala @@ -37,7 +37,7 @@ object Tree { val target = Monoid.sum(newInstances.map(_.target)) expand(times - 1, treeIndex, LeafNode(0, target), splitter, evaluator, stopper, sampler, newInstances) } - Some(SplitNode(pred, splitFeature, ex(left), ex(right))) + Some(SplitNode(splitFeature, pred, ex(left), ex(right))) } }.getOrElse(leaf) } else { diff --git a/brushfire-core/src/main/scala/com/stripe/brushfire/TreeTraversal.scala b/brushfire-core/src/main/scala/com/stripe/brushfire/TreeTraversal.scala index 80bb433..7973f0b 100644 --- a/brushfire-core/src/main/scala/com/stripe/brushfire/TreeTraversal.scala +++ b/brushfire-core/src/main/scala/com/stripe/brushfire/TreeTraversal.scala @@ -151,7 +151,7 @@ case class DepthFirstTreeTraversal[K, V, T, A](order: (Random, List[Node[K, V, T case (split @ SplitNode(_, _, _, _, _)) :: rest => val newStack = split.evaluate(row) match { case Nil => rest - //case node :: Nil => node :: rest + case node :: Nil => node :: rest case candidates => order(rng, candidates) ::: rest } loop(newStack) diff --git a/brushfire-core/src/test/scala/com/stripe/brushfire/TreeGenerators.scala b/brushfire-core/src/test/scala/com/stripe/brushfire/TreeGenerators.scala index 4a90637..5a7afb8 100644 --- a/brushfire-core/src/test/scala/com/stripe/brushfire/TreeGenerators.scala +++ b/brushfire-core/src/test/scala/com/stripe/brushfire/TreeGenerators.scala @@ -35,11 +35,11 @@ object TreeGenerators { genT.map(LeafNode(index, _)) def genSplit(index: Int, maxDepth: Int) = for { - pred <- genPredicate(genV) key <- genK + pred <- genPredicate(genV) left <- genNode(index, maxDepth - 1) right <- genNode(index | (1 << (maxDepth - 1)), maxDepth - 1) - } yield SplitNode(pred, key, left, right) + } yield SplitNode(key, pred, left, right) def genNode(index: Int, maxDepth: Int): Gen[Node[K, V, T, Unit]] = if (maxDepth > 1) { diff --git a/brushfire-core/src/test/scala/com/stripe/brushfire/TreeTraversalSpec.scala b/brushfire-core/src/test/scala/com/stripe/brushfire/TreeTraversalSpec.scala index 93b6fc5..a988918 100644 --- a/brushfire-core/src/test/scala/com/stripe/brushfire/TreeTraversalSpec.scala +++ b/brushfire-core/src/test/scala/com/stripe/brushfire/TreeTraversalSpec.scala @@ -14,13 +14,13 @@ class TreeTraversalSpec extends WordSpec with Matchers with Checkers { import TreeGenerators._ def split[T, A: Semigroup](key: String, pred: Predicate[Double], left: Node[String, Double, T, A], right: Node[String, Double, T, A]): SplitNode[String, Double, T, A] = - SplitNode(pred, key, left, right) + SplitNode(key, pred, left, right) "depthFirst" should { "always choose the left side of a split in a binary tree" in { val simpleTreeGen = genBinaryTree(arbitrary[String], arbitrary[Double], arbitrary[Map[String, Long]], 2) .filter(_.root match { - case SplitNode(p, _, _, _, _) => + case SplitNode(_, p, _, _, _) => p match { case IsPresent(_) => false case _ => true @@ -80,7 +80,7 @@ class TreeTraversalSpec extends WordSpec with Matchers with Checkers { def collectLeafs[K, V, T, A](node: Node[K, V, T, A]): Set[LeafNode[K, V, T, A]] = node match { - case SplitNode(p, _, lc, rc, _) => + case SplitNode(_, p, lc, rc, _) => (if (p.run(None)) List(lc, rc) else Nil).flatMap(collectLeafs).toSet case leaf @ LeafNode(_, _, _) => Set(leaf) From f9ff2a11ce8a9419c6c8ec5805d61426f15184f3 Mon Sep 17 00:00:00 2001 From: Erik Osheim Date: Tue, 15 Dec 2015 17:07:15 -0800 Subject: [PATCH 04/11] Initial Bonsai integration. The tests pass. This is definitely not in the best form, but it's a start! I need to do some benchmarking to try to figure out what kind of performance impact this change has. --- brushfire-core/build.sbt | 1 + .../com/stripe/brushfire/AnnotatedTree.scala | 51 +++++++- .../scala/com/stripe/brushfire/Tree.scala | 4 + .../com/stripe/brushfire/TreeTraversal.scala | 122 +++++++++++------- .../scala/com/stripe/brushfire/Types.scala | 19 +++ .../scala/com/stripe/brushfire/Voter.scala | 5 +- .../com/stripe/brushfire/local/Example.scala | 13 +- .../com/stripe/brushfire/local/Trainer.scala | 31 +++-- .../scala/com/stripe/brushfire/package.scala | 1 - .../stripe/brushfire/TreeTraversalSpec.scala | 88 +++++++------ .../stripe/brushfire/scalding/Trainer.scala | 10 +- project/Deps.scala | 2 + 12 files changed, 228 insertions(+), 119 deletions(-) create mode 100644 brushfire-core/src/main/scala/com/stripe/brushfire/Types.scala diff --git a/brushfire-core/build.sbt b/brushfire-core/build.sbt index c507e6f..29d15d6 100644 --- a/brushfire-core/build.sbt +++ b/brushfire-core/build.sbt @@ -5,6 +5,7 @@ libraryDependencies ++= { Seq( algebirdCore, bijectionJson, + bonsai, chillBijection, jacksonMapper, jacksonXC, diff --git a/brushfire-core/src/main/scala/com/stripe/brushfire/AnnotatedTree.scala b/brushfire-core/src/main/scala/com/stripe/brushfire/AnnotatedTree.scala index b401ea3..9b4a4cc 100644 --- a/brushfire-core/src/main/scala/com/stripe/brushfire/AnnotatedTree.scala +++ b/brushfire-core/src/main/scala/com/stripe/brushfire/AnnotatedTree.scala @@ -1,10 +1,14 @@ package com.stripe.brushfire import com.twitter.algebird._ +import com.stripe.bonsai.{ FullBinaryTree, FullBinaryTreeOps } import java.lang.Math.{abs, max} +import Types._ + sealed abstract class Node[K, V, T, A] { + def annotation: A def renumber(nextId: Int): (Int, Node[K, V, T, A]) = @@ -16,6 +20,12 @@ sealed abstract class Node[K, V, T, A] { case LeafNode(_, target, annotation) => (nextId + 1, LeafNode(nextId, target, annotation)) } + + def fold[B](f: (Node[K, V, T, A], Node[K, V, T, A], (K, Predicate[V], A)) => B, g: Tuple3[Int, T, A] => B): B = + this match { + case SplitNode(k, p, lc, rc, a) => f(lc, rc, (k, p, a)) + case LeafNode(index, target, a) => g((index, target, a)) + } } case class SplitNode[K, V, T, A](key: K, predicate: Predicate[V], leftChild: Node[K, V, T, A], rightChild: Node[K, V, T, A], annotation: A) extends Node[K, V, T, A] { @@ -25,6 +35,8 @@ case class SplitNode[K, V, T, A](key: K, predicate: Predicate[V], leftChild: Nod case Some(false) => rightChild :: Nil case None => leftChild :: rightChild :: Nil } + + def splitLabel: (K, Predicate[V], A) = (key, predicate, annotation) } object SplitNode { @@ -32,7 +44,9 @@ object SplitNode { SplitNode(k, p, lc, rc, Semigroup.plus(lc.annotation, rc.annotation)) } -case class LeafNode[K, V, T, A](index: Int, target: T, annotation: A) extends Node[K, V, T, A] +case class LeafNode[K, V, T, A](index: Int, target: T, annotation: A) extends Node[K, V, T, A] { + def leafLabel: (Int, T, A) = (index, target, annotation) +} object LeafNode { def apply[K, V, T, A: Monoid](index: Int, target: T): LeafNode[K, V, T, A] = @@ -41,6 +55,8 @@ object LeafNode { case class AnnotatedTree[K, V, T, A: Semigroup](root: Node[K, V, T, A]) { + import AnnotatedTree.AnnotatedTreeTraversal + /** * Transform the splits of a tree (the predicates and keys of the * nodes) while leaving everything else alone. @@ -227,14 +243,14 @@ case class AnnotatedTree[K, V, T, A: Semigroup](root: Node[K, V, T, A]) { } } - def leafFor(row: Map[K, V], id: Option[String] = None)(implicit traversal: TreeTraversal[K, V, T, A]): Option[LeafNode[K, V, T, A]] = - traversal.find(this, row, id).headOption + def leafFor(row: Map[K, V], id: Option[String] = None)(implicit traversal: AnnotatedTreeTraversal[K, V, T, A]): Option[(Int, T, A)] = + traversal.search(this, row, id).headOption - def leafIndexFor(row: Map[K, V], id: Option[String] = None)(implicit traversal: TreeTraversal[K, V, T, A]): Option[Int] = - leafFor(row, id).map(_.index) + def leafIndexFor(row: Map[K, V], id: Option[String] = None)(implicit traversal: AnnotatedTreeTraversal[K, V, T, A]): Option[Int] = + leafFor(row, id).map(_._1) - def targetFor(row: Map[K, V], id: Option[String] = None)(implicit traversal: TreeTraversal[K, V, T, A], semigroup: Semigroup[T]): Option[T] = - semigroup.sumOption(traversal.find(this, row, id).map(_.target)) + def targetFor(row: Map[K, V], id: Option[String] = None)(implicit traversal: AnnotatedTreeTraversal[K, V, T, A], semigroup: Semigroup[T]): Option[T] = + semigroup.sumOption(leafFor(row, id).map(_._2)) /** * For each leaf, this may convert the leaf to a [[SplitNode]] whose children @@ -285,4 +301,25 @@ case class AnnotatedTree[K, V, T, A: Semigroup](root: Node[K, V, T, A]) { */ def renumberLeaves: AnnotatedTree[K, V, T, A] = AnnotatedTree(root.renumber(0)._2) + + def compress: FullBinaryTree[(K, Predicate[V], A), (Int, T, A)] = + FullBinaryTree(this) +} + +object AnnotatedTree { + + type AnnotatedTreeTraversal[K, V, T, A] = TreeTraversal[AnnotatedTree[K, V, T, A], K, V, T, A] + + implicit def fullBinaryTreeOpsForAnnotatedTree[K, V, T, A]: FullBinaryTreeOps[AnnotatedTree[K, V, T, A], (K, Predicate[V], A), (Int, T, A)] = + new FullBinaryTreeOpsForAnnotatedTree[K, V, T, A] +} + +class FullBinaryTreeOpsForAnnotatedTree[K, V, T, A] extends FullBinaryTreeOps[AnnotatedTree[K, V, T, A], (K, Predicate[V], A), (Int, T, A)] { + type Node = com.stripe.brushfire.Node[K, V, T, A] + + def root(t: AnnotatedTree[K, V, T, A]): Option[Node] = Some(t.root) + + def foldNode[B](node: Node)( + f: (Node, Node, (K, Predicate[V], A)) => B, + g: Tuple3[Int, T, A] => B): B = node.fold(f, g) } diff --git a/brushfire-core/src/main/scala/com/stripe/brushfire/Tree.scala b/brushfire-core/src/main/scala/com/stripe/brushfire/Tree.scala index 271de5d..65bffe6 100644 --- a/brushfire-core/src/main/scala/com/stripe/brushfire/Tree.scala +++ b/brushfire-core/src/main/scala/com/stripe/brushfire/Tree.scala @@ -1,5 +1,6 @@ package com.stripe.brushfire +import com.stripe.bonsai.FullBinaryTreeOps import com.twitter.algebird._ // type Tree[K, V, T] = AnnotatedTree[K, V, T, Unit] @@ -44,4 +45,7 @@ object Tree { leaf } } + + implicit def fullBinaryTreeOpsForTree[K, V, T]: FullBinaryTreeOps[Tree[K, V, T], (K, Predicate[V], Unit), (Int, T, Unit)] = + new FullBinaryTreeOpsForAnnotatedTree[K, V, T, Unit] } diff --git a/brushfire-core/src/main/scala/com/stripe/brushfire/TreeTraversal.scala b/brushfire-core/src/main/scala/com/stripe/brushfire/TreeTraversal.scala index 7973f0b..9fb2371 100644 --- a/brushfire-core/src/main/scala/com/stripe/brushfire/TreeTraversal.scala +++ b/brushfire-core/src/main/scala/com/stripe/brushfire/TreeTraversal.scala @@ -7,6 +7,9 @@ import scala.util.Random import scala.util.hashing.MurmurHash3 import com.twitter.algebird._ +import com.stripe.bonsai._ + +import Types._ /** * A `TreeTraversal` provides a way to find all of the leaves in a tree that @@ -15,12 +18,14 @@ import com.twitter.algebird._ * features). A tree traversal chooses which paths to go down (which may be all * of them) and the order in which they are traversed. */ -trait TreeTraversal[K, V, T, A] { +trait TreeTraversal[Tree, K, V, T, A] { + + val treeOps: FullBinaryTreeOps[Tree, BranchLabel[K, V, A], LeafLabel[T, A]] /** * Limit the maximum number of leaves returned from `find` to `n`. */ - def limitTo(n: Int): TreeTraversal[K, V, T, A] = + def limitTo(n: Int): TreeTraversal[Tree, K, V, T, A] = LimitedTreeTraversal(this, n) /** @@ -34,8 +39,11 @@ trait TreeTraversal[K, V, T, A] { * @param row the row/instance we're trying to match with a leaf node * @return the leaf nodes that best match the row */ - def find(tree: AnnotatedTree[K, V, T, A], row: Map[K, V], id: Option[String]): Stream[LeafNode[K, V, T, A]] = - find(tree.root, row, id) + def search(tree: Tree, row: Map[K, V], id: Option[String]): Stream[LeafLabel[T, A]] = + treeOps.root(tree) match { + case Some(root) => searchNode(root, row, id) + case None => Stream.empty + } /** * Find the [[LeafNode]]s that best fit `row` in the tree. Generally, the @@ -48,20 +56,47 @@ trait TreeTraversal[K, V, T, A] { * @param row the row/instance we're trying to match with a leaf node * @return the leaf nodes that match the row */ - def find(node: Node[K, V, T, A], row: Map[K, V], id: Option[String]): Stream[LeafNode[K, V, T, A]] + def searchNode(node: treeOps.Node, row: Map[K, V], id: Option[String]): Stream[LeafLabel[T, A]] +} + +trait Reorder[A] { + def apply[N](r: Random, ns: List[N], f: N => A): List[N] +} + +object Reorder { + def unchanged[A]: Reorder[A] = + new Reorder[A] { + def apply[N](r: Random, ns: List[N], f: N => A): List[N] = ns + } + + def shuffled[A]: Reorder[A] = + new Reorder[A] { + def apply[N](r: Random, ns: List[N], f: N => A): List[N] = r.shuffle(ns) + } + + def weightedDepthFirst[A](implicit ev: Ordering[A]): Reorder[A] = + new Reorder[A] { + def apply[N](r: Random, ns: List[N], f: N => A): List[N] = ns.sortBy(f)(ev.reverse) + } + + def probabilisticWeightedDepthFirst[A](conversion: A => Double): Reorder[A] = + new Reorder[A] { + def apply[N](r: Random, ns: List[N], f: N => A): List[N] = + TreeTraversal.probabilisticShuffle(r, ns)(n => conversion(f(n))) + } } object TreeTraversal { - def find[K, V, T, A](tree: AnnotatedTree[K, V, T, A], row: Map[K, V], id: Option[String] = None)(implicit traversal: TreeTraversal[K, V, T, A]): Stream[LeafNode[K, V, T, A]] = - traversal.find(tree, row, id) + def search[Tree, K, V, T, A](tree: Tree, row: Map[K, V], id: Option[String] = None)(implicit ev: TreeTraversal[Tree, K, V, T, A]): Stream[LeafLabel[T, A]] = + ev.search(tree, row, id) /** * Performs a depth-first traversal of the tree, returning all matching leaf * nodes. */ - implicit def depthFirst[K, V, T, A]: TreeTraversal[K, V, T, A] = - DepthFirstTreeTraversal((_, xs) => xs) + implicit def depthFirst[Tree, K, V, T, A](implicit treeOps: FullBinaryTreeOps[Tree, BranchLabel[K, V, A], LeafLabel[T, A]]): TreeTraversal[Tree, K, V, T, A] = + DepthFirstTreeTraversal(Reorder.unchanged) /** * A depth first search for matching leaves, randomly choosing the order of @@ -70,8 +105,8 @@ object TreeTraversal { * descend from that node are traversed before moving onto the node's * sibling. */ - def randomDepthFirst[K, V, T, A]: TreeTraversal[K, V, T, A] = - DepthFirstTreeTraversal(_ shuffle _) + def randomDepthFirst[Tree, K, V, T, A](implicit treeOps: FullBinaryTreeOps[Tree, BranchLabel[K, V, A], LeafLabel[T, A]]): TreeTraversal[Tree, K, V, T, A] = + DepthFirstTreeTraversal(Reorder.shuffled) /** * A depth-first search for matching leaves, where the candidate child nodes @@ -79,8 +114,8 @@ object TreeTraversal { * annotations. This means that if we have multiple valid candidate children, * we will traverse the child with the largest annotation first. */ - def weightedDepthFirst[K, V, T, A: Ordering]: TreeTraversal[K, V, T, A] = - DepthFirstTreeTraversal((_, xs) => xs.sortBy(_.annotation)(Ordering[A].reverse)) + def weightedDepthFirst[Tree, K, V, T, A: Ordering](implicit treeOps: FullBinaryTreeOps[Tree, BranchLabel[K, V, A], LeafLabel[T, A]]): TreeTraversal[Tree, K, V, T, A] = + DepthFirstTreeTraversal(Reorder.weightedDepthFirst) /** * A depth-first search for matching leaves, where the candidate child leaves @@ -93,8 +128,8 @@ object TreeTraversal { * proportional to its probability of being sampled, relative to all the * other elements still in the set. */ - def probabilisticWeightedDepthFirst[K, V, T, A](implicit conversion: A => Double): TreeTraversal[K, V, T, A] = - DepthFirstTreeTraversal(probabilisticShuffle(_, _)(_.annotation)) + def probabilisticWeightedDepthFirst[Tree, K, V, T, A](implicit treeOps: FullBinaryTreeOps[Tree, BranchLabel[K, V, A], LeafLabel[T, A]], conversion: A => Double): TreeTraversal[Tree, K, V, T, A] = + DepthFirstTreeTraversal(Reorder.probabilisticWeightedDepthFirst(conversion)) // Given a weighted set `xs`, this creates an ordered list of all the elements // in `xs` by sampling without replacement from the set, but giving each @@ -127,44 +162,41 @@ object TreeTraversal { } def mkRandom(id: String): Random = - new Random(MurmurHash3.stringHash(id)) + new Random(MurmurHash3.stringHash(id)) } -case class DepthFirstTreeTraversal[K, V, T, A](order: (Random, List[Node[K, V, T, A]]) => List[Node[K, V, T, A]]) - extends TreeTraversal[K, V, T, A] { +case class DepthFirstTreeTraversal[Tree, K, V, T, A](reorder: Reorder[A])(implicit val treeOps: FullBinaryTreeOps[Tree, BranchLabel[K, V, A], LeafLabel[T, A]]) extends TreeTraversal[Tree, K, V, T, A] { + + def searchNode(start: treeOps.Node, row: Map[K, V], id: Option[String]): Stream[LeafLabel[T, A]] = { - def find(start: Node[K, V, T, A], row: Map[K, V], id: Option[String]): Stream[LeafNode[K, V, T, A]] = { // Lazy to avoid creation in the fast case. lazy val rng: Random = id.fold[Random](Random)(TreeTraversal.mkRandom) - // A little indirection makes scalac happy to eliminate some tailcalls in loop. - def loop0(stack: List[Node[K, V, T, A]]): Stream[LeafNode[K, V, T, A]] = - loop(stack) - - @tailrec - def loop(stack: List[Node[K, V, T, A]]): Stream[LeafNode[K, V, T, A]] = - stack match { - case Nil => - Stream.empty - case (leaf @ LeafNode(_, _, _)) :: rest => - leaf #:: loop0(rest) - case (split @ SplitNode(_, _, _, _, _)) :: rest => - val newStack = split.evaluate(row) match { - case Nil => rest - case node :: Nil => node :: rest - case candidates => order(rng, candidates) ::: rest - } - loop(newStack) - } - - loop0(start :: Nil) + val Empty: Stream[LeafLabel[T, A]] = Stream.empty + + val getAnnotation: treeOps.Node => A = + n => treeOps.foldNode(n)((_, _, bl) => bl._3, _._3) + + def loop(node: treeOps.Node): Stream[LeafLabel[T, A]] = + treeOps.foldNode(node)({ case (lc, rc, bl) => + val (k, p, a) = bl + p(row.get(k)) match { + case Some(true) => + loop(lc) + case Some(false) => + loop(rc) + case None => + val cs = reorder(rng, lc :: rc :: Nil, getAnnotation).toStream + cs.flatMap(loop) + } + }, ll => ll #:: Empty) + loop(start) } } -case class LimitedTreeTraversal[K, V, T, A](traversal: TreeTraversal[K, V, T, A], limit: Int) - extends TreeTraversal[K, V, T, A] { +case class LimitedTreeTraversal[Tree, K, V, T, A](traversal: TreeTraversal[Tree, K, V, T, A], limit: Int) extends TreeTraversal[Tree, K, V, T, A] { require(limit > 0, "limit must be greater than 0") - - def find(node: Node[K, V, T, A], row: Map[K, V], id: Option[String]): Stream[LeafNode[K, V, T, A]] = - traversal.find(node, row, id).take(limit) + val treeOps: traversal.treeOps.type = traversal.treeOps + def searchNode(node: treeOps.Node, row: Map[K, V], id: Option[String]): Stream[LeafLabel[T, A]] = + traversal.searchNode(node, row, id).take(limit) } diff --git a/brushfire-core/src/main/scala/com/stripe/brushfire/Types.scala b/brushfire-core/src/main/scala/com/stripe/brushfire/Types.scala new file mode 100644 index 0000000..c021ce5 --- /dev/null +++ b/brushfire-core/src/main/scala/com/stripe/brushfire/Types.scala @@ -0,0 +1,19 @@ +package com.stripe.brushfire + +import scala.annotation.tailrec +import scala.collection.SortedMap +import scala.math.Ordering +import scala.util.Random +import scala.util.hashing.MurmurHash3 + +import com.twitter.algebird._ +import com.stripe.bonsai.FullBinaryTreeOps + +object Types { + type BranchLabel[K, V, A] = (K, Predicate[V], A) + type LeafLabel[T, A] = (Int, T, A) + + //type Ops[Tree] = FullBinaryTreeOps[Tree] + //type Aux[Tree, K, V, T, A] = FullBinaryTreeOps.Aux[Tree, BranchLabel[K, V, A], LeafLabel[T, A]] + //type AnnotatedTreeTraversal[K, V, T, A] = TreeTraversal[AnnotatedTree[K, V, T, A], K, V, T, A] +} diff --git a/brushfire-core/src/main/scala/com/stripe/brushfire/Voter.scala b/brushfire-core/src/main/scala/com/stripe/brushfire/Voter.scala index ef6df1a..36e3453 100644 --- a/brushfire-core/src/main/scala/com/stripe/brushfire/Voter.scala +++ b/brushfire-core/src/main/scala/com/stripe/brushfire/Voter.scala @@ -2,6 +2,9 @@ package com.stripe.brushfire import com.twitter.algebird._ +import Types._ +import AnnotatedTree.AnnotatedTreeTraversal + /** Combines multiple targets into a single prediction **/ trait Voter[T, P] { self => @@ -26,7 +29,7 @@ trait Voter[T, P] { self => f(self.combine(targets)) } - final def predict[K, V, A](trees: Iterable[AnnotatedTree[K, V, T, A]], row: Map[K, V])(implicit traversal: TreeTraversal[K, V, T, A], semigroup: Semigroup[T]): P = + final def predict[K, V, A](trees: Iterable[AnnotatedTree[K, V, T, A]], row: Map[K, V])(implicit traversal: AnnotatedTreeTraversal[K, V, T, A], semigroup: Semigroup[T]): P = combine(trees.flatMap(_.targetFor(row))) } diff --git a/brushfire-core/src/main/scala/com/stripe/brushfire/local/Example.scala b/brushfire-core/src/main/scala/com/stripe/brushfire/local/Example.scala index 2a8ff13..2c856a6 100644 --- a/brushfire-core/src/main/scala/com/stripe/brushfire/local/Example.scala +++ b/brushfire-core/src/main/scala/com/stripe/brushfire/local/Example.scala @@ -1,33 +1,36 @@ package com.stripe.brushfire.local +import com.stripe.bonsai._ import com.stripe.brushfire._ import com.twitter.algebird._ +import AnnotatedTree.{AnnotatedTreeTraversal, fullBinaryTreeOpsForAnnotatedTree} + object Example extends Defaults { def main(args: Array[String]) { val cols = args.toList - + val trainingData = io.Source.stdin.getLines.map { line => val parts = line.split(",").reverse.toList val label = parts.head val values = parts.tail.map { s => s.toDouble } Instance(line, 0L, Map(cols.zip(values): _*), Map(label -> 1L)) }.toList - + var trainer = Trainer(trainingData, KFoldSampler(4)) .updateTargets - + println(trainer.validate(AccuracyError())) println(trainer.validate(BrierScoreError())) - + 1.to(10).foreach { i => trainer = trainer.expand(1) println(trainer.validate(AccuracyError())) println(trainer.validate(BrierScoreError())) } - + implicit val ord = Ordering.by[AveragedValue, Double] { _.value } trainer = trainer.prune(BrierScoreError()) println(trainer.validate(AccuracyError())) diff --git a/brushfire-core/src/main/scala/com/stripe/brushfire/local/Trainer.scala b/brushfire-core/src/main/scala/com/stripe/brushfire/local/Trainer.scala index 5cf23b9..41e3033 100644 --- a/brushfire-core/src/main/scala/com/stripe/brushfire/local/Trainer.scala +++ b/brushfire-core/src/main/scala/com/stripe/brushfire/local/Trainer.scala @@ -1,14 +1,18 @@ -package com.stripe.brushfire.local +package com.stripe.brushfire +package local import com.stripe.brushfire._ import com.twitter.algebird._ +import Types._ +import AnnotatedTree.AnnotatedTreeTraversal + case class Trainer[K: Ordering, V, T: Monoid]( trainingData: Iterable[Instance[K, V, T]], sampler: Sampler[K], - trees: List[Tree[K, V, T]]) { + trees: List[Tree[K, V, T]])(implicit traversal: AnnotatedTreeTraversal[K, V, T, Unit]) { - private def updateTrees(fn: (Tree[K, V, T], Int, Map[LeafNode[K, V, T, Unit], Iterable[Instance[K, V, T]]]) => Tree[K, V, T]): Trainer[K, V, T] = { + private def updateTrees(fn: (Tree[K, V, T], Int, Map[(Int, T, Unit), Iterable[Instance[K, V, T]]]) => Tree[K, V, T]): Trainer[K, V, T] = { val newTrees = trees.zipWithIndex.par.map { case (tree, index) => val byLeaf = @@ -16,7 +20,7 @@ case class Trainer[K: Ordering, V, T: Monoid]( val repeats = sampler.timesInTrainingSet(instance.id, instance.timestamp, index) if (repeats > 0) { tree.leafFor(instance.features).map { leaf => - 1.to(repeats).toList.map { i => (instance, leaf) } + (1 to repeats).toList.map { i => (instance, leaf) } }.getOrElse(Nil) } else { Nil @@ -28,12 +32,13 @@ case class Trainer[K: Ordering, V, T: Monoid]( copy(trees = newTrees.toList) } - private def updateLeaves(fn: (Int, LeafNode[K, V, T, Unit], Iterable[Instance[K, V, T]]) => Node[K, V, T, Unit]): Trainer[K, V, T] = { + private def updateLeaves(fn: (Int, (Int, T, Unit), Iterable[Instance[K, V, T]]) => Node[K, V, T, Unit]): Trainer[K, V, T] = { updateTrees { case (tree, treeIndex, byLeaf) => val newNodes = byLeaf.map { case (leaf, instances) => - leaf.index -> fn(treeIndex, leaf, instances) + val (index, _, _) = leaf + index -> fn(treeIndex, leaf, instances) } tree.updateByLeafIndex(newNodes.lift) @@ -42,23 +47,23 @@ case class Trainer[K: Ordering, V, T: Monoid]( def updateTargets: Trainer[K, V, T] = updateLeaves { - case (treeIndex, leaf, instances) => + case (treeIndex, (index, _, annotation), instances) => val target = implicitly[Monoid[T]].sum(instances.map { _.target }) - leaf.copy(target = target) + LeafNode(index, target, annotation) } def expand(times: Int)(implicit splitter: Splitter[V, T], evaluator: Evaluator[V, T], stopper: Stopper[T]): Trainer[K, V, T] = updateLeaves { - case (treeIndex, leaf, instances) => - Tree.expand(times, treeIndex, leaf, splitter, evaluator, stopper, sampler, instances) + case (treeIndex, (index, target, annotation), instances) => + Tree.expand(times, treeIndex, LeafNode(index, target, annotation), splitter, evaluator, stopper, sampler, instances) } def prune[P, E](error: Error[T, P, E])(implicit voter: Voter[T, P], ord: Ordering[E]): Trainer[K, V, T] = updateTrees { case (tree, treeIndex, byLeaf) => val byLeafIndex = byLeaf.map { - case (l, instances) => - l.index -> implicitly[Monoid[T]].sum(instances.map { _.target }) + case ((index, _, _), instances) => + index -> implicitly[Monoid[T]].sum(instances.map { _.target }) } tree.prune(byLeafIndex, voter, error) } @@ -81,7 +86,7 @@ case class Trainer[K: Ordering, V, T: Monoid]( } object Trainer { - def apply[K: Ordering, V, T: Monoid](trainingData: Iterable[Instance[K, V, T]], sampler: Sampler[K]): Trainer[K, V, T] = { + def apply[K: Ordering, V, T: Monoid](trainingData: Iterable[Instance[K, V, T]], sampler: Sampler[K])(implicit traversal: AnnotatedTreeTraversal[K, V, T, Unit]): Trainer[K, V, T] = { val empty = 0.until(sampler.numTrees).toList.map { i => Tree.singleton[K, V, T](Monoid.zero) } Trainer(trainingData, sampler, empty) } diff --git a/brushfire-core/src/main/scala/com/stripe/brushfire/package.scala b/brushfire-core/src/main/scala/com/stripe/brushfire/package.scala index 9a452c2..35375fa 100644 --- a/brushfire-core/src/main/scala/com/stripe/brushfire/package.scala +++ b/brushfire-core/src/main/scala/com/stripe/brushfire/package.scala @@ -3,4 +3,3 @@ package com.stripe package object brushfire { type Tree[K, V, T] = AnnotatedTree[K, V, T, Unit] } - diff --git a/brushfire-core/src/test/scala/com/stripe/brushfire/TreeTraversalSpec.scala b/brushfire-core/src/test/scala/com/stripe/brushfire/TreeTraversalSpec.scala index a988918..42414f9 100644 --- a/brushfire-core/src/test/scala/com/stripe/brushfire/TreeTraversalSpec.scala +++ b/brushfire-core/src/test/scala/com/stripe/brushfire/TreeTraversalSpec.scala @@ -4,7 +4,7 @@ import scala.util.Random import com.twitter.algebird._ -import org.scalacheck.Prop +import org.scalacheck.{ Gen, Prop } import org.scalacheck.Prop._ import org.scalacheck.Arbitrary.arbitrary import org.scalatest.{ WordSpec, Matchers } @@ -13,34 +13,42 @@ import org.scalatest.prop.Checkers class TreeTraversalSpec extends WordSpec with Matchers with Checkers { import TreeGenerators._ + type TreeSDM[A] = AnnotatedTree[String, Double, Map[String, Long], A] + type TreeSDD[A] = AnnotatedTree[String, Double, Double, A] + def split[T, A: Semigroup](key: String, pred: Predicate[Double], left: Node[String, Double, T, A], right: Node[String, Double, T, A]): SplitNode[String, Double, T, A] = SplitNode(key, pred, left, right) "depthFirst" should { "always choose the left side of a split in a binary tree" in { - val simpleTreeGen = genBinaryTree(arbitrary[String], arbitrary[Double], arbitrary[Map[String, Long]], 2) - .filter(_.root match { - case SplitNode(_, p, _, _, _) => - p match { - case IsPresent(_) => false - case _ => true - } - case _ => - false - }) - check(Prop.forAll(simpleTreeGen) { tree => + val simpleTreeGen: Gen[TreeSDM[Unit]] = + genBinaryTree(arbitrary[String], arbitrary[Double], arbitrary[Map[String, Long]], 2) + .filter(_.root match { + case SplitNode(_, p, _, _, _) => + p match { + case IsPresent(_) => false + case _ => true + } + case _ => + false + }) + check(Prop.forAll(simpleTreeGen) { (tree: TreeSDM[Unit]) => (tree.root: @unchecked) match { - case SplitNode(_, _, lc, rc, _) => - TreeTraversal.depthFirst.find(tree, Map.empty[String, Double], None).headOption == Some(lc) + case SplitNode(_, _, LeafNode(li, lt, la), LeafNode(ri, rt, ra), _) => + TreeTraversal.depthFirst[TreeSDM[Unit], String, Double, Map[String, Long], Unit] + .search(tree, Map.empty[String, Double], None).headOption == Some((li, lt, la)) } }) } "traverse in order" in { - check { (tree: Tree[String, Double, Map[String, Long]]) => - TreeTraversal.depthFirst - .find(tree, Map.empty[String, Double], None) - .map(_.index) + check { (tree: TreeSDM[Unit]) => + val leaves: Stream[(Int, Map[String, Long], Unit)] = + TreeTraversal.depthFirst[TreeSDM[Unit], String, Double, Map[String, Long], Unit] + .search(tree, Map.empty[String, Double], None) + + leaves + .map { case (index, _, _) => index } .sliding(2) .forall { case Seq(_) => true @@ -51,14 +59,14 @@ class TreeTraversalSpec extends WordSpec with Matchers with Checkers { } "weightedDepthFirst" should { - implicit val traversal = TreeTraversal.weightedDepthFirst[String, Double, Double, Int] + implicit val traversal = TreeTraversal.weightedDepthFirst[TreeSDD[Int], String, Double, Double, Int] "choose the heaviest node in a split" in { val split1 = split("f1", LessThan(0D), LeafNode(0, -1D, 12), LeafNode(0, 1D, 9)) - AnnotatedTree(split1).leafFor(Map.empty[String, Double]) shouldBe Some(split1.leftChild) + AnnotatedTree(split1).leafFor(Map.empty[String, Double]) shouldBe Some((0, -1D, 12)) val split2 = split("f2", LessThan(0D), LeafNode(0, -1D, -3), LeafNode(0, 1D, 9)) - AnnotatedTree(split2).leafFor(Map.empty[String, Double]) shouldBe Some(split2.rightChild) + AnnotatedTree(split2).leafFor(Map.empty[String, Double]) shouldBe Some((0, 1D, 9)) } "choose heaviest path in a tree" in { @@ -71,10 +79,10 @@ class TreeTraversalSpec extends WordSpec with Matchers with Checkers { LeafNode(2, -100D, 77), LeafNode(3, 333D, 19)))) - tree.leafFor(Map.empty) shouldBe Some(LeafNode(2, -100D, 77)) - tree.leafFor(Map("f1" -> -1D)) shouldBe Some(LeafNode(1, 1D, 53)) - tree.leafFor(Map("f2" -> 1D)) shouldBe Some(LeafNode(3, 333D, 19)) - tree.leafFor(Map("f1" -> -1D, "f2" -> -1D)) shouldBe Some(LeafNode(0, -1D, 33)) + tree.leafFor(Map.empty) shouldBe Some((2, -100D, 77)) + tree.leafFor(Map("f1" -> -1D)) shouldBe Some((1, 1D, 53)) + tree.leafFor(Map("f2" -> 1D)) shouldBe Some((3, 333D, 19)) + tree.leafFor(Map("f1" -> -1D, "f2" -> -1D)) shouldBe Some((0, -1D, 33)) } } @@ -93,24 +101,24 @@ class TreeTraversalSpec extends WordSpec with Matchers with Checkers { // seed = "c", 0.28329487121994257, 0.6209168529111834, 0.896329641267321, 0.4295933416724764, 0.28764827423012496 // seed = "z", 0.3982796920464978, 0.09452393725015651, 0.2645674766831084, 0.45670292528849277, 0.4919659310853626 - val traversal = TreeTraversal.probabilisticWeightedDepthFirst[String, Double, Double, Int] + val traversal = TreeTraversal.probabilisticWeightedDepthFirst[TreeSDD[Int], String, Double, Double, Int] "choose predictable weighted probabilistic node in a split" in { val split1 = split("f1", LessThan(0D), LeafNode(0, -1D, 2703), LeafNode(1, 1D, 10000 - 2703)) - traversal.find(AnnotatedTree(split1), Map.empty[String, Double], Some("b")) shouldBe - Seq(LeafNode(1, 1D, 10000 - 2703), LeafNode(0, -1D, 2703)) + traversal.search(AnnotatedTree(split1), Map.empty[String, Double], Some("b")) shouldBe + Seq((1, 1D, 10000 - 2703), (0, -1D, 2703)) val split2 = split("f1", LessThan(0D), LeafNode(0, -1D, 2704), LeafNode(1, 1D, 10000 - 2704)) - traversal.find(AnnotatedTree(split2), Map.empty[String, Double], Some("b")) shouldBe - Seq(LeafNode(0, -1D, 2704), LeafNode(1, 1D, 10000 - 2704)) + traversal.search(AnnotatedTree(split2), Map.empty[String, Double], Some("b")) shouldBe + Seq((0, -1D, 2704), (1, 1D, 10000 - 2704)) val split3 = split("f1", LessThan(0D), LeafNode(0, -1D, 3982), LeafNode(1, 1D, 10000 - 3982)) - traversal.find(AnnotatedTree(split3), Map.empty[String, Double], Some("z")) shouldBe - Seq(LeafNode(1, 1D, 10000 - 3982), LeafNode(0, -1D, 3982)) + traversal.search(AnnotatedTree(split3), Map.empty[String, Double], Some("z")) shouldBe + Seq((1, 1D, 10000 - 3982), (0, -1D, 3982)) val split4 = split("f1", LessThan(0D), LeafNode(0, -1D, 3983), LeafNode(1, 1D, 10000 - 3983)) - traversal.find(AnnotatedTree(split4), Map.empty[String, Double], Some("z")) shouldBe - Seq(LeafNode(0, -1D, 3983), LeafNode(1, 1D, 10000 - 3983)) + traversal.search(AnnotatedTree(split4), Map.empty[String, Double], Some("z")) shouldBe + Seq((0, -1D, 3983), (1, 1D, 10000 - 3983)) } "choose predictable weighted probabilistic path in a tree" in { @@ -123,14 +131,10 @@ class TreeTraversalSpec extends WordSpec with Matchers with Checkers { LeafNode(2, 3D, 24), LeafNode(3, 4D, 46)))) - traversal.find(tree, Map.empty[String, Double], Some("c")).headOption shouldBe - Some(LeafNode(1, 2D, 21)) - traversal.find(tree, Map.empty[String, Double], Some("z")).headOption shouldBe - Some(LeafNode(2, 3D, 24)) - traversal.find(tree, Map("f1" -> 1D), Some("z")).headOption shouldBe - Some(LeafNode(3, 4D, 46)) - traversal.find(tree, Map("f1" -> -1D), Some("b")).headOption shouldBe - Some(LeafNode(0, 1D, 9)) + traversal.search(tree, Map.empty[String, Double], Some("c")).headOption shouldBe Some((1, 2D, 21)) + traversal.search(tree, Map.empty[String, Double], Some("z")).headOption shouldBe Some((2, 3D, 24)) + traversal.search(tree, Map("f1" -> 1D), Some("z")).headOption shouldBe Some((3, 4D, 46)) + traversal.search(tree, Map("f1" -> -1D), Some("b")).headOption shouldBe Some((0, 1D, 9)) } } } diff --git a/brushfire-scalding/src/main/scala/com/stripe/brushfire/scalding/Trainer.scala b/brushfire-scalding/src/main/scala/com/stripe/brushfire/scalding/Trainer.scala index dce9a97..d073798 100644 --- a/brushfire-scalding/src/main/scala/com/stripe/brushfire/scalding/Trainer.scala +++ b/brushfire-scalding/src/main/scala/com/stripe/brushfire/scalding/Trainer.scala @@ -118,9 +118,9 @@ case class Trainer[K: Ordering, V, T: Monoid]( for ( (treeIndex, tree) <- treeMap; i <- 1.to(sampler.timesInTrainingSet(instance.id, instance.timestamp, treeIndex)).toList; - leaf <- tree.leafFor(instance.features).toList if stopper.shouldSplit(leaf.target) && stopper.shouldSplitDistributed(leaf.target); - (feature, stats) <- features if (sampler.includeFeature(feature, treeIndex, leaf.index)) - ) yield (treeIndex, leaf.index, feature) -> stats + (index, target, annotation) <- tree.leafFor(instance.features).toList if stopper.shouldSplit(target) && stopper.shouldSplitDistributed(target); + (feature, stats) <- features if (sampler.includeFeature(feature, treeIndex, index)) + ) yield (treeIndex, index, feature) -> stats } val splits = @@ -280,8 +280,8 @@ case class Trainer[K: Ordering, V, T: Monoid]( for ( (treeIndex, tree) <- treeMap; i <- 1.to(sampler.timesInTrainingSet(instance.id, instance.timestamp, treeIndex)).toList; - leaf <- tree.leafFor(instance.features).toList if stopper.shouldSplit(leaf.target) && (r.nextDouble < stopper.samplingRateToSplitLocally(leaf.target)) - ) yield (treeIndex, leaf.index) -> instance + (index, target, _) <- tree.leafFor(instance.features).toList if stopper.shouldSplit(target) && (r.nextDouble < stopper.samplingRateToSplitLocally(target)) + ) yield (treeIndex, index) -> instance } .group .forceToReducers diff --git a/project/Deps.scala b/project/Deps.scala index 354cb8e..62c013a 100644 --- a/project/Deps.scala +++ b/project/Deps.scala @@ -9,6 +9,7 @@ object Deps { val algebird = "0.9.0" val jackson = "1.9.13" val bijection = "0.7.0" + val bonsai = "0.1.3-SNAPSHOT" val tDigest = "3.1" val hadoopClient = "2.5.2" @@ -23,6 +24,7 @@ object Deps { val algebirdCore = "com.twitter" %% "algebird-core" % V.algebird val bijectionJson = "com.twitter" %% "bijection-json" % V.bijection + val bonsai = "com.stripe" %% "bonsai-core" % V.bonsai val chillBijection = "com.twitter" %% "chill-bijection" % V.chill val jacksonMapper = "org.codehaus.jackson" % "jackson-mapper-asl" % V.jackson val jacksonXC = "org.codehaus.jackson" % "jackson-xc" % V.jackson From 047701a1ed061d1dc0daf6399e0be30f6f0f07f1 Mon Sep 17 00:00:00 2001 From: Erik Osheim Date: Tue, 12 Jan 2016 14:06:54 -0500 Subject: [PATCH 05/11] Simplified and optimized reordering. --- .../com/stripe/brushfire/TreeTraversal.scala | 125 ++++++++++-------- 1 file changed, 71 insertions(+), 54 deletions(-) diff --git a/brushfire-core/src/main/scala/com/stripe/brushfire/TreeTraversal.scala b/brushfire-core/src/main/scala/com/stripe/brushfire/TreeTraversal.scala index 9fb2371..43e9cf9 100644 --- a/brushfire-core/src/main/scala/com/stripe/brushfire/TreeTraversal.scala +++ b/brushfire-core/src/main/scala/com/stripe/brushfire/TreeTraversal.scala @@ -59,30 +59,65 @@ trait TreeTraversal[Tree, K, V, T, A] { def searchNode(node: treeOps.Node, row: Map[K, V], id: Option[String]): Stream[LeafLabel[T, A]] } +/** + * Simple data type that provides rules to order nodes during + * traversal. + * + * In some cases subtypes of Reorder will also wraps RNG state, for + * instances that need to randomly select instances. Thus, Reorder is + * not guaranteed to be referentially-transparent. Fresh instances + * should be used with each traversal. + */ trait Reorder[A] { - def apply[N](r: Random, ns: List[N], f: N => A): List[N] + def setSeed(seed: Option[String]): Unit + def apply[N](n1: N, n2: N, f: N => A): (N, N) } object Reorder { + + // Traverse into the left node first. def unchanged[A]: Reorder[A] = new Reorder[A] { - def apply[N](r: Random, ns: List[N], f: N => A): List[N] = ns + def setSeed(seed: Option[String]): Unit = () + def apply[N](n1: N, n2: N, f: N => A): (N, N) = + (n1, n2) } + // Traverse into a random node first. Each node has equal + // probability of being selected. def shuffled[A]: Reorder[A] = new Reorder[A] { - def apply[N](r: Random, ns: List[N], f: N => A): List[N] = r.shuffle(ns) + val r = new Random() + def setSeed(seed: Option[String]): Unit = + seed.foreach(s => r.setSeed(MurmurHash3.stringHash(s))) + def apply[N](n1: N, n2: N, f: N => A): (N, N) = + if (r.nextBoolean) (n1, n2) else (n2, n1) } + // Traverse into the node with the higher weight first. def weightedDepthFirst[A](implicit ev: Ordering[A]): Reorder[A] = new Reorder[A] { - def apply[N](r: Random, ns: List[N], f: N => A): List[N] = ns.sortBy(f)(ev.reverse) + def setSeed(seed: Option[String]): Unit = () + def apply[N](n1: N, n2: N, f: N => A): (N, N) = + if (ev.compare(f(n1), f(n2)) >= 0) (n1, n2) else (n2, n1) } + // Traverse into a random node, but choose the random node based on + // the ratio between its weight and the total weight of both nodes. + // + // If the left node's weight was 10, and the right node's weight was + // 5, the left node would be picked 2/3 of the time. def probabilisticWeightedDepthFirst[A](conversion: A => Double): Reorder[A] = new Reorder[A] { - def apply[N](r: Random, ns: List[N], f: N => A): List[N] = - TreeTraversal.probabilisticShuffle(r, ns)(n => conversion(f(n))) + val r = new Random() + def setSeed(seed: Option[String]): Unit = + seed.foreach(s => r.setSeed(MurmurHash3.stringHash(s))) + def apply[N](n1: N, n2: N, f: N => A): (N, N) = { + val w1 = conversion(f(n1)) + val w2 = conversion(f(n2)) + val sum = w1 + w2 + if (r.nextDouble * sum < w1) (n1, n2) else (n2, n1) + } } } @@ -130,67 +165,49 @@ object TreeTraversal { */ def probabilisticWeightedDepthFirst[Tree, K, V, T, A](implicit treeOps: FullBinaryTreeOps[Tree, BranchLabel[K, V, A], LeafLabel[T, A]], conversion: A => Double): TreeTraversal[Tree, K, V, T, A] = DepthFirstTreeTraversal(Reorder.probabilisticWeightedDepthFirst(conversion)) - - // Given a weighted set `xs`, this creates an ordered list of all the elements - // in `xs` by sampling without replacement from the set, but giving each - // element a probability of being picked that is equal to its weight / total - // weight of all elements remaining in the set. - private[brushfire] def probabilisticShuffle[A](rng: Random, xs: List[A])(getWeight: A => Double): List[A] = { - - @tailrec - def loop(sum: Double, acc: SortedMap[Double, Int], order: Vector[A], as: List[A]): Vector[A] = - as match { - case a :: tail => - val sum0 = sum + getWeight(a) - val newHead = - if (acc.isEmpty) None - else acc.from(rng.nextDouble * sum0).headOption - newHead match { - case Some((k, i)) => - val acc0 = acc + (sum0 -> i) + (k -> order.size) - loop(sum0, acc0, order.updated(i, a) :+ order(i), tail) - case None => - val acc0 = acc + (sum0 -> order.length) - loop(sum0, acc0, order :+ a, tail) - } - - case Nil => - order.reverse - } - - loop(0D, SortedMap.empty, Vector.empty, xs).toList - } - - def mkRandom(id: String): Random = - new Random(MurmurHash3.stringHash(id)) } case class DepthFirstTreeTraversal[Tree, K, V, T, A](reorder: Reorder[A])(implicit val treeOps: FullBinaryTreeOps[Tree, BranchLabel[K, V, A], LeafLabel[T, A]]) extends TreeTraversal[Tree, K, V, T, A] { - def searchNode(start: treeOps.Node, row: Map[K, V], id: Option[String]): Stream[LeafLabel[T, A]] = { + import treeOps.{Node, foldNode} + + def searchNode(start: Node, row: Map[K, V], id: Option[String]): Stream[LeafLabel[T, A]] = { - // Lazy to avoid creation in the fast case. - lazy val rng: Random = id.fold[Random](Random)(TreeTraversal.mkRandom) + // this will be a noop unless we have an id and our reorder + // instance requires randomness. + reorder.setSeed(id) - val Empty: Stream[LeafLabel[T, A]] = Stream.empty + // pull the A value out of a branch or leaf. + val getAnnotation: Node => A = + n => foldNode(n)((_, _, bl) => bl._3, ll => ll._3) - val getAnnotation: treeOps.Node => A = - n => treeOps.foldNode(n)((_, _, bl) => bl._3, _._3) + // construct a singleton stream from a leaf + //val Empty: Stream[LeafLabel[T, A]] = Stream.empty + val leafF: LeafLabel[T, A] => Stream[LeafLabel[T, A]] = + _ #:: Stream.empty - def loop(node: treeOps.Node): Stream[LeafLabel[T, A]] = - treeOps.foldNode(node)({ case (lc, rc, bl) => - val (k, p, a) = bl + // recurse into branch nodes, going left, right, or both, + // depending on what our predicate says. + lazy val branchF: (Node, Node, BranchLabel[K, V, A]) => Stream[LeafLabel[T, A]] = + { case (lc, rc, (k, p, _)) => p(row.get(k)) match { case Some(true) => - loop(lc) + recurse(lc) case Some(false) => - loop(rc) + recurse(rc) case None => - val cs = reorder(rng, lc :: rc :: Nil, getAnnotation).toStream - cs.flatMap(loop) + val (c1, c2) = reorder(lc, rc, getAnnotation) + recurse(c1) #::: recurse(c2) } - }, ll => ll #:: Empty) - loop(start) + } + + // recursively handle each node. the foldNode method decides + // whether to handle it as a branch or a leaf. + def recurse(node: Node): Stream[LeafLabel[T, A]] = + foldNode(node)(branchF, leafF) + + // do it! + recurse(start) } } From 9c73f444313d017326a337f222197e57359cab4e Mon Sep 17 00:00:00 2001 From: Erik Osheim Date: Wed, 13 Jan 2016 16:51:14 -0500 Subject: [PATCH 06/11] Stabilizing reordering during traversals. This commit ensures that parallel traversals will not change the order nodes are visited when using an explicit seed. It includes a test which verifies that this is currently working using parallel collections. It also stabilizes the bonsai dependency at 0.1.3. --- .../com/stripe/brushfire/TreeTraversal.scala | 80 ++++++++++++------- .../stripe/brushfire/TreeTraversalSpec.scala | 30 +++++-- project/Deps.scala | 2 +- 3 files changed, 75 insertions(+), 37 deletions(-) diff --git a/brushfire-core/src/main/scala/com/stripe/brushfire/TreeTraversal.scala b/brushfire-core/src/main/scala/com/stripe/brushfire/TreeTraversal.scala index 43e9cf9..8051ef0 100644 --- a/brushfire-core/src/main/scala/com/stripe/brushfire/TreeTraversal.scala +++ b/brushfire-core/src/main/scala/com/stripe/brushfire/TreeTraversal.scala @@ -69,7 +69,7 @@ trait TreeTraversal[Tree, K, V, T, A] { * should be used with each traversal. */ trait Reorder[A] { - def setSeed(seed: Option[String]): Unit + def setSeed(seed: Option[String]): Reorder[A] def apply[N](n1: N, n2: N, f: N => A): (N, N) } @@ -77,30 +77,16 @@ object Reorder { // Traverse into the left node first. def unchanged[A]: Reorder[A] = - new Reorder[A] { - def setSeed(seed: Option[String]): Unit = () - def apply[N](n1: N, n2: N, f: N => A): (N, N) = - (n1, n2) - } + new UnchangedReorder() // Traverse into a random node first. Each node has equal // probability of being selected. def shuffled[A]: Reorder[A] = - new Reorder[A] { - val r = new Random() - def setSeed(seed: Option[String]): Unit = - seed.foreach(s => r.setSeed(MurmurHash3.stringHash(s))) - def apply[N](n1: N, n2: N, f: N => A): (N, N) = - if (r.nextBoolean) (n1, n2) else (n2, n1) - } + new ShuffledReorder(new Random) // Traverse into the node with the higher weight first. def weightedDepthFirst[A](implicit ev: Ordering[A]): Reorder[A] = - new Reorder[A] { - def setSeed(seed: Option[String]): Unit = () - def apply[N](n1: N, n2: N, f: N => A): (N, N) = - if (ev.compare(f(n1), f(n2)) >= 0) (n1, n2) else (n2, n1) - } + new WeightedReorder() // Traverse into a random node, but choose the random node based on // the ratio between its weight and the total weight of both nodes. @@ -108,17 +94,49 @@ object Reorder { // If the left node's weight was 10, and the right node's weight was // 5, the left node would be picked 2/3 of the time. def probabilisticWeightedDepthFirst[A](conversion: A => Double): Reorder[A] = - new Reorder[A] { - val r = new Random() - def setSeed(seed: Option[String]): Unit = - seed.foreach(s => r.setSeed(MurmurHash3.stringHash(s))) - def apply[N](n1: N, n2: N, f: N => A): (N, N) = { - val w1 = conversion(f(n1)) - val w2 = conversion(f(n2)) - val sum = w1 + w2 - if (r.nextDouble * sum < w1) (n1, n2) else (n2, n1) + new ProbabilisticWeighted(new Random, conversion) + + class UnchangedReorder[A] extends Reorder[A] { + def setSeed(seed: Option[String]): Reorder[A] = + this + def apply[N](n1: N, n2: N, f: N => A): (N, N) = + (n1, n2) + } + + class ShuffledReorder[A](r: Random) extends Reorder[A] { + def setSeed(seed: Option[String]): Reorder[A] = + seed match { + case Some(s) => + new ShuffledReorder(new Random(MurmurHash3.stringHash(s))) + case None => + this } + def apply[N](n1: N, n2: N, f: N => A): (N, N) = + if (r.nextBoolean) (n1, n2) else (n2, n1) + } + + class WeightedReorder[A](implicit ev: Ordering[A]) extends Reorder[A] { + def setSeed(seed: Option[String]): Reorder[A] = + this + def apply[N](n1: N, n2: N, f: N => A): (N, N) = + if (ev.compare(f(n1), f(n2)) >= 0) (n1, n2) else (n2, n1) + } + + class ProbabilisticWeighted[A](r: Random, conversion: A => Double) extends Reorder[A] { + def setSeed(seed: Option[String]): Reorder[A] = + seed match { + case Some(s) => + new ProbabilisticWeighted(new Random(MurmurHash3.stringHash(s)), conversion) + case None => + this + } + def apply[N](n1: N, n2: N, f: N => A): (N, N) = { + val w1 = conversion(f(n1)) + val w2 = conversion(f(n2)) + val sum = w1 + w2 + if (r.nextDouble * sum < w1) (n1, n2) else (n2, n1) } + } } object TreeTraversal { @@ -174,8 +192,10 @@ case class DepthFirstTreeTraversal[Tree, K, V, T, A](reorder: Reorder[A])(implic def searchNode(start: Node, row: Map[K, V], id: Option[String]): Stream[LeafLabel[T, A]] = { // this will be a noop unless we have an id and our reorder - // instance requires randomness. - reorder.setSeed(id) + // instance requires randomness. it ensures that each searchNode + // call has its own independent RNG (in cases where we care about + // repeatability, i.e. when `id` is not None). + val r = reorder.setSeed(id) // pull the A value out of a branch or leaf. val getAnnotation: Node => A = @@ -196,7 +216,7 @@ case class DepthFirstTreeTraversal[Tree, K, V, T, A](reorder: Reorder[A])(implic case Some(false) => recurse(rc) case None => - val (c1, c2) = reorder(lc, rc, getAnnotation) + val (c1, c2) = r(lc, rc, getAnnotation) recurse(c1) #::: recurse(c2) } } diff --git a/brushfire-core/src/test/scala/com/stripe/brushfire/TreeTraversalSpec.scala b/brushfire-core/src/test/scala/com/stripe/brushfire/TreeTraversalSpec.scala index 42414f9..d5d1aee 100644 --- a/brushfire-core/src/test/scala/com/stripe/brushfire/TreeTraversalSpec.scala +++ b/brushfire-core/src/test/scala/com/stripe/brushfire/TreeTraversalSpec.scala @@ -43,9 +43,10 @@ class TreeTraversalSpec extends WordSpec with Matchers with Checkers { "traverse in order" in { check { (tree: TreeSDM[Unit]) => - val leaves: Stream[(Int, Map[String, Long], Unit)] = + val traversal = TreeTraversal.depthFirst[TreeSDM[Unit], String, Double, Map[String, Long], Unit] - .search(tree, Map.empty[String, Double], None) + val leaves: Stream[(Int, Map[String, Long], Unit)] = + traversal.search(tree, Map.empty[String, Double], None) leaves .map { case (index, _, _) => index } @@ -103,7 +104,7 @@ class TreeTraversalSpec extends WordSpec with Matchers with Checkers { val traversal = TreeTraversal.probabilisticWeightedDepthFirst[TreeSDD[Int], String, Double, Double, Int] - "choose predictable weighted probabilistic node in a split" in { + "choose a predictable node from a split" in { val split1 = split("f1", LessThan(0D), LeafNode(0, -1D, 2703), LeafNode(1, 1D, 10000 - 2703)) traversal.search(AnnotatedTree(split1), Map.empty[String, Double], Some("b")) shouldBe Seq((1, 1D, 10000 - 2703), (0, -1D, 2703)) @@ -121,7 +122,7 @@ class TreeTraversalSpec extends WordSpec with Matchers with Checkers { Seq((0, -1D, 3983), (1, 1D, 10000 - 3983)) } - "choose predictable weighted probabilistic path in a tree" in { + "choose a predictable path through a tree" in { val tree = AnnotatedTree( split("f1", LessThan(0D), split("f2", LessThan(0D), // 30 @@ -131,10 +132,27 @@ class TreeTraversalSpec extends WordSpec with Matchers with Checkers { LeafNode(2, 3D, 24), LeafNode(3, 4D, 46)))) - traversal.search(tree, Map.empty[String, Double], Some("c")).headOption shouldBe Some((1, 2D, 21)) - traversal.search(tree, Map.empty[String, Double], Some("z")).headOption shouldBe Some((2, 3D, 24)) + traversal.search(tree, Map.empty, Some("c")).headOption shouldBe Some((1, 2D, 21)) + traversal.search(tree, Map.empty, Some("z")).headOption shouldBe Some((2, 3D, 24)) traversal.search(tree, Map("f1" -> 1D), Some("z")).headOption shouldBe Some((3, 4D, 46)) traversal.search(tree, Map("f1" -> -1D), Some("b")).headOption shouldBe Some((0, 1D, 9)) + + def stableSearch(t: TreeSDD[Int]): (Int, Double, Int) = + traversal.search(tree, Map.empty, Some("x")).head + + def unstableSearch(t: TreeSDD[Int]): (Int, Double, Int) = + traversal.search(tree, Map.empty, None).head + + // ensure that parallel traversal have the same results as + // sequential traversals. + val trees = (1 to 16).map(_ => tree) + + // should never fail + trees.par.map(stableSearch) shouldBe trees.map(stableSearch) + + // may fail very occasionally -- but this verifies that without + // using a stable seed, we get faster non-stable behavior. + trees.par.map(unstableSearch) should not be trees.map(unstableSearch) } } } diff --git a/project/Deps.scala b/project/Deps.scala index 62c013a..430d33f 100644 --- a/project/Deps.scala +++ b/project/Deps.scala @@ -9,7 +9,7 @@ object Deps { val algebird = "0.9.0" val jackson = "1.9.13" val bijection = "0.7.0" - val bonsai = "0.1.3-SNAPSHOT" + val bonsai = "0.1.3" val tDigest = "3.1" val hadoopClient = "2.5.2" From 4d930ec67122cf0cd44fa46f775eb0519b2b3dab Mon Sep 17 00:00:00 2001 From: Erik Osheim Date: Wed, 13 Jan 2016 17:15:42 -0500 Subject: [PATCH 07/11] General clean up. Add warning for failed inlining, remove unnecessary Types object, and so on. --- .../com/stripe/brushfire/AnnotatedTree.scala | 2 -- .../scala/com/stripe/brushfire/Tree.scala | 2 -- .../com/stripe/brushfire/TreeTraversal.scala | 2 -- .../scala/com/stripe/brushfire/Types.scala | 19 ------------------- .../scala/com/stripe/brushfire/Voter.scala | 1 - .../com/stripe/brushfire/local/Trainer.scala | 1 - .../scala/com/stripe/brushfire/package.scala | 2 ++ .../com/stripe/brushfire/VoterSpec.scala | 14 +++++++------- build.sbt | 1 + 9 files changed, 10 insertions(+), 34 deletions(-) delete mode 100644 brushfire-core/src/main/scala/com/stripe/brushfire/Types.scala diff --git a/brushfire-core/src/main/scala/com/stripe/brushfire/AnnotatedTree.scala b/brushfire-core/src/main/scala/com/stripe/brushfire/AnnotatedTree.scala index 9b4a4cc..6beb22a 100644 --- a/brushfire-core/src/main/scala/com/stripe/brushfire/AnnotatedTree.scala +++ b/brushfire-core/src/main/scala/com/stripe/brushfire/AnnotatedTree.scala @@ -5,8 +5,6 @@ import com.stripe.bonsai.{ FullBinaryTree, FullBinaryTreeOps } import java.lang.Math.{abs, max} -import Types._ - sealed abstract class Node[K, V, T, A] { def annotation: A diff --git a/brushfire-core/src/main/scala/com/stripe/brushfire/Tree.scala b/brushfire-core/src/main/scala/com/stripe/brushfire/Tree.scala index 65bffe6..8148759 100644 --- a/brushfire-core/src/main/scala/com/stripe/brushfire/Tree.scala +++ b/brushfire-core/src/main/scala/com/stripe/brushfire/Tree.scala @@ -3,8 +3,6 @@ package com.stripe.brushfire import com.stripe.bonsai.FullBinaryTreeOps import com.twitter.algebird._ -// type Tree[K, V, T] = AnnotatedTree[K, V, T, Unit] - object Tree { def apply[K, V, T](node: Node[K, V, T, Unit]): Tree[K, V, T] = AnnotatedTree(node) diff --git a/brushfire-core/src/main/scala/com/stripe/brushfire/TreeTraversal.scala b/brushfire-core/src/main/scala/com/stripe/brushfire/TreeTraversal.scala index 8051ef0..c407e04 100644 --- a/brushfire-core/src/main/scala/com/stripe/brushfire/TreeTraversal.scala +++ b/brushfire-core/src/main/scala/com/stripe/brushfire/TreeTraversal.scala @@ -9,8 +9,6 @@ import scala.util.hashing.MurmurHash3 import com.twitter.algebird._ import com.stripe.bonsai._ -import Types._ - /** * A `TreeTraversal` provides a way to find all of the leaves in a tree that * some row can evaluate to. Specifically, there may be cases where multiple diff --git a/brushfire-core/src/main/scala/com/stripe/brushfire/Types.scala b/brushfire-core/src/main/scala/com/stripe/brushfire/Types.scala deleted file mode 100644 index c021ce5..0000000 --- a/brushfire-core/src/main/scala/com/stripe/brushfire/Types.scala +++ /dev/null @@ -1,19 +0,0 @@ -package com.stripe.brushfire - -import scala.annotation.tailrec -import scala.collection.SortedMap -import scala.math.Ordering -import scala.util.Random -import scala.util.hashing.MurmurHash3 - -import com.twitter.algebird._ -import com.stripe.bonsai.FullBinaryTreeOps - -object Types { - type BranchLabel[K, V, A] = (K, Predicate[V], A) - type LeafLabel[T, A] = (Int, T, A) - - //type Ops[Tree] = FullBinaryTreeOps[Tree] - //type Aux[Tree, K, V, T, A] = FullBinaryTreeOps.Aux[Tree, BranchLabel[K, V, A], LeafLabel[T, A]] - //type AnnotatedTreeTraversal[K, V, T, A] = TreeTraversal[AnnotatedTree[K, V, T, A], K, V, T, A] -} diff --git a/brushfire-core/src/main/scala/com/stripe/brushfire/Voter.scala b/brushfire-core/src/main/scala/com/stripe/brushfire/Voter.scala index 36e3453..4fbfd68 100644 --- a/brushfire-core/src/main/scala/com/stripe/brushfire/Voter.scala +++ b/brushfire-core/src/main/scala/com/stripe/brushfire/Voter.scala @@ -2,7 +2,6 @@ package com.stripe.brushfire import com.twitter.algebird._ -import Types._ import AnnotatedTree.AnnotatedTreeTraversal /** Combines multiple targets into a single prediction **/ diff --git a/brushfire-core/src/main/scala/com/stripe/brushfire/local/Trainer.scala b/brushfire-core/src/main/scala/com/stripe/brushfire/local/Trainer.scala index 41e3033..b57a3d4 100644 --- a/brushfire-core/src/main/scala/com/stripe/brushfire/local/Trainer.scala +++ b/brushfire-core/src/main/scala/com/stripe/brushfire/local/Trainer.scala @@ -4,7 +4,6 @@ package local import com.stripe.brushfire._ import com.twitter.algebird._ -import Types._ import AnnotatedTree.AnnotatedTreeTraversal case class Trainer[K: Ordering, V, T: Monoid]( diff --git a/brushfire-core/src/main/scala/com/stripe/brushfire/package.scala b/brushfire-core/src/main/scala/com/stripe/brushfire/package.scala index 35375fa..e411c81 100644 --- a/brushfire-core/src/main/scala/com/stripe/brushfire/package.scala +++ b/brushfire-core/src/main/scala/com/stripe/brushfire/package.scala @@ -2,4 +2,6 @@ package com.stripe package object brushfire { type Tree[K, V, T] = AnnotatedTree[K, V, T, Unit] + type BranchLabel[K, V, A] = (K, Predicate[V], A) + type LeafLabel[T, A] = (Int, T, A) } diff --git a/brushfire-core/src/test/scala/com/stripe/brushfire/VoterSpec.scala b/brushfire-core/src/test/scala/com/stripe/brushfire/VoterSpec.scala index f93edfd..bcdd2fb 100644 --- a/brushfire-core/src/test/scala/com/stripe/brushfire/VoterSpec.scala +++ b/brushfire-core/src/test/scala/com/stripe/brushfire/VoterSpec.scala @@ -18,12 +18,12 @@ class VoterSpec extends WordSpec with Matchers with Checkers { "return mean of counts" in { val targets = Seq( - Map("a" -> 1, "b" -> 2, "c" -> 3), - Map("a" -> 2, "c" -> 5)) + Map(("a", 1), ("b", 2), ("c", 3)), + Map(("a", 2), ("c", 5))) val expected = Map( - "a" -> ((1D / 6D + 2D / 7D) / 2D), - "b" -> ((2D / 6D) / 2D), - "c" -> ((3D / 6D + 5D / 7D) / 2D)) + ("a", ((1D / 6D + 2D / 7D) / 2D)), + ("b", ((2D / 6D) / 2D)), + ("c", ((3D / 6D + 5D / 7D) / 2D))) Voter.soft[String, Int].combine(targets) shouldBe expected } } @@ -39,8 +39,8 @@ class VoterSpec extends WordSpec with Matchers with Checkers { Map("a" -> 2, "c" -> 5), // c Map("b" -> 3, "c" -> 4)) // c val expected = Map( - "b" -> 1D / 3D, - "c" -> 2D / 3D) + ("b", 1D / 3D), + ("c", 2D / 3D)) Voter.mode[String, Int].combine(targets) shouldBe expected } } diff --git a/build.sbt b/build.sbt index 09081bf..d795fc3 100644 --- a/build.sbt +++ b/build.sbt @@ -5,6 +5,7 @@ scalaVersion in ThisBuild := "2.11.5" crossScalaVersions in ThisBuild := Seq("2.10.4", "2.11.5") scalacOptions in ThisBuild ++= Seq( + "-Yinline-warnings", "-deprecation", "-feature", "-unchecked", From 36076b2316cdb562638536cfa31beab7e0b21e12 Mon Sep 17 00:00:00 2001 From: Erik Osheim Date: Tue, 19 Jan 2016 11:37:06 -0500 Subject: [PATCH 08/11] Respond to review comments. 1. Move Reorder into its own file with better Scaladoc comments. 2. Make Predicate's injections public, since predicate is public. 3. Improve some names and comments to be clearer. 4. A bit of other clean up. --- .../com/stripe/brushfire/Brushfire.scala | 10 +- .../com/stripe/brushfire/Injections.scala | 77 +++++------ .../scala/com/stripe/brushfire/Reorder.scala | 124 ++++++++++++++++++ .../scala/com/stripe/brushfire/Tree.scala | 6 +- .../com/stripe/brushfire/TreeTraversal.scala | 103 ++------------- 5 files changed, 183 insertions(+), 137 deletions(-) create mode 100644 brushfire-core/src/main/scala/com/stripe/brushfire/Reorder.scala diff --git a/brushfire-core/src/main/scala/com/stripe/brushfire/Brushfire.scala b/brushfire-core/src/main/scala/com/stripe/brushfire/Brushfire.scala index ac27458..95f010e 100644 --- a/brushfire-core/src/main/scala/com/stripe/brushfire/Brushfire.scala +++ b/brushfire-core/src/main/scala/com/stripe/brushfire/Brushfire.scala @@ -57,7 +57,15 @@ case class Split[V, T](predicate: Predicate[V], leftDistribution: T, rightDistri /** Evaluates the goodness of a candidate split */ trait Evaluator[V, T] { - /** returns a (possibly transformed) version of the input split, and a numeric goodness score */ + + /** + * Evaluate the fitness of a candidate split. + * + * This method may transform the split, in which case the score + * applies to the split that is returned. The result is optional to + * handle cases where a split is not valid (for example, if one side + * of a split is "empty" in some sense). + */ def evaluate(split: Split[V, T]): Option[(Split[V, T], Double)] } diff --git a/brushfire-core/src/main/scala/com/stripe/brushfire/Injections.scala b/brushfire-core/src/main/scala/com/stripe/brushfire/Injections.scala index 82fbf79..75778b8 100644 --- a/brushfire-core/src/main/scala/com/stripe/brushfire/Injections.scala +++ b/brushfire-core/src/main/scala/com/stripe/brushfire/Injections.scala @@ -71,6 +71,44 @@ object JsonInjections { } } + implicit def predicateJsonInjection[V](implicit vInj: JsonNodeInjection[V], ord: Ordering[V] = null): JsonNodeInjection[Predicate[V]] = { + new AbstractJsonNodeInjection[Predicate[V]] { + def apply(pred: Predicate[V]) = { + val obj = JsonNodeFactory.instance.objectNode + pred match { + case IsPresent(None) => obj.put("exists", JsonNodeFactory.instance.nullNode) + case IsPresent(Some(pred)) => obj.put("exists", toJsonNode(pred)(predicateJsonInjection)) + case EqualTo(v) => obj.put("eq", toJsonNode(v)) + case LessThan(v) => obj.put("lt", toJsonNode(v)) + case Not(pred) => obj.put("not", toJsonNode(pred)(predicateJsonInjection)) + case AnyOf(preds) => + val ary = JsonNodeFactory.instance.arrayNode + preds.foreach { pred => ary.add(toJsonNode(pred)(predicateJsonInjection)) } + obj.put("or", ary) + } + obj + } + + override def invert(n: JsonNode) = { + n.getFieldNames.asScala.toList.headOption match { + case Some("eq") => fromJsonNode[V](n.get("eq")).map { EqualTo(_) } + case Some("lt") => + if (ord == null) + sys.error("No Ordering[V] supplied but less than used") + else + fromJsonNode[V](n.get("lt")).map { LessThan(_) } + case Some("not") => fromJsonNode[Predicate[V]](n.get("not")).map { Not(_) } + case Some("or") => fromJsonNode[List[Predicate[V]]](n.get("or")).map { AnyOf(_) } + case Some("exists") => + val predNode = n.get("exists") + if (predNode.isNull) Success(IsPresent[V](None)) + else fromJsonNode[Predicate[V]](predNode).map(p => IsPresent(Some(p))) + case _ => sys.error("Not a predicate node: " + n) + } + } + } + } + implicit def treeJsonInjection[K, V, T]( implicit kInj: JsonNodeInjection[K], pInj: JsonNodeInjection[T], @@ -78,43 +116,6 @@ object JsonInjections { mon: Monoid[T], ord: Ordering[V] = null): JsonNodeInjection[Tree[K, V, T]] = { - implicit def predicateJsonNodeInjection: JsonNodeInjection[Predicate[V]] = - new AbstractJsonNodeInjection[Predicate[V]] { - def apply(pred: Predicate[V]) = { - val obj = JsonNodeFactory.instance.objectNode - pred match { - case IsPresent(None) => obj.put("exists", JsonNodeFactory.instance.nullNode) - case IsPresent(Some(pred)) => obj.put("exists", toJsonNode(pred)(predicateJsonNodeInjection)) - case EqualTo(v) => obj.put("eq", toJsonNode(v)) - case LessThan(v) => obj.put("lt", toJsonNode(v)) - case Not(pred) => obj.put("not", toJsonNode(pred)(predicateJsonNodeInjection)) - case AnyOf(preds) => - val ary = JsonNodeFactory.instance.arrayNode - preds.foreach { pred => ary.add(toJsonNode(pred)(predicateJsonNodeInjection)) } - obj.put("or", ary) - } - obj - } - - override def invert(n: JsonNode) = { - n.getFieldNames.asScala.toList.headOption match { - case Some("eq") => fromJsonNode[V](n.get("eq")).map { EqualTo(_) } - case Some("lt") => - if (ord == null) - sys.error("No Ordering[V] supplied but less than used") - else - fromJsonNode[V](n.get("lt")).map { LessThan(_) } - case Some("not") => fromJsonNode[Predicate[V]](n.get("not")).map { Not(_) } - case Some("or") => fromJsonNode[List[Predicate[V]]](n.get("or")).map { AnyOf(_) } - case Some("exists") => - val predNode = n.get("exists") - if (predNode.isNull) Success(IsPresent[V](None)) - else fromJsonNode[Predicate[V]](predNode).map(p => IsPresent(Some(p))) - case _ => sys.error("Not a predicate node: " + n) - } - } - } - implicit def nodeJsonNodeInjection: JsonNodeInjection[Node[K, V, T, Unit]] = new AbstractJsonNodeInjection[Node[K, V, T, Unit]] { def apply(node: Node[K, V, T, Unit]) = node match { @@ -127,7 +128,7 @@ object JsonInjections { case SplitNode(k, p, lc, rc, _) => val obj = JsonNodeFactory.instance.objectNode obj.put("key", toJsonNode(k)) - obj.put("predicate", toJsonNode(p)(predicateJsonNodeInjection)) + obj.put("predicate", toJsonNode(p)(predicateJsonInjection)) obj.put("left", toJsonNode(lc)(nodeJsonNodeInjection)) obj.put("right", toJsonNode(rc)(nodeJsonNodeInjection)) obj diff --git a/brushfire-core/src/main/scala/com/stripe/brushfire/Reorder.scala b/brushfire-core/src/main/scala/com/stripe/brushfire/Reorder.scala new file mode 100644 index 0000000..5c5feb1 --- /dev/null +++ b/brushfire-core/src/main/scala/com/stripe/brushfire/Reorder.scala @@ -0,0 +1,124 @@ +package com.stripe.brushfire + +import scala.math.Ordering +import scala.util.Random +import scala.util.hashing.MurmurHash3 + +/** + * Simple data type that provides rules to order nodes during + * traversal. + * + * In some cases subtypes of Reorder will also wraps RNG state, for + * instances that need to randomly select instances. Thus, Reorder is + * not guaranteed to be referentially-transparent. Fresh instances + * should be used with each traversal. + * + * Reorder will also continue to recurse into a given structure using + * a provided callback method. + */ +trait Reorder[A] { + + /** + * Seed a reorder instance with a stable identifier. + * + * This ensures that reorders which are non-deterministic in general + * (e.g. shuffled) will produce the same reordering for the same + * seed across many traversals. + * + * On deterministic reorders this method is a noop. + */ + def setSeed(seed: Option[String]): Reorder[A] + + /** + * Perform a reordering. + * + * This method takes two nodes (`n1` and `n2`), as well as two + * functions: + * + * - `f`: function from node to identifying annotation + * - `g`: function from two nodes to a combined result + * + * The `f` function is used in cases where sorting or weighting is + * necessary (in those cases A will be a weight or similar). The `g` + * function is used to recurse on the result -- i.e. the possibly + * reordered nodes `n1` and `n2` will be passed to `g` after the + * reordering occurs. + */ + def apply[N, S](n1: N, n2: N, f: N => A, g: (N, N) => S): S +} + +object Reorder { + + /** + * Reorder instance that traverses into the left node first. + */ + def unchanged[A]: Reorder[A] = + new UnchangedReorder() + + /** + * Reorder instance that traverses into the node with the higher + * weight first. + */ + def weightedDepthFirst[A](implicit ev: Ordering[A]): Reorder[A] = + new WeightedReorder() + + /** + * Reorder instance that traverses into a random node first. Each + * node has equal probability of being selected. + */ + def shuffled[A]: Reorder[A] = + new ShuffledReorder(new Random) + + /** + * Reorder instance that traverses into a random node, but choose + * the random node based on the ratio between its weight and the + * total weight of both nodes. + * + * If the left node's weight was 10, and the right node's weight was + * 5, the left node would be picked 2/3 of the time. + */ + def probabilisticWeightedDepthFirst[A](conversion: A => Double): Reorder[A] = + new ProbabilisticWeighted(new Random, conversion) + + class UnchangedReorder[A] extends Reorder[A] { + def setSeed(seed: Option[String]): Reorder[A] = + this + def apply[N, S](n1: N, n2: N, f: N => A, g: (N, N) => S): S = + g(n1, n2) + } + + class WeightedReorder[A](implicit ev: Ordering[A]) extends Reorder[A] { + def setSeed(seed: Option[String]): Reorder[A] = + this + def apply[N, S](n1: N, n2: N, f: N => A, g: (N, N) => S): S = + if (ev.compare(f(n1), f(n2)) >= 0) g(n1, n2) else g(n2, n1) + } + + class ShuffledReorder[A](r: Random) extends Reorder[A] { + def setSeed(seed: Option[String]): Reorder[A] = + seed match { + case Some(s) => + new ShuffledReorder(new Random(MurmurHash3.stringHash(s))) + case None => + this + } + def apply[N, S](n1: N, n2: N, f: N => A, g: (N, N) => S): S = + if (r.nextBoolean) g(n1, n2) else g(n2, n1) + } + + class ProbabilisticWeighted[A](r: Random, conversion: A => Double) extends Reorder[A] { + def setSeed(seed: Option[String]): Reorder[A] = + seed match { + case Some(s) => + new ProbabilisticWeighted(new Random(MurmurHash3.stringHash(s)), conversion) + case None => + this + } + def apply[N, S](n1: N, n2: N, f: N => A, g: (N, N) => S): S = { + val w1 = conversion(f(n1)) + val w2 = conversion(f(n2)) + val sum = w1 + w2 + if (r.nextDouble * sum < w1) g(n1, n2) else g(n2, n1) + } + } +} diff --git a/brushfire-core/src/main/scala/com/stripe/brushfire/Tree.scala b/brushfire-core/src/main/scala/com/stripe/brushfire/Tree.scala index 8148759..7806499 100644 --- a/brushfire-core/src/main/scala/com/stripe/brushfire/Tree.scala +++ b/brushfire-core/src/main/scala/com/stripe/brushfire/Tree.scala @@ -31,12 +31,12 @@ object Tree { if (splits.isEmpty) None else { val (splitFeature, (Split(pred, left, right), _)) = splits.maxBy { case (f, (x, s)) => s } - def ex(dist: T): Node[K, V, T, Unit] = { - val newInstances = instances.filter { inst => pred.run(inst.features.get(splitFeature)) } + def expandChild(dist: T): Node[K, V, T, Unit] = { + val newInstances = instances.filter(inst => pred.run(inst.features.get(splitFeature))) val target = Monoid.sum(newInstances.map(_.target)) expand(times - 1, treeIndex, LeafNode(0, target), splitter, evaluator, stopper, sampler, newInstances) } - Some(SplitNode(splitFeature, pred, ex(left), ex(right))) + Some(SplitNode(splitFeature, pred, expandChild(left), expandChild(right))) } }.getOrElse(leaf) } else { diff --git a/brushfire-core/src/main/scala/com/stripe/brushfire/TreeTraversal.scala b/brushfire-core/src/main/scala/com/stripe/brushfire/TreeTraversal.scala index c407e04..45a92d3 100644 --- a/brushfire-core/src/main/scala/com/stripe/brushfire/TreeTraversal.scala +++ b/brushfire-core/src/main/scala/com/stripe/brushfire/TreeTraversal.scala @@ -1,13 +1,8 @@ package com.stripe.brushfire -import scala.annotation.tailrec -import scala.collection.SortedMap import scala.math.Ordering -import scala.util.Random -import scala.util.hashing.MurmurHash3 -import com.twitter.algebird._ -import com.stripe.bonsai._ +import com.stripe.bonsai.FullBinaryTreeOps /** * A `TreeTraversal` provides a way to find all of the leaves in a tree that @@ -57,86 +52,6 @@ trait TreeTraversal[Tree, K, V, T, A] { def searchNode(node: treeOps.Node, row: Map[K, V], id: Option[String]): Stream[LeafLabel[T, A]] } -/** - * Simple data type that provides rules to order nodes during - * traversal. - * - * In some cases subtypes of Reorder will also wraps RNG state, for - * instances that need to randomly select instances. Thus, Reorder is - * not guaranteed to be referentially-transparent. Fresh instances - * should be used with each traversal. - */ -trait Reorder[A] { - def setSeed(seed: Option[String]): Reorder[A] - def apply[N](n1: N, n2: N, f: N => A): (N, N) -} - -object Reorder { - - // Traverse into the left node first. - def unchanged[A]: Reorder[A] = - new UnchangedReorder() - - // Traverse into a random node first. Each node has equal - // probability of being selected. - def shuffled[A]: Reorder[A] = - new ShuffledReorder(new Random) - - // Traverse into the node with the higher weight first. - def weightedDepthFirst[A](implicit ev: Ordering[A]): Reorder[A] = - new WeightedReorder() - - // Traverse into a random node, but choose the random node based on - // the ratio between its weight and the total weight of both nodes. - // - // If the left node's weight was 10, and the right node's weight was - // 5, the left node would be picked 2/3 of the time. - def probabilisticWeightedDepthFirst[A](conversion: A => Double): Reorder[A] = - new ProbabilisticWeighted(new Random, conversion) - - class UnchangedReorder[A] extends Reorder[A] { - def setSeed(seed: Option[String]): Reorder[A] = - this - def apply[N](n1: N, n2: N, f: N => A): (N, N) = - (n1, n2) - } - - class ShuffledReorder[A](r: Random) extends Reorder[A] { - def setSeed(seed: Option[String]): Reorder[A] = - seed match { - case Some(s) => - new ShuffledReorder(new Random(MurmurHash3.stringHash(s))) - case None => - this - } - def apply[N](n1: N, n2: N, f: N => A): (N, N) = - if (r.nextBoolean) (n1, n2) else (n2, n1) - } - - class WeightedReorder[A](implicit ev: Ordering[A]) extends Reorder[A] { - def setSeed(seed: Option[String]): Reorder[A] = - this - def apply[N](n1: N, n2: N, f: N => A): (N, N) = - if (ev.compare(f(n1), f(n2)) >= 0) (n1, n2) else (n2, n1) - } - - class ProbabilisticWeighted[A](r: Random, conversion: A => Double) extends Reorder[A] { - def setSeed(seed: Option[String]): Reorder[A] = - seed match { - case Some(s) => - new ProbabilisticWeighted(new Random(MurmurHash3.stringHash(s)), conversion) - case None => - this - } - def apply[N](n1: N, n2: N, f: N => A): (N, N) = { - val w1 = conversion(f(n1)) - val w2 = conversion(f(n2)) - val sum = w1 + w2 - if (r.nextDouble * sum < w1) (n1, n2) else (n2, n1) - } - } -} - object TreeTraversal { def search[Tree, K, V, T, A](tree: Tree, row: Map[K, V], id: Option[String] = None)(implicit ev: TreeTraversal[Tree, K, V, T, A]): Stream[LeafLabel[T, A]] = @@ -197,25 +112,23 @@ case class DepthFirstTreeTraversal[Tree, K, V, T, A](reorder: Reorder[A])(implic // pull the A value out of a branch or leaf. val getAnnotation: Node => A = - n => foldNode(n)((_, _, bl) => bl._3, ll => ll._3) + node => foldNode(node)((_, _, bl) => bl._3, ll => ll._3) // construct a singleton stream from a leaf - //val Empty: Stream[LeafLabel[T, A]] = Stream.empty val leafF: LeafLabel[T, A] => Stream[LeafLabel[T, A]] = _ #:: Stream.empty + lazy val reorderF: (Node, Node) => Stream[LeafLabel[T, A]] = + (n1, n2) => recurse(n1) #::: recurse(n2) + // recurse into branch nodes, going left, right, or both, // depending on what our predicate says. lazy val branchF: (Node, Node, BranchLabel[K, V, A]) => Stream[LeafLabel[T, A]] = { case (lc, rc, (k, p, _)) => p(row.get(k)) match { - case Some(true) => - recurse(lc) - case Some(false) => - recurse(rc) - case None => - val (c1, c2) = r(lc, rc, getAnnotation) - recurse(c1) #::: recurse(c2) + case Some(true) => recurse(lc) + case Some(false) => recurse(rc) + case None => r(lc, rc, getAnnotation, reorderF) } } From cd32e2b2e2b870f5d14baaf7ba4d504270da88b2 Mon Sep 17 00:00:00 2001 From: Erik Osheim Date: Tue, 19 Jan 2016 17:39:04 -0500 Subject: [PATCH 09/11] Explicitly pass RNG seeds in more cases. This change makes the RNG seed handling more explicit, and makes it clear how it is initialized in the "defualt" case. It is still not secure (it's initialized from the lower 32-bits of System.nanoTime) but at least it's clearer what is going on. --- .../scala/com/stripe/brushfire/Reorder.scala | 8 +++--- .../com/stripe/brushfire/TreeTraversal.scala | 26 ++++++++++--------- .../stripe/brushfire/TreeTraversalSpec.scala | 2 +- 3 files changed, 19 insertions(+), 17 deletions(-) diff --git a/brushfire-core/src/main/scala/com/stripe/brushfire/Reorder.scala b/brushfire-core/src/main/scala/com/stripe/brushfire/Reorder.scala index 5c5feb1..3fdd50d 100644 --- a/brushfire-core/src/main/scala/com/stripe/brushfire/Reorder.scala +++ b/brushfire-core/src/main/scala/com/stripe/brushfire/Reorder.scala @@ -66,8 +66,8 @@ object Reorder { * Reorder instance that traverses into a random node first. Each * node has equal probability of being selected. */ - def shuffled[A]: Reorder[A] = - new ShuffledReorder(new Random) + def shuffled[A](seed: Int): Reorder[A] = + new ShuffledReorder(new Random(seed)) /** * Reorder instance that traverses into a random node, but choose @@ -77,8 +77,8 @@ object Reorder { * If the left node's weight was 10, and the right node's weight was * 5, the left node would be picked 2/3 of the time. */ - def probabilisticWeightedDepthFirst[A](conversion: A => Double): Reorder[A] = - new ProbabilisticWeighted(new Random, conversion) + def probabilisticWeightedDepthFirst[A](seed: Int, conversion: A => Double): Reorder[A] = + new ProbabilisticWeighted(new Random(seed), conversion) class UnchangedReorder[A] extends Reorder[A] { def setSeed(seed: Option[String]): Reorder[A] = diff --git a/brushfire-core/src/main/scala/com/stripe/brushfire/TreeTraversal.scala b/brushfire-core/src/main/scala/com/stripe/brushfire/TreeTraversal.scala index 45a92d3..b3ebb76 100644 --- a/brushfire-core/src/main/scala/com/stripe/brushfire/TreeTraversal.scala +++ b/brushfire-core/src/main/scala/com/stripe/brushfire/TreeTraversal.scala @@ -64,16 +64,6 @@ object TreeTraversal { implicit def depthFirst[Tree, K, V, T, A](implicit treeOps: FullBinaryTreeOps[Tree, BranchLabel[K, V, A], LeafLabel[T, A]]): TreeTraversal[Tree, K, V, T, A] = DepthFirstTreeTraversal(Reorder.unchanged) - /** - * A depth first search for matching leaves, randomly choosing the order of - * child candidate nodes to traverse at each step. Since it is depth-first, - * after a node is chosen to be traversed, all of the matching leafs that - * descend from that node are traversed before moving onto the node's - * sibling. - */ - def randomDepthFirst[Tree, K, V, T, A](implicit treeOps: FullBinaryTreeOps[Tree, BranchLabel[K, V, A], LeafLabel[T, A]]): TreeTraversal[Tree, K, V, T, A] = - DepthFirstTreeTraversal(Reorder.shuffled) - /** * A depth-first search for matching leaves, where the candidate child nodes * for a given parent node are traversed in reverse order of their @@ -83,6 +73,16 @@ object TreeTraversal { def weightedDepthFirst[Tree, K, V, T, A: Ordering](implicit treeOps: FullBinaryTreeOps[Tree, BranchLabel[K, V, A], LeafLabel[T, A]]): TreeTraversal[Tree, K, V, T, A] = DepthFirstTreeTraversal(Reorder.weightedDepthFirst) + /** + * A depth first search for matching leaves, randomly choosing the order of + * child candidate nodes to traverse at each step. Since it is depth-first, + * after a node is chosen to be traversed, all of the matching leafs that + * descend from that node are traversed before moving onto the node's + * sibling. + */ + def randomDepthFirst[Tree, K, V, T, A](seed: Option[Int] = None)(implicit treeOps: FullBinaryTreeOps[Tree, BranchLabel[K, V, A], LeafLabel[T, A]]): TreeTraversal[Tree, K, V, T, A] = + DepthFirstTreeTraversal(Reorder.shuffled(seed.getOrElse(System.nanoTime.toInt))) + /** * A depth-first search for matching leaves, where the candidate child leaves * of a parent node are randomly shuffled, but with nodes with higher weight @@ -94,8 +94,10 @@ object TreeTraversal { * proportional to its probability of being sampled, relative to all the * other elements still in the set. */ - def probabilisticWeightedDepthFirst[Tree, K, V, T, A](implicit treeOps: FullBinaryTreeOps[Tree, BranchLabel[K, V, A], LeafLabel[T, A]], conversion: A => Double): TreeTraversal[Tree, K, V, T, A] = - DepthFirstTreeTraversal(Reorder.probabilisticWeightedDepthFirst(conversion)) + def probabilisticWeightedDepthFirst[Tree, K, V, T, A](seed: Option[Int] = None)(implicit treeOps: FullBinaryTreeOps[Tree, BranchLabel[K, V, A], LeafLabel[T, A]], conversion: A => Double): TreeTraversal[Tree, K, V, T, A] = { + val n = seed.getOrElse(System.nanoTime.toInt) + DepthFirstTreeTraversal(Reorder.probabilisticWeightedDepthFirst(n, conversion)) + } } case class DepthFirstTreeTraversal[Tree, K, V, T, A](reorder: Reorder[A])(implicit val treeOps: FullBinaryTreeOps[Tree, BranchLabel[K, V, A], LeafLabel[T, A]]) extends TreeTraversal[Tree, K, V, T, A] { diff --git a/brushfire-core/src/test/scala/com/stripe/brushfire/TreeTraversalSpec.scala b/brushfire-core/src/test/scala/com/stripe/brushfire/TreeTraversalSpec.scala index d5d1aee..7772ebf 100644 --- a/brushfire-core/src/test/scala/com/stripe/brushfire/TreeTraversalSpec.scala +++ b/brushfire-core/src/test/scala/com/stripe/brushfire/TreeTraversalSpec.scala @@ -102,7 +102,7 @@ class TreeTraversalSpec extends WordSpec with Matchers with Checkers { // seed = "c", 0.28329487121994257, 0.6209168529111834, 0.896329641267321, 0.4295933416724764, 0.28764827423012496 // seed = "z", 0.3982796920464978, 0.09452393725015651, 0.2645674766831084, 0.45670292528849277, 0.4919659310853626 - val traversal = TreeTraversal.probabilisticWeightedDepthFirst[TreeSDD[Int], String, Double, Double, Int] + val traversal = TreeTraversal.probabilisticWeightedDepthFirst[TreeSDD[Int], String, Double, Double, Int]() "choose a predictable node from a split" in { val split1 = split("f1", LessThan(0D), LeafNode(0, -1D, 2703), LeafNode(1, 1D, 10000 - 2703)) From 54e6f711e9016720a4765c122f18d97be7f07a6d Mon Sep 17 00:00:00 2001 From: Erik Osheim Date: Wed, 20 Jan 2016 14:38:04 -0500 Subject: [PATCH 10/11] Add comment explaining Reorder#apply. The API is a bit fiddly, and this comment helps explain why. --- .../src/main/scala/com/stripe/brushfire/Reorder.scala | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/brushfire-core/src/main/scala/com/stripe/brushfire/Reorder.scala b/brushfire-core/src/main/scala/com/stripe/brushfire/Reorder.scala index 3fdd50d..a778df8 100644 --- a/brushfire-core/src/main/scala/com/stripe/brushfire/Reorder.scala +++ b/brushfire-core/src/main/scala/com/stripe/brushfire/Reorder.scala @@ -15,6 +15,13 @@ import scala.util.hashing.MurmurHash3 * * Reorder will also continue to recurse into a given structure using * a provided callback method. + * + * The reason that the node type is provided to `apply` (instead of + * `Reorder`) has to do with how generic trees are specified. Since + * TreeOps uses path-dependent types to specify node types, the + * current design is a bit of a kludge that makes it easy to get the + * types right when handling a Reorder instance to a TreeTraversal + * instance (which is parameterized on a generic tree type). */ trait Reorder[A] { From 58eccd4471c60875be18eecd9e41a1e2da0212a0 Mon Sep 17 00:00:00 2001 From: Erik Osheim Date: Wed, 27 Jan 2016 16:06:40 -0500 Subject: [PATCH 11/11] Bump version to 0.7.0. There have been a bunch of big brushfire changes, so it seems appropriate to update the version to 0.7.0. --- version.sbt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/version.sbt b/version.sbt index 981e93d..a3ddb27 100644 --- a/version.sbt +++ b/version.sbt @@ -1 +1 @@ -version in ThisBuild := "0.6.4-SNAPSHOT" \ No newline at end of file +version in ThisBuild := "0.7.0-SNAPSHOT"