Skip to content
This repository was archived by the owner on Apr 8, 2021. It is now read-only.
1 change: 1 addition & 0 deletions brushfire-core/build.sbt
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ libraryDependencies ++= {
Seq(
algebirdCore,
bijectionJson,
bonsai,
chillBijection,
jacksonMapper,
jacksonXC,
Expand Down
310 changes: 173 additions & 137 deletions brushfire-core/src/main/scala/com/stripe/brushfire/AnnotatedTree.scala

Large diffs are not rendered by default.

25 changes: 21 additions & 4 deletions brushfire-core/src/main/scala/com/stripe/brushfire/Brushfire.scala
Original file line number Diff line number Diff line change
Expand Up @@ -42,14 +42,31 @@ trait Splitter[V, T] {
}

/** Candidate split for a tree node */
trait Split[V, T] {
def predicates: Iterable[(Predicate[V], T)]
case class Split[V, T](predicate: Predicate[V], leftDistribution: T, rightDistribution: T) {

/**
* Given a feature key, create a SplitNode from this Split.
*
* Note that the leaves of this node will likely need to be
* renumbered if this node is put into a larger tree.
*/
def createSplitNode[K](feature: K): SplitNode[K, V, T, Unit] =
SplitNode(feature, predicate, LeafNode(0, leftDistribution), LeafNode(1, rightDistribution))
}


/** Evaluates the goodness of a candidate split */
trait Evaluator[V, T] {
/** returns a (possibly transformed) version of the input split, and a numeric goodness score */
def evaluate(split: Split[V, T]): (Split[V, T], Double)

/**
* Evaluate the fitness of a candidate split.
*
* This method may transform the split, in which case the score
* applies to the split that is returned. The result is optional to
* handle cases where a split is not valid (for example, if one side
* of a split is "empty" in some sense).
*/
def evaluate(split: Split[V, T]): Option[(Split[V, T], Double)]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

May be a question for @avibryant too, but what does None mean and could we put that in the docs?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll try to document this, and @avibryant can review it for accuracy.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this help?

  /**
   * Evaluate the fitness of a candidate split.
   *
   * This method may transform the split, in which case the score
   * applies to the split that is returned. The result is optional to
   * handle cases where a split is not valid (for example, if one side
   * of a split is "empty" in some sense).
   */

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍

}

/** Provides stopping conditions which guide when splits will be attempted */
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,13 +57,6 @@ object Dispatched {
def continuous[C](c: C) = Continuous(c)
def sparse[D](d: D) = Sparse(d)

def wrapSplits[X, T, A: Ordering, B, C: Ordering, D](splits: Iterable[Split[X, T]])(fn: X => Dispatched[A, B, C, D]) = {
splits.map { split =>
new Split[Dispatched[A, B, C, D], T] {
def predicates = split.predicates.map {
case (pred, p) => (pred.map(fn), p)
}
}
}
}
def wrapSplits[X, T, A: Ordering, B, C: Ordering, D](splits: Iterable[Split[X, T]])(fn: X => Dispatched[A, B, C, D]) =
splits.map { case Split(p, left, right) => Split(p.map(fn), left, right) }
}
50 changes: 18 additions & 32 deletions brushfire-core/src/main/scala/com/stripe/brushfire/Evaluators.scala
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,9 @@ import com.twitter.algebird._

case class ChiSquaredEvaluator[V, L, W](implicit weightMonoid: Monoid[W], weightDouble: W => Double)
extends Evaluator[V, Map[L, W]] {
def evaluate(split: Split[V, Map[L, W]]) = {
val rows = split.predicates.map { _._2 }.filter { _.nonEmpty }
def evaluate(split: Split[V, Map[L, W]]): Option[(Split[V, Map[L, W]], Double)] = {
val Split(_, left, right) = split
val rows = (left :: right :: Nil).filter(_.nonEmpty)
if (rows.size > 1) {
val n = weightMonoid.sum(rows.flatMap { _.values })
val rowTotals = rows.map { row => weightMonoid.sum(row.values) }.toList
Expand All @@ -20,43 +21,28 @@ case class ChiSquaredEvaluator[V, L, W](implicit weightMonoid: Monoid[W], weight
val delta = observed - expected
(delta * delta) / expected
}).sum
(split, testStatistic)
} else
(EmptySplit[V, Map[L, W]](), Double.NegativeInfinity)
Some((split, testStatistic))
} else {
None
}
}
}

case class MinWeightEvaluator[V, L, W: Monoid](minWeight: W => Boolean, wrapped: Evaluator[V, Map[L, W]])
extends Evaluator[V, Map[L, W]] {
def evaluate(split: Split[V, Map[L, W]]) = {
val (baseSplit, baseScore) = wrapped.evaluate(split)
if (baseSplit.predicates.forall {
case (pred, freq) =>
val totalWeight = Monoid.sum(freq.values)
minWeight(totalWeight)
})
(baseSplit, baseScore)
else
(EmptySplit[V, Map[L, W]](), Double.NegativeInfinity)
}
}

case class EmptySplit[V, P]() extends Split[V, P] {
val predicates = Nil
}
private[this] def test(dist: Map[L, W]): Boolean = minWeight(Monoid.sum(dist.values))

case class ErrorEvaluator[V, T, P, E](error: Error[T, P, E], voter: Voter[T, P])(fn: E => Double)
extends Evaluator[V, T] {
def evaluate(split: Split[V, T]) = {
val totalErrorOption =
error.semigroup.sumOption(
split
.predicates
.map { case (_, target) => error.create(target, voter.combine(Some(target))) })

totalErrorOption match {
case Some(totalError) => (split, -fn(totalError))
case None => (EmptySplit[V, T](), Double.NegativeInfinity)
def evaluate(split: Split[V, Map[L, W]]): Option[(Split[V, Map[L, W]], Double)] =
wrapped.evaluate(split).filter { case (Split(_, left, right), _) =>
test(left) && test(right)
}
}

case class ErrorEvaluator[V, T, P, E](error: Error[T, P, E], voter: Voter[T, P])(fn: E => Double) extends Evaluator[V, T] {
def evaluate(split: Split[V, T]): Option[(Split[V, T], Double)] = {
val Split(_, left, right) = split
def e(t: T): E = error.create(t, voter.combine(Some(t)))
Some((split, -fn(error.semigroup.plus(e(left), e(right)))))
}
}
150 changes: 71 additions & 79 deletions brushfire-core/src/main/scala/com/stripe/brushfire/Injections.scala
Original file line number Diff line number Diff line change
Expand Up @@ -71,50 +71,51 @@ object JsonInjections {
}
}

implicit def predicateJsonInjection[V](implicit vInj: JsonNodeInjection[V], ord: Ordering[V] = null): JsonNodeInjection[Predicate[V]] = {
new AbstractJsonNodeInjection[Predicate[V]] {
def apply(pred: Predicate[V]) = {
val obj = JsonNodeFactory.instance.objectNode
pred match {
case IsPresent(None) => obj.put("exists", JsonNodeFactory.instance.nullNode)
case IsPresent(Some(pred)) => obj.put("exists", toJsonNode(pred)(predicateJsonInjection))
case EqualTo(v) => obj.put("eq", toJsonNode(v))
case LessThan(v) => obj.put("lt", toJsonNode(v))
case Not(pred) => obj.put("not", toJsonNode(pred)(predicateJsonInjection))
case AnyOf(preds) =>
val ary = JsonNodeFactory.instance.arrayNode
preds.foreach { pred => ary.add(toJsonNode(pred)(predicateJsonInjection)) }
obj.put("or", ary)
}
obj
}

override def invert(n: JsonNode) = {
n.getFieldNames.asScala.toList.headOption match {
case Some("eq") => fromJsonNode[V](n.get("eq")).map { EqualTo(_) }
case Some("lt") =>
if (ord == null)
sys.error("No Ordering[V] supplied but less than used")
else
fromJsonNode[V](n.get("lt")).map { LessThan(_) }
case Some("not") => fromJsonNode[Predicate[V]](n.get("not")).map { Not(_) }
case Some("or") => fromJsonNode[List[Predicate[V]]](n.get("or")).map { AnyOf(_) }
case Some("exists") =>
val predNode = n.get("exists")
if (predNode.isNull) Success(IsPresent[V](None))
else fromJsonNode[Predicate[V]](predNode).map(p => IsPresent(Some(p)))
case _ => sys.error("Not a predicate node: " + n)
}
}
}
}

implicit def treeJsonInjection[K, V, T](
implicit kInj: JsonNodeInjection[K],
pInj: JsonNodeInjection[T],
vInj: JsonNodeInjection[V],
mon: Monoid[T],
ord: Ordering[V] = null): JsonNodeInjection[Tree[K, V, T]] = {

implicit def predicateJsonNodeInjection: JsonNodeInjection[Predicate[V]] =
new AbstractJsonNodeInjection[Predicate[V]] {
def apply(pred: Predicate[V]) = {
val obj = JsonNodeFactory.instance.objectNode
pred match {
case IsPresent(None) => obj.put("exists", JsonNodeFactory.instance.nullNode)
case IsPresent(Some(pred)) => obj.put("exists", toJsonNode(pred)(predicateJsonNodeInjection))
case EqualTo(v) => obj.put("eq", toJsonNode(v))
case LessThan(v) => obj.put("lt", toJsonNode(v))
case Not(pred) => obj.put("not", toJsonNode(pred)(predicateJsonNodeInjection))
case AnyOf(preds) =>
val ary = JsonNodeFactory.instance.arrayNode
preds.foreach { pred => ary.add(toJsonNode(pred)(predicateJsonNodeInjection)) }
obj.put("or", ary)
}
obj
}

override def invert(n: JsonNode) = {
n.getFieldNames.asScala.toList.headOption match {
case Some("eq") => fromJsonNode[V](n.get("eq")).map { EqualTo(_) }
case Some("lt") =>
if (ord == null)
sys.error("No Ordering[V] supplied but less than used")
else
fromJsonNode[V](n.get("lt")).map { LessThan(_) }
case Some("not") => fromJsonNode[Predicate[V]](n.get("not")).map { Not(_) }
case Some("or") => fromJsonNode[List[Predicate[V]]](n.get("or")).map { AnyOf(_) }
case Some("exists") =>
val predNode = n.get("exists")
if (predNode.isNull) Success(IsPresent[V](None))
else fromJsonNode[Predicate[V]](predNode).map(p => IsPresent(Some(p)))
case _ => sys.error("Not a predicate node")
}
}
}

implicit def nodeJsonNodeInjection: JsonNodeInjection[Node[K, V, T, Unit]] =
new AbstractJsonNodeInjection[Node[K, V, T, Unit]] {
def apply(node: Node[K, V, T, Unit]) = node match {
Expand All @@ -124,57 +125,48 @@ object JsonInjections {
obj.put("distribution", toJsonNode(target))
obj

case SplitNode(children) =>
val ary = JsonNodeFactory.instance.arrayNode
children.foreach {
case (feature, predicate, child) =>
val obj = JsonNodeFactory.instance.objectNode
obj.put("feature", toJsonNode(feature))
obj.put("predicate", toJsonNode(predicate))
obj.put("display", toJsonNode(Predicate.display(predicate)))
obj.put("children", toJsonNode(child)(nodeJsonNodeInjection))
ary.add(obj)
}
ary
case SplitNode(k, p, lc, rc, _) =>
val obj = JsonNodeFactory.instance.objectNode
obj.put("key", toJsonNode(k))
obj.put("predicate", toJsonNode(p)(predicateJsonInjection))
obj.put("left", toJsonNode(lc)(nodeJsonNodeInjection))
obj.put("right", toJsonNode(rc)(nodeJsonNodeInjection))
obj
}

def tryChild(node: JsonNode, property: String) = Try {
def tryChild(node: JsonNode, property: String): Try[JsonNode] = {
val child = node.get(property)
assert(child != null, property + " != null")
child
if (child == null) Failure(new IllegalArgumentException(property + " != null"))
else Success(child)
}

override def invert(n: JsonNode) = {
Option(n.get("leaf")) match {
case Some(indexNode) => fromJsonNode[Int](indexNode).flatMap { index =>
fromJsonNode[T](n.get("distribution")).map { target =>
LeafNode(index, target)
}
}

case None =>
val children = n.getElements.asScala.map { c =>
for (
featureNode <- tryChild(c, "feature");
feature <- fromJsonNode[K](featureNode);
predicateNode <- tryChild(c, "predicate");
predicate <- fromJsonNode[Predicate[V]](predicateNode);
childNode <- tryChild(c, "children");
child <- fromJsonNode[Node[K, V, T, Unit]](childNode)
) yield (feature, predicate, child)
}.toList

children.find { _.isFailure } match {
case Some(Failure(e)) => Failure(InversionFailure(n, e))
case _ => Success(SplitNode[K, V, T, Unit](children.map { _.get }))
}
}
def tryLoad[T: JsonNodeInjection](node: JsonNode, property: String): Try[T] = {
val child = node.get(property)
if (child == null) Failure(new IllegalArgumentException(property + " != null"))
else fromJsonNode[T](child)
}

override def invert(n: JsonNode): Try[Node[K, V, T, Unit]] =
if (n.has("leaf")) {
for {
index <- tryLoad[Int](n, "leaf")
target <- tryLoad[T](n, "distribution")
} yield LeafNode(index, target)
} else {
for {
k <- tryLoad[K](n, "key")
p <- tryLoad[Predicate[V]](n, "predicate")
left <- tryLoad[Node[K, V, T, Unit]](n, "left")(nodeJsonNodeInjection)
right <- tryLoad[Node[K, V, T, Unit]](n, "right")(nodeJsonNodeInjection)
} yield SplitNode(k, p, left, right)
}
}

new AbstractJsonNodeInjection[Tree[K, V, T]] {
def apply(tree: Tree[K, V, T]) = toJsonNode(tree.root)
override def invert(n: JsonNode) = fromJsonNode[Node[K, V, T, Unit]](n).map { root => Tree(root) }
def apply(tree: Tree[K, V, T]): JsonNode =
toJsonNode(tree.root)
override def invert(n: JsonNode): Try[Tree[K, V, T]] =
fromJsonNode[Node[K, V, T, Unit]](n).map(root => Tree(root))
}
}

Expand Down
25 changes: 19 additions & 6 deletions brushfire-core/src/main/scala/com/stripe/brushfire/Predicate.scala
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,13 @@ package com.stripe.brushfire
* `true` or `false`. It is generally used within a [[Tree]] to decide which
* branch to follow while trying to classify a row/feature vector.
*/
sealed trait Predicate[V] extends (Option[V] => Boolean) {
sealed trait Predicate[V] extends (Option[V] => Option[Boolean]) {

def run(o: Option[V]): Boolean =
apply(o) match {
case Some(b) => b
case None => true
}

/**
* Map the value types of this [[Predicate]] using `f`.
Expand All @@ -24,7 +30,7 @@ sealed trait Predicate[V] extends (Option[V] => Boolean) {
* defined and is equal to `value` (according to the input's `equals` method).
*/
case class EqualTo[V](value: V) extends Predicate[V] {
def apply(v: Option[V]) = v.isEmpty || (v.get == value)
def apply(v: Option[V]): Option[Boolean] = v.map(_ == value)
}

/**
Expand All @@ -33,23 +39,26 @@ case class EqualTo[V](value: V) extends Predicate[V] {
* of type `V` to handle the comparison.
*/
case class LessThan[V](value: V)(implicit ord: Ordering[V]) extends Predicate[V] {
def apply(v: Option[V]) = v.isEmpty || ord.lt(v.get, value)
def apply(v: Option[V]): Option[Boolean] = v.map(x => ord.lt(x, value))
}

/**
* A [[Predicate]] that returns `true` if `pred` returns `false` and returns
* `false` if `pred` returns `true`.
*/
case class Not[V](pred: Predicate[V]) extends Predicate[V] {
def apply(v: Option[V]) = v.isEmpty || !pred(v)
def apply(v: Option[V]): Option[Boolean] = pred(v).map(!_)
}

/**
* A [[Predicate]] that returns `true` if any of the predicates in `preds`
* returns `true`.
*/
case class AnyOf[V](preds: Seq[Predicate[V]]) extends Predicate[V] {
def apply(v: Option[V]) = preds.exists { p => p(v) }
def apply(v: Option[V]): Option[Boolean] = {
val it = preds.iterator.map(_(v)).flatten
if (!it.hasNext) None else Some(it.exists(_ == true))
}
}

/**
Expand All @@ -63,7 +72,11 @@ case class AnyOf[V](preds: Seq[Predicate[V]]) extends Predicate[V] {
* value is present (not missing).
*/
case class IsPresent[V](pred: Option[Predicate[V]]) extends Predicate[V] {
def apply(v: Option[V]) = v.isDefined && pred.fold(true)(_(v))
def apply(v: Option[V]): Option[Boolean] =
pred match {
case None => Some(v.isDefined)
case Some(pred) => Some(v.isDefined && pred(v) == Some(true))
}
}

object Predicate {
Expand Down
Loading