diff --git a/scalding-base/src/main/scala/com/twitter/scalding/typed/OptimizationRules.scala b/scalding-base/src/main/scala/com/twitter/scalding/typed/OptimizationRules.scala index 1827df769..31d43cf54 100644 --- a/scalding-base/src/main/scala/com/twitter/scalding/typed/OptimizationRules.scala +++ b/scalding-base/src/main/scala/com/twitter/scalding/typed/OptimizationRules.scala @@ -1034,7 +1034,7 @@ object OptimizationRules { case MergedTypedPipe(a, EmptyTypedPipe) => a case ReduceStepPipe(rs: ReduceStep[_, _, _]) if rs.mapped == EmptyTypedPipe => EmptyTypedPipe case SumByLocalKeys(EmptyTypedPipe, _) => EmptyTypedPipe - case TrappedPipe(EmptyTypedPipe, _) => EmptyTypedPipe + case TrappedPipe(EmptyTypedPipe, _) => EmptyTypedPipe case CoGroupedPipe(cgp) if emptyCogroup(cgp) => EmptyTypedPipe case WithOnComplete(EmptyTypedPipe, _) => EmptyTypedPipe // there is nothing to do, so we never have workers complete diff --git a/scalding-beam/src/main/scala/com/twitter/scalding/beam_backend/BeamBackend.scala b/scalding-beam/src/main/scala/com/twitter/scalding/beam_backend/BeamBackend.scala index f71a80a6a..0d290fd6e 100644 --- a/scalding-beam/src/main/scala/com/twitter/scalding/beam_backend/BeamBackend.scala +++ b/scalding-beam/src/main/scala/com/twitter/scalding/beam_backend/BeamBackend.scala @@ -1,12 +1,14 @@ package com.twitter.scalding.beam_backend -import com.twitter.scalding.dagon.{FunctionK, Memoize, Rule} import com.twitter.chill.KryoInstantiator import com.twitter.chill.config.ScalaMapConfig import com.twitter.scalding.Config import com.twitter.scalding.beam_backend.BeamOp.{CoGroupedOp, MergedBeamOp} +import com.twitter.scalding.dagon.{FunctionK, Memoize, Rule} import com.twitter.scalding.serialization.KryoHadoop +import com.twitter.scalding.typed.OptimizationRules._ import com.twitter.scalding.typed._ +import com.twitter.scalding.typed.cascading_backend.CascadingExtensions.ConfigCascadingExtensions import com.twitter.scalding.typed.functions.{ FilterKeysToFilter, FlatMapValuesToFlatMap, @@ -14,8 +16,6 @@ import com.twitter.scalding.typed.functions.{ ScaldingPriorityQueueMonoid } -import com.twitter.scalding.typed.cascading_backend.CascadingExtensions.ConfigCascadingExtensions - object BeamPlanner { def plan( config: Config, @@ -61,8 +61,15 @@ object BeamPlanner { BeamOp.Source(config, src, srcs(src)) case (IterablePipe(iterable), _) => BeamOp.FromIterable(iterable, kryoCoder) - case (wd: WithDescriptionTypedPipe[a], rec) => - rec[a](wd.input) + case (wd: WithDescriptionTypedPipe[_], rec) => { + val op = rec(wd.input) + wd.descriptions match { + case head :: _ => + op.withName(head._1) + case Nil => + op + } + } case (SumByLocalKeys(pipe, sg), rec) => val op = rec(pipe) config.getMapSideAggregationThreshold match { @@ -97,7 +104,10 @@ object BeamPlanner { uir.evidence.subst[BeamOpT](sortedOp) } go(ivsr) - case (ReduceStepPipe(ValueSortedReduce(keyOrdering, pipe, valueSort, reduceFn, _, _)), rec) => + case ( + ReduceStepPipe(ValueSortedReduce(keyOrdering, pipe, valueSort, reduceFn, _, _)), + rec + ) => val op = rec(pipe) op.sortedMapGroup(reduceFn)(keyOrdering, valueSort, kryoCoder) case (ReduceStepPipe(IteratorMappedReduce(keyOrdering, pipe, reduceFn, _, _)), rec) => @@ -116,7 +126,7 @@ object BeamPlanner { val ops: Seq[BeamOp[(K, Any)]] = cg.inputs.map(tp => rec(tp)) CoGroupedOp(cg, ops) } - go(cg) + if (cg.descriptions.isEmpty) go(cg) else go(cg).withName(cg.descriptions.last) case (Fork(input), rec) => rec(input) case (m @ MergedTypedPipe(_, _), rec) => @@ -137,7 +147,21 @@ object BeamPlanner { def defaultOptimizationRules(config: Config): Seq[Rule[TypedPipe]] = { def std(forceHash: Rule[TypedPipe]) = - OptimizationRules.standardMapReduceRules ::: + List( + // phase 0, add explicit forks to not duplicate pipes on fanout below + AddExplicitForks, + RemoveUselessFork, + // phase 1, compose flatMap/map, move descriptions down, defer merge, filter pushup etc... + IgnoreNoOpGroup.orElse(composeSame).orElse(FilterKeysEarly).orElse(DeferMerge), + // phase 2, combine different kinds of mapping operations into flatMaps, including redundant merges + composeIntoFlatMap + .orElse(simplifyEmpty) + .orElse(DiamondToFlatMap) + .orElse(ComposeDescriptions) + .orElse(MapValuesInReducers), + // phase 3, remove duplicates forces/forks (e.g. .fork.fork or .forceToDisk.fork, ....) + RemoveDuplicateForceFork + ) ::: List( OptimizationRules.FilterLocally, // after filtering, we may have filtered to nothing, lets see OptimizationRules.simplifyEmpty, diff --git a/scalding-beam/src/main/scala/com/twitter/scalding/beam_backend/BeamOp.scala b/scalding-beam/src/main/scala/com/twitter/scalding/beam_backend/BeamOp.scala index 32cc6fad3..968175daf 100644 --- a/scalding-beam/src/main/scala/com/twitter/scalding/beam_backend/BeamOp.scala +++ b/scalding-beam/src/main/scala/com/twitter/scalding/beam_backend/BeamOp.scala @@ -1,14 +1,14 @@ package com.twitter.scalding.beam_backend -import com.twitter.scalding.dagon.Memoize import com.twitter.algebird.Semigroup import com.twitter.scalding.Config import com.twitter.scalding.beam_backend.BeamFunctions._ import com.twitter.scalding.beam_backend.BeamJoiner.MultiJoinFunction +import com.twitter.scalding.dagon.Memoize import com.twitter.scalding.serialization.Externalizer +import com.twitter.scalding.typed.{CoGrouped, Input} import com.twitter.scalding.typed.functions.ComposedFunctions.ComposedMapGroup import com.twitter.scalding.typed.functions.{EmptyGuard, MapValueStream, ScaldingPriorityQueueMonoid, SumAll} -import com.twitter.scalding.typed.{CoGrouped, Input} import java.util.{Comparator, PriorityQueue} import org.apache.beam.sdk.Pipeline import org.apache.beam.sdk.coders.{Coder, IterableCoder, KvCoder} @@ -16,9 +16,7 @@ import org.apache.beam.sdk.transforms.DoFn.ProcessElement import org.apache.beam.sdk.transforms.Top.TopCombineFn import org.apache.beam.sdk.transforms._ import org.apache.beam.sdk.transforms.join.{CoGbkResult, CoGroupByKey, KeyedPCollectionTuple} -import org.apache.beam.sdk.values.PCollectionList -import org.apache.beam.sdk.values.PCollectionTuple -import org.apache.beam.sdk.values.{KV, PCollection, TupleTag} +import org.apache.beam.sdk.values.{KV, PCollection, PCollectionList, PCollectionTuple, TupleTag} import scala.collection.JavaConverters._ import scala.reflect.ClassTag @@ -55,6 +53,8 @@ sealed abstract class BeamOp[+A] { def flatMap[B](f: A => TraversableOnce[B])(implicit kryoCoder: KryoCoder): BeamOp[B] = parDo(FlatMapFn(f), "flatMap") + + def withName(name: String): BeamOp[A] } private final case class SerializableComparator[T](comp: Comparator[T]) extends Comparator[T] { @@ -136,11 +136,16 @@ object BeamOp extends Serializable { ) case Some(src) => src.read(pipeline, conf) } + + override def withName(name: String): BeamOp[A] = this } - final case class FromIterable[A](iterable: Iterable[A], kryoCoder: KryoCoder) extends BeamOp[A] { + final case class FromIterable[A](iterable: Iterable[A], kryoCoder: KryoCoder, name: Option[String] = None) + extends BeamOp[A] { override def runNoCache(pipeline: Pipeline): PCollection[_ <: A] = - pipeline.apply(Create.of(iterable.asJava).withCoder(kryoCoder)) + pipeline.apply(name.getOrElse("Iterable source"), Create.of(iterable.asJava).withCoder(kryoCoder)) + + override def withName(name: String): BeamOp[A] = FromIterable(iterable, kryoCoder, Some(name)) } final case class TransformBeamOp[A, B]( @@ -153,6 +158,8 @@ object BeamOp extends Serializable { val pCollection: PCollection[A] = widenPCollection(source.run(pipeline)) pCollection.apply(name, f).setCoder(kryoCoder) } + + override def withName(desc: String): BeamOp[B] = TransformBeamOp(source, f, kryoCoder, desc) } final case class HashJoinTransform[K, V, U, W]( @@ -184,7 +191,8 @@ object BeamOp extends Serializable { final case class HashJoinOp[K, V, U, W]( left: BeamOp[(K, V)], right: BeamOp[(K, U)], - joiner: (K, V, Iterable[U]) => Iterator[W] + joiner: (K, V, Iterable[U]) => Iterator[W], + name: Option[String] = None )(implicit kryoCoder: KryoCoder, ordK: Ordering[K]) extends BeamOp[(K, W)] { override def runNoCache(pipeline: Pipeline): PCollection[_ <: (K, W)] = { @@ -199,20 +207,28 @@ object BeamOp extends Serializable { widenPCollection(rightPCollection): PCollection[(K, _)] ) - tuple.apply(HashJoinTransform(keyCoder, joiner)) + tuple.apply(name.getOrElse("HashJoin"), HashJoinTransform(keyCoder, joiner)) } + + override def withName(name: String): BeamOp[(K, W)] = HashJoinOp(left, right, joiner, Some(name)) } - final case class MergedBeamOp[A](first: BeamOp[A], second: BeamOp[A], tail: Seq[BeamOp[A]]) - extends BeamOp[A] { + final case class MergedBeamOp[A]( + first: BeamOp[A], + second: BeamOp[A], + tail: Seq[BeamOp[A]], + name: Option[String] = None + ) extends BeamOp[A] { override def runNoCache(pipeline: Pipeline): PCollection[_ <: A] = { val collections = PCollectionList .of(widenPCollection(first.run(pipeline)): PCollection[A]) .and(widenPCollection(second.run(pipeline)): PCollection[A]) .and(tail.map(op => widenPCollection(op.run(pipeline)): PCollection[A]).asJava) - collections.apply(Flatten.pCollections[A]()) + collections.apply(name.getOrElse("Merge"), Flatten.pCollections[A]()) } + + override def withName(name: String): BeamOp[A] = MergedBeamOp(first, second, tail, Some(name)) } final case class CoGroupedTransform[K, V]( @@ -241,7 +257,8 @@ object BeamOp extends Serializable { final case class CoGroupedOp[K, V]( cg: CoGrouped[K, V], - inputOps: Seq[BeamOp[(K, Any)]] + inputOps: Seq[BeamOp[(K, Any)]], + name: Option[String] = None )(implicit kryoCoder: KryoCoder) extends BeamOp[(K, V)] { override def runNoCache(pipeline: Pipeline): PCollection[_ <: (K, V)] = { @@ -256,8 +273,10 @@ object BeamOp extends Serializable { PCollectionList .of(pcols.asJava) - .apply(CoGroupedTransform(joinFunction, tupleTags, keyCoder)) + .apply(name.getOrElse("CoGrouped"), CoGroupedTransform(joinFunction, tupleTags, keyCoder)) } + + override def withName(name: String): BeamOp[(K, V)] = CoGroupedOp(cg, inputOps, Some(name)) } final case class CoGroupDoFn[K, V]( diff --git a/scalding-beam/src/test/scala/com/twitter/scalding/beam_backend/BeamBackendTests.scala b/scalding-beam/src/test/scala/com/twitter/scalding/beam_backend/BeamBackendTests.scala index c7df08104..8f8981c99 100644 --- a/scalding-beam/src/test/scala/com/twitter/scalding/beam_backend/BeamBackendTests.scala +++ b/scalding-beam/src/test/scala/com/twitter/scalding/beam_backend/BeamBackendTests.scala @@ -1,13 +1,23 @@ package com.twitter.scalding.beam_backend +import com.twitter.scalding.dagon.Rule import com.twitter.algebird.{AveragedValue, Semigroup} +import com.twitter.scalding.Execution.ToWrite +import com.twitter.scalding.Execution.ToWrite.SimpleWrite +import com.twitter.scalding.TypedTsv import com.twitter.scalding.beam_backend.BeamOp.{CoGroupedOp, FromIterable, HashJoinOp, MergedBeamOp} import com.twitter.scalding.{Config, Execution, TextLine, TypedPipe} import java.io.File import java.nio.file.Paths import org.apache.beam.sdk.Pipeline +import org.apache.beam.sdk.Pipeline.PipelineVisitor +import org.apache.beam.sdk.Pipeline.PipelineVisitor.CompositeBehavior import org.apache.beam.sdk.options.{PipelineOptions, PipelineOptionsFactory} +import org.apache.beam.sdk.runners.TransformHierarchy +import org.apache.beam.sdk.values.PValue import org.scalatest.{BeforeAndAfter, FunSuite} +import scala.collection.immutable +import scala.collection.mutable import scala.io.Source class BeamBackendTests extends FunSuite with BeforeAndAfter { @@ -473,6 +483,42 @@ class BeamBackendTests extends FunSuite with BeforeAndAfter { assert(output.toSet == Seq((5, 3), (10, 3)).toSet) } + test("BeamOp naming: named PTransforms") { + class TransformNameVisitor extends PipelineVisitor.Defaults { + private var transformNames: mutable.Set[String] = mutable.Set[String]() + + override def visitPrimitiveTransform(node: TransformHierarchy#Node): Unit = + transformNames.add(node.getFullName) + + def getTransformNames(): mutable.Set[String] = transformNames + } + + case class WordCount(a: String, b: Long) + + val t1 = TypedPipe.from(Seq(("a", 1L), ("b", 2L))) + val pipe = TypedPipe + .from(Seq("the quick brown fox jumps")) + .withDescription("Read data") + .flatMap(s => s.split(" ")) + .withDescription("Convert to words") + .map(tag => (tag, 1L)) + .sumByKey + .withDescription("Count words") + .join(t1) + .withDescription("Join with t1") + .map(keyval => WordCount(keyval._1, keyval._2._1)) ++ t1 + val (pipeline, op) = beamUnoptimizedPlan(pipe) + op.run(pipeline) + val visitor = new TransformNameVisitor() + pipeline.traverseTopologically(visitor) + val names = visitor.getTransformNames() + assert( + names.exists(_.contains("Read data")) && names.exists(_.contains("Convert to words")) && names.exists( + _.contains("Join with t1") + ) + ) + } + private def getContents(path: String, prefix: String): List[String] = new File(path).listFiles.flatMap { file => if (file.getPath.startsWith(prefix)) {