diff --git a/core/src/main/java/io/substrait/relation/AbstractRelVisitor.java b/core/src/main/java/io/substrait/relation/AbstractRelVisitor.java index 99f4e0577..02418ce1d 100644 --- a/core/src/main/java/io/substrait/relation/AbstractRelVisitor.java +++ b/core/src/main/java/io/substrait/relation/AbstractRelVisitor.java @@ -1,8 +1,13 @@ package io.substrait.relation; +import io.substrait.relation.physical.BroadcastExchange; import io.substrait.relation.physical.HashJoin; import io.substrait.relation.physical.MergeJoin; +import io.substrait.relation.physical.MultiBucketExchange; import io.substrait.relation.physical.NestedLoopJoin; +import io.substrait.relation.physical.RoundRobinExchange; +import io.substrait.relation.physical.ScatterExchange; +import io.substrait.relation.physical.SingleBucketExchange; import io.substrait.util.VisitationContext; public abstract class AbstractRelVisitor @@ -138,4 +143,29 @@ public O visit(ExtensionDdl ddl, C context) throws E { public O visit(NamedUpdate update, C context) throws E { return visitFallback(update, context); } + + @Override + public O visit(ScatterExchange exchange, C context) throws E { + return visitFallback(exchange, context); + } + + @Override + public O visit(SingleBucketExchange exchange, C context) throws E { + return visitFallback(exchange, context); + } + + @Override + public O visit(MultiBucketExchange exchange, C context) throws E { + return visitFallback(exchange, context); + } + + @Override + public O visit(BroadcastExchange exchange, C context) throws E { + return visitFallback(exchange, context); + } + + @Override + public O visit(RoundRobinExchange exchange, C context) throws E { + return visitFallback(exchange, context); + } } diff --git a/core/src/main/java/io/substrait/relation/ProtoRelConverter.java b/core/src/main/java/io/substrait/relation/ProtoRelConverter.java index ce86f1de2..fca347f81 100644 --- a/core/src/main/java/io/substrait/relation/ProtoRelConverter.java +++ b/core/src/main/java/io/substrait/relation/ProtoRelConverter.java @@ -1,6 +1,7 @@ package io.substrait.relation; import io.substrait.expression.Expression; +import io.substrait.expression.FieldReference; import io.substrait.expression.proto.ProtoExpressionConverter; import io.substrait.extension.AdvancedExtension; import io.substrait.extension.DefaultExtensionCatalog; @@ -18,6 +19,7 @@ import io.substrait.proto.ConsistentPartitionWindowRel; import io.substrait.proto.CrossRel; import io.substrait.proto.DdlRel; +import io.substrait.proto.ExchangeRel; import io.substrait.proto.ExpandRel; import io.substrait.proto.ExtensionLeafRel; import io.substrait.proto.ExtensionMultiRel; @@ -37,9 +39,22 @@ import io.substrait.relation.extensions.EmptyDetail; import io.substrait.relation.files.FileFormat; import io.substrait.relation.files.FileOrFiles; +import io.substrait.relation.physical.AbstractExchangeRel; +import io.substrait.relation.physical.BroadcastExchange; import io.substrait.relation.physical.HashJoin; +import io.substrait.relation.physical.ImmutableBroadcastExchange; +import io.substrait.relation.physical.ImmutableExchangeTarget; +import io.substrait.relation.physical.ImmutableMultiBucketExchange; +import io.substrait.relation.physical.ImmutableRoundRobinExchange; +import io.substrait.relation.physical.ImmutableScatterExchange; +import io.substrait.relation.physical.ImmutableSingleBucketExchange; import io.substrait.relation.physical.MergeJoin; +import io.substrait.relation.physical.MultiBucketExchange; import io.substrait.relation.physical.NestedLoopJoin; +import io.substrait.relation.physical.RoundRobinExchange; +import io.substrait.relation.physical.ScatterExchange; +import io.substrait.relation.physical.SingleBucketExchange; +import io.substrait.relation.physical.TargetType; import io.substrait.type.NamedStruct; import io.substrait.type.Type; import io.substrait.type.proto.ProtoTypeConverter; @@ -163,6 +178,8 @@ public Rel from(io.substrait.proto.Rel rel) { return newDdl(rel.getDdl()); case UPDATE: return newUpdate(rel.getUpdate()); + case EXCHANGE: + return newExchange(rel.getExchange()); default: throw new UnsupportedOperationException("Unsupported RelTypeCase of " + relType); } @@ -977,6 +994,163 @@ protected ConsistentPartitionWindow newConsistentPartitionWindow( return builder.build(); } + protected AbstractExchangeRel newExchange(ExchangeRel rel) { + ExchangeRel.ExchangeKindCase exchangeKind = rel.getExchangeKindCase(); + switch (exchangeKind) { + case SCATTER_BY_FIELDS: + return newScatterExchange(rel); + case SINGLE_TARGET: + return newSingleBucketExchange(rel); + case MULTI_TARGET: + return newMultiBucketExchange(rel); + case BROADCAST: + return newBroadcastExchange(rel); + case ROUND_ROBIN: + return newRoundRobinExchange(rel); + default: + throw new UnsupportedOperationException("Unsupported ExchangeKindCase of " + exchangeKind); + } + } + + protected ScatterExchange newScatterExchange(ExchangeRel rel) { + Rel input = from(rel.getInput()); + List targets = + rel.getTargetsList().stream().map(this::newExchangeTarget).collect(Collectors.toList()); + + ProtoExpressionConverter protoExprConverter = + new ProtoExpressionConverter(lookup, extensions, input.getRecordType(), this); + List fieldReferences = + rel.getScatterByFields().getFieldsList().stream() + .map(protoExprConverter::from) + .collect(Collectors.toList()); + + ImmutableScatterExchange.Builder builder = + ScatterExchange.builder() + .input(input) + .addAllFields(fieldReferences) + .partitionCount(rel.getPartitionCount()) + .targets(targets); + + builder + .commonExtension(optionalAdvancedExtension(rel.getCommon())) + .remap(optionalRelmap(rel.getCommon())) + .hint(optionalHint(rel.getCommon())); + if (rel.hasAdvancedExtension()) { + builder.extension(protoExtensionConverter.fromProto(rel.getAdvancedExtension())); + } + return builder.build(); + } + + protected SingleBucketExchange newSingleBucketExchange(ExchangeRel rel) { + Rel input = from(rel.getInput()); + List targets = + rel.getTargetsList().stream().map(this::newExchangeTarget).collect(Collectors.toList()); + ProtoExpressionConverter protoExprConverter = + new ProtoExpressionConverter(lookup, extensions, input.getRecordType(), this); + + ImmutableSingleBucketExchange.Builder builder = + SingleBucketExchange.builder() + .input(input) + .partitionCount(rel.getPartitionCount()) + .targets(targets) + .expression(protoExprConverter.from(rel.getSingleTarget().getExpression())); + + builder + .commonExtension(optionalAdvancedExtension(rel.getCommon())) + .remap(optionalRelmap(rel.getCommon())) + .hint(optionalHint(rel.getCommon())); + if (rel.hasAdvancedExtension()) { + builder.extension(protoExtensionConverter.fromProto(rel.getAdvancedExtension())); + } + return builder.build(); + } + + protected MultiBucketExchange newMultiBucketExchange(ExchangeRel rel) { + Rel input = from(rel.getInput()); + List targets = + rel.getTargetsList().stream().map(this::newExchangeTarget).collect(Collectors.toList()); + ProtoExpressionConverter protoExprConverter = + new ProtoExpressionConverter(lookup, extensions, input.getRecordType(), this); + + ImmutableMultiBucketExchange.Builder builder = + MultiBucketExchange.builder() + .input(input) + .partitionCount(rel.getPartitionCount()) + .targets(targets) + .expression(protoExprConverter.from(rel.getMultiTarget().getExpression())) + .constrainedToCount(rel.getMultiTarget().getConstrainedToCount()); + + builder + .commonExtension(optionalAdvancedExtension(rel.getCommon())) + .remap(optionalRelmap(rel.getCommon())) + .hint(optionalHint(rel.getCommon())); + if (rel.hasAdvancedExtension()) { + builder.extension(protoExtensionConverter.fromProto(rel.getAdvancedExtension())); + } + return builder.build(); + } + + protected RoundRobinExchange newRoundRobinExchange(ExchangeRel rel) { + Rel input = from(rel.getInput()); + List targets = + rel.getTargetsList().stream().map(this::newExchangeTarget).collect(Collectors.toList()); + + ImmutableRoundRobinExchange.Builder builder = + RoundRobinExchange.builder() + .input(input) + .partitionCount(rel.getPartitionCount()) + .targets(targets) + .exact(rel.getRoundRobin().getExact()); + + builder + .commonExtension(optionalAdvancedExtension(rel.getCommon())) + .remap(optionalRelmap(rel.getCommon())) + .hint(optionalHint(rel.getCommon())); + if (rel.hasAdvancedExtension()) { + builder.extension(protoExtensionConverter.fromProto(rel.getAdvancedExtension())); + } + return builder.build(); + } + + protected BroadcastExchange newBroadcastExchange(ExchangeRel rel) { + Rel input = from(rel.getInput()); + List targets = + rel.getTargetsList().stream().map(this::newExchangeTarget).collect(Collectors.toList()); + + ImmutableBroadcastExchange.Builder builder = + BroadcastExchange.builder() + .input(input) + .partitionCount(rel.getPartitionCount()) + .targets(targets); + + builder + .commonExtension(optionalAdvancedExtension(rel.getCommon())) + .remap(optionalRelmap(rel.getCommon())) + .hint(optionalHint(rel.getCommon())); + if (rel.hasAdvancedExtension()) { + builder.extension(protoExtensionConverter.fromProto(rel.getAdvancedExtension())); + } + return builder.build(); + } + + protected AbstractExchangeRel.ExchangeTarget newExchangeTarget( + ExchangeRel.ExchangeTarget target) { + ImmutableExchangeTarget.Builder builder = AbstractExchangeRel.ExchangeTarget.builder(); + builder.addAllPartitionIds(target.getPartitionIdList()); + switch (target.getTargetTypeCase()) { + case URI: + builder.type(TargetType.Uri.builder().uri(target.getUri()).build()); + break; + case EXTENDED: + builder.type(TargetType.Extended.builder().extended(target.getExtended()).build()); + break; + default: + throw new UnsupportedOperationException( + "Unsupported TargetTypeCase of " + target.getTargetTypeCase()); + } + return builder.build(); + } + protected static Optional optionalRelmap(io.substrait.proto.RelCommon relCommon) { return Optional.ofNullable( relCommon.hasEmit() ? Rel.Remap.of(relCommon.getEmit().getOutputMappingList()) : null); diff --git a/core/src/main/java/io/substrait/relation/RelCopyOnWriteVisitor.java b/core/src/main/java/io/substrait/relation/RelCopyOnWriteVisitor.java index 14a3b9fda..b144ede1d 100644 --- a/core/src/main/java/io/substrait/relation/RelCopyOnWriteVisitor.java +++ b/core/src/main/java/io/substrait/relation/RelCopyOnWriteVisitor.java @@ -8,9 +8,14 @@ import io.substrait.expression.Expression; import io.substrait.expression.FieldReference; import io.substrait.expression.FunctionArg; +import io.substrait.relation.physical.BroadcastExchange; import io.substrait.relation.physical.HashJoin; import io.substrait.relation.physical.MergeJoin; +import io.substrait.relation.physical.MultiBucketExchange; import io.substrait.relation.physical.NestedLoopJoin; +import io.substrait.relation.physical.RoundRobinExchange; +import io.substrait.relation.physical.ScatterExchange; +import io.substrait.relation.physical.SingleBucketExchange; import io.substrait.util.EmptyVisitationContext; import java.util.List; import java.util.Optional; @@ -274,6 +279,91 @@ public Optional visit(NamedUpdate update, EmptyVisitationContext context) t .build()); } + @Override + public Optional visit(ScatterExchange exchange, EmptyVisitationContext context) throws E { + Optional input = exchange.getInput().accept(this, context); + Optional> fields = + transformList(exchange.getFields(), context, this::visitFieldReference); + + if (allEmpty(input, fields)) { + return Optional.empty(); + } + + return Optional.of( + ScatterExchange.builder() + .from(exchange) + .input(input.orElse(exchange.getInput())) + .fields(fields.orElse(exchange.getFields())) + .build()); + } + + @Override + public Optional visit(SingleBucketExchange exchange, EmptyVisitationContext context) + throws E { + Optional input = exchange.getInput().accept(this, context); + + Optional expression = + exchange.getExpression().accept(getExpressionCopyOnWriteVisitor(), context); + + if (allEmpty(input, expression)) { + return Optional.empty(); + } + + return Optional.of( + SingleBucketExchange.builder() + .from(exchange) + .input(input.orElse(exchange.getInput())) + .expression(expression.orElse(exchange.getExpression())) + .build()); + } + + @Override + public Optional visit(MultiBucketExchange exchange, EmptyVisitationContext context) + throws E { + Optional input = exchange.getInput().accept(this, context); + Optional expression = + exchange.getExpression().accept(getExpressionCopyOnWriteVisitor(), context); + + if (allEmpty(input)) { + return Optional.empty(); + } + + return Optional.of( + MultiBucketExchange.builder() + .from(exchange) + .input(input.orElse(exchange.getInput())) + .expression(expression.orElse(exchange.getExpression())) + .build()); + } + + @Override + public Optional visit(RoundRobinExchange exchange, EmptyVisitationContext context) throws E { + Optional input = exchange.getInput().accept(this, context); + if (allEmpty(input)) { + return Optional.empty(); + } + + return Optional.of( + RoundRobinExchange.builder() + .from(exchange) + .input(input.orElse(exchange.getInput())) + .build()); + } + + @Override + public Optional visit(BroadcastExchange exchange, EmptyVisitationContext context) throws E { + Optional input = exchange.getInput().accept(this, context); + if (allEmpty(input)) { + return Optional.empty(); + } + + return Optional.of( + BroadcastExchange.builder() + .from(exchange) + .input(input.orElse(exchange.getInput())) + .build()); + } + @Override public Optional visit(Sort sort, EmptyVisitationContext context) throws E { Optional input = sort.getInput().accept(this, context); diff --git a/core/src/main/java/io/substrait/relation/RelProtoConverter.java b/core/src/main/java/io/substrait/relation/RelProtoConverter.java index 1b3611388..20ef1cce7 100644 --- a/core/src/main/java/io/substrait/relation/RelProtoConverter.java +++ b/core/src/main/java/io/substrait/relation/RelProtoConverter.java @@ -14,6 +14,7 @@ import io.substrait.proto.ConsistentPartitionWindowRel; import io.substrait.proto.CrossRel; import io.substrait.proto.DdlRel; +import io.substrait.proto.ExchangeRel; import io.substrait.proto.ExpandRel; import io.substrait.proto.ExtensionLeafRel; import io.substrait.proto.ExtensionMultiRel; @@ -43,9 +44,16 @@ import io.substrait.proto.UpdateRel; import io.substrait.proto.WriteRel; import io.substrait.relation.files.FileOrFiles; +import io.substrait.relation.physical.AbstractExchangeRel; +import io.substrait.relation.physical.BroadcastExchange; import io.substrait.relation.physical.HashJoin; import io.substrait.relation.physical.MergeJoin; +import io.substrait.relation.physical.MultiBucketExchange; import io.substrait.relation.physical.NestedLoopJoin; +import io.substrait.relation.physical.RoundRobinExchange; +import io.substrait.relation.physical.ScatterExchange; +import io.substrait.relation.physical.SingleBucketExchange; +import io.substrait.relation.physical.TargetType; import io.substrait.type.proto.TypeProtoConverter; import io.substrait.util.EmptyVisitationContext; import java.util.Collection; @@ -522,6 +530,101 @@ public Rel visit(NamedUpdate update, EmptyVisitationContext context) throws Runt return Rel.newBuilder().setUpdate(builder).build(); } + @Override + public Rel visit(ScatterExchange exchange, EmptyVisitationContext context) + throws RuntimeException { + ExchangeRel.Builder builder = + ExchangeRel.newBuilder() + .setScatterByFields( + ExchangeRel.ScatterFields.newBuilder() + .addAllFields( + exchange.getFields().stream() + .map(this::toProto) + .collect(Collectors.toList())) + .build()) + .setPartitionCount(exchange.getPartitionCount()) + .addAllTargets( + exchange.getTargets().stream().map(this::toProto).collect(Collectors.toList())) + .setCommon(common(exchange)) + .setInput(toProto(exchange.getInput())); + return Rel.newBuilder().setExchange(builder).build(); + } + + @Override + public Rel visit(SingleBucketExchange exchange, EmptyVisitationContext context) + throws RuntimeException { + ExchangeRel.Builder builder = + ExchangeRel.newBuilder() + .setSingleTarget( + ExchangeRel.SingleBucketExpression.newBuilder() + .setExpression(toProto(exchange.getExpression())) + .build()) + .setPartitionCount(exchange.getPartitionCount()) + .addAllTargets( + exchange.getTargets().stream().map(this::toProto).collect(Collectors.toList())) + .setCommon(common(exchange)) + .setInput(toProto(exchange.getInput())); + return Rel.newBuilder().setExchange(builder).build(); + } + + @Override + public Rel visit(MultiBucketExchange exchange, EmptyVisitationContext context) + throws RuntimeException { + ExchangeRel.Builder builder = + ExchangeRel.newBuilder() + .setMultiTarget( + ExchangeRel.MultiBucketExpression.newBuilder() + .setExpression(toProto(exchange.getExpression())) + .setConstrainedToCount(exchange.getConstrainedToCount()) + .build()) + .setPartitionCount(exchange.getPartitionCount()) + .addAllTargets( + exchange.getTargets().stream().map(this::toProto).collect(Collectors.toList())) + .setCommon(common(exchange)) + .setInput(toProto(exchange.getInput())); + return Rel.newBuilder().setExchange(builder).build(); + } + + @Override + public Rel visit(RoundRobinExchange exchange, EmptyVisitationContext context) + throws RuntimeException { + ExchangeRel.Builder builder = + ExchangeRel.newBuilder() + .setRoundRobin( + ExchangeRel.RoundRobin.newBuilder().setExact(exchange.getExact()).build()) + .setPartitionCount(exchange.getPartitionCount()) + .addAllTargets( + exchange.getTargets().stream().map(this::toProto).collect(Collectors.toList())) + .setCommon(common(exchange)) + .setInput(toProto(exchange.getInput())); + return Rel.newBuilder().setExchange(builder).build(); + } + + @Override + public Rel visit(BroadcastExchange exchange, EmptyVisitationContext context) + throws RuntimeException { + ExchangeRel.Builder builder = + ExchangeRel.newBuilder() + .setBroadcast(ExchangeRel.Broadcast.newBuilder().build()) + .setPartitionCount(exchange.getPartitionCount()) + .addAllTargets( + exchange.getTargets().stream().map(this::toProto).collect(Collectors.toList())) + .setCommon(common(exchange)) + .setInput(toProto(exchange.getInput())); + return Rel.newBuilder().setExchange(builder).build(); + } + + private ExchangeRel.ExchangeTarget toProto(AbstractExchangeRel.ExchangeTarget target) { + ExchangeRel.ExchangeTarget.Builder builder = + ExchangeRel.ExchangeTarget.newBuilder().addAllPartitionId(target.getPartitionIds()); + if (target.getType() instanceof TargetType.Uri) { + builder.setUri(((TargetType.Uri) target.getType()).getUri()); + } else if (target.getType() instanceof TargetType.Extended) { + builder.setExtended(((TargetType.Extended) target.getType()).getExtended()); + } + return builder.build(); + } + UpdateRel.TransformExpression toProto(AbstractUpdate.TransformExpression transformation) { return UpdateRel.TransformExpression.newBuilder() .setTransformation(toProto(transformation.getTransformation())) diff --git a/core/src/main/java/io/substrait/relation/RelVisitor.java b/core/src/main/java/io/substrait/relation/RelVisitor.java index 23ce99fea..9cc19bad3 100644 --- a/core/src/main/java/io/substrait/relation/RelVisitor.java +++ b/core/src/main/java/io/substrait/relation/RelVisitor.java @@ -1,8 +1,13 @@ package io.substrait.relation; +import io.substrait.relation.physical.BroadcastExchange; import io.substrait.relation.physical.HashJoin; import io.substrait.relation.physical.MergeJoin; +import io.substrait.relation.physical.MultiBucketExchange; import io.substrait.relation.physical.NestedLoopJoin; +import io.substrait.relation.physical.RoundRobinExchange; +import io.substrait.relation.physical.ScatterExchange; +import io.substrait.relation.physical.SingleBucketExchange; import io.substrait.util.VisitationContext; public interface RelVisitor { @@ -57,4 +62,14 @@ public interface RelVisitor O visit(ExtensionDdl ddl, C context) throws E; O visit(NamedUpdate update, C context) throws E; + + O visit(ScatterExchange exchange, C context) throws E; + + O visit(SingleBucketExchange exchange, C context) throws E; + + O visit(MultiBucketExchange exchange, C context) throws E; + + O visit(RoundRobinExchange exchange, C context) throws E; + + O visit(BroadcastExchange exchange, C context) throws E; } diff --git a/core/src/main/java/io/substrait/relation/physical/AbstractExchangeRel.java b/core/src/main/java/io/substrait/relation/physical/AbstractExchangeRel.java new file mode 100644 index 000000000..10520647b --- /dev/null +++ b/core/src/main/java/io/substrait/relation/physical/AbstractExchangeRel.java @@ -0,0 +1,29 @@ +package io.substrait.relation.physical; + +import io.substrait.relation.HasExtension; +import io.substrait.relation.SingleInputRel; +import io.substrait.type.Type; +import java.util.List; +import org.immutables.value.Value; + +public abstract class AbstractExchangeRel extends SingleInputRel implements HasExtension { + public abstract Integer getPartitionCount(); + + public abstract List getTargets(); + + @Override + protected Type.Struct deriveRecordType() { + return getInput().getRecordType(); + } + + @Value.Immutable + public abstract static class ExchangeTarget { + public abstract List getPartitionIds(); + + public abstract TargetType getType(); + + public static ImmutableExchangeTarget.Builder builder() { + return ImmutableExchangeTarget.builder(); + } + } +} diff --git a/core/src/main/java/io/substrait/relation/physical/BroadcastExchange.java b/core/src/main/java/io/substrait/relation/physical/BroadcastExchange.java new file mode 100644 index 000000000..10dc1e532 --- /dev/null +++ b/core/src/main/java/io/substrait/relation/physical/BroadcastExchange.java @@ -0,0 +1,18 @@ +package io.substrait.relation.physical; + +import io.substrait.relation.RelVisitor; +import io.substrait.util.VisitationContext; +import org.immutables.value.Value; + +@Value.Immutable +public abstract class BroadcastExchange extends AbstractExchangeRel { + @Override + public O accept( + RelVisitor visitor, C context) throws E { + return visitor.visit(this, context); + } + + public static ImmutableBroadcastExchange.Builder builder() { + return ImmutableBroadcastExchange.builder(); + } +} diff --git a/core/src/main/java/io/substrait/relation/physical/MultiBucketExchange.java b/core/src/main/java/io/substrait/relation/physical/MultiBucketExchange.java new file mode 100644 index 000000000..9f1c2f31e --- /dev/null +++ b/core/src/main/java/io/substrait/relation/physical/MultiBucketExchange.java @@ -0,0 +1,23 @@ +package io.substrait.relation.physical; + +import io.substrait.expression.Expression; +import io.substrait.relation.RelVisitor; +import io.substrait.util.VisitationContext; +import org.immutables.value.Value; + +@Value.Immutable +public abstract class MultiBucketExchange extends AbstractExchangeRel { + public abstract Expression getExpression(); + + public abstract boolean getConstrainedToCount(); + + @Override + public O accept( + RelVisitor visitor, C context) throws E { + return visitor.visit(this, context); + } + + public static ImmutableMultiBucketExchange.Builder builder() { + return ImmutableMultiBucketExchange.builder(); + } +} diff --git a/core/src/main/java/io/substrait/relation/physical/RoundRobinExchange.java b/core/src/main/java/io/substrait/relation/physical/RoundRobinExchange.java new file mode 100644 index 000000000..3bbb3e370 --- /dev/null +++ b/core/src/main/java/io/substrait/relation/physical/RoundRobinExchange.java @@ -0,0 +1,20 @@ +package io.substrait.relation.physical; + +import io.substrait.relation.RelVisitor; +import io.substrait.util.VisitationContext; +import org.immutables.value.Value; + +@Value.Immutable +public abstract class RoundRobinExchange extends AbstractExchangeRel { + public abstract boolean getExact(); + + @Override + public O accept( + RelVisitor visitor, C context) throws E { + return visitor.visit(this, context); + } + + public static ImmutableRoundRobinExchange.Builder builder() { + return ImmutableRoundRobinExchange.builder(); + } +} diff --git a/core/src/main/java/io/substrait/relation/physical/ScatterExchange.java b/core/src/main/java/io/substrait/relation/physical/ScatterExchange.java new file mode 100644 index 000000000..6f8f99977 --- /dev/null +++ b/core/src/main/java/io/substrait/relation/physical/ScatterExchange.java @@ -0,0 +1,22 @@ +package io.substrait.relation.physical; + +import io.substrait.expression.FieldReference; +import io.substrait.relation.RelVisitor; +import io.substrait.util.VisitationContext; +import java.util.List; +import org.immutables.value.Value; + +@Value.Immutable +public abstract class ScatterExchange extends AbstractExchangeRel { + public abstract List getFields(); + + @Override + public O accept( + RelVisitor visitor, C context) throws E { + return visitor.visit(this, context); + } + + public static ImmutableScatterExchange.Builder builder() { + return ImmutableScatterExchange.builder(); + } +} diff --git a/core/src/main/java/io/substrait/relation/physical/SingleBucketExchange.java b/core/src/main/java/io/substrait/relation/physical/SingleBucketExchange.java new file mode 100644 index 000000000..6446d8389 --- /dev/null +++ b/core/src/main/java/io/substrait/relation/physical/SingleBucketExchange.java @@ -0,0 +1,21 @@ +package io.substrait.relation.physical; + +import io.substrait.expression.Expression; +import io.substrait.relation.RelVisitor; +import io.substrait.util.VisitationContext; +import org.immutables.value.Value; + +@Value.Immutable +public abstract class SingleBucketExchange extends AbstractExchangeRel { + public abstract Expression getExpression(); + + @Override + public O accept( + RelVisitor visitor, C context) throws E { + return visitor.visit(this, context); + } + + public static ImmutableSingleBucketExchange.Builder builder() { + return ImmutableSingleBucketExchange.builder(); + } +} diff --git a/core/src/main/java/io/substrait/relation/physical/TargetType.java b/core/src/main/java/io/substrait/relation/physical/TargetType.java new file mode 100644 index 000000000..38b30bcb7 --- /dev/null +++ b/core/src/main/java/io/substrait/relation/physical/TargetType.java @@ -0,0 +1,25 @@ +package io.substrait.relation.physical; + +import org.immutables.value.Value; + +@Value.Enclosing +public interface TargetType { + + @Value.Immutable + abstract class Uri implements TargetType { + public abstract String getUri(); + + public static ImmutableTargetType.Uri.Builder builder() { + return ImmutableTargetType.Uri.builder(); + } + } + + @Value.Immutable + abstract class Extended implements TargetType { + public abstract com.google.protobuf.Any getExtended(); + + public static ImmutableTargetType.Extended.Builder builder() { + return ImmutableTargetType.Extended.builder(); + } + } +} diff --git a/core/src/test/java/io/substrait/type/proto/ExchangeRelRoundtripTest.java b/core/src/test/java/io/substrait/type/proto/ExchangeRelRoundtripTest.java new file mode 100644 index 000000000..667af2486 --- /dev/null +++ b/core/src/test/java/io/substrait/type/proto/ExchangeRelRoundtripTest.java @@ -0,0 +1,108 @@ +package io.substrait.type.proto; + +import io.substrait.TestBase; +import io.substrait.relation.Rel; +import io.substrait.relation.physical.AbstractExchangeRel; +import io.substrait.relation.physical.BroadcastExchange; +import io.substrait.relation.physical.MultiBucketExchange; +import io.substrait.relation.physical.RoundRobinExchange; +import io.substrait.relation.physical.ScatterExchange; +import io.substrait.relation.physical.SingleBucketExchange; +import io.substrait.relation.physical.TargetType; +import java.util.Arrays; +import java.util.Collections; +import java.util.List; +import org.junit.jupiter.api.Test; + +public class ExchangeRelRoundtripTest extends TestBase { + + final Rel baseTable = + b.namedScan( + Collections.singletonList("exchange_test_table"), + Arrays.asList("id", "amount", "name", "status"), + Arrays.asList(R.I64, R.FP64, R.STRING, R.BOOLEAN)); + + @Test + void broadcastExchange() { + Rel exchange = BroadcastExchange.builder().input(baseTable).partitionCount(1).build(); + + verifyRoundTrip(exchange); + } + + @Test + void roundRobinExchange() { + Rel exchange = + RoundRobinExchange.builder().input(baseTable).exact(true).partitionCount(1).build(); + + verifyRoundTrip(exchange); + } + + @Test + void scatterExchange() { + Rel exchange = + ScatterExchange.builder() + .input(baseTable) + .addFields(b.fieldReference(baseTable, 0)) + .partitionCount(1) + .build(); + + verifyRoundTrip(exchange); + } + + @Test + void singleBucketExchange() { + Rel exchange = + SingleBucketExchange.builder() + .input(baseTable) + .partitionCount(1) + .expression(b.fieldReference(baseTable, 0)) + .build(); + + verifyRoundTrip(exchange); + } + + @Test + void multiBucketExchange() { + Rel exchange = + MultiBucketExchange.builder() + .input(baseTable) + .expression(b.fieldReference(baseTable, 0)) + .constrainedToCount(true) + .partitionCount(1) + .build(); + + verifyRoundTrip(exchange); + } + + @Test + void exchangeWithTargets() { + AbstractExchangeRel.ExchangeTarget target1 = + AbstractExchangeRel.ExchangeTarget.builder() + .partitionIds(Arrays.asList(0, 1)) + .type(TargetType.Uri.builder().uri("hdfs://example.com/data1").build()) + .build(); + + AbstractExchangeRel.ExchangeTarget target2 = + AbstractExchangeRel.ExchangeTarget.builder() + .partitionIds(Arrays.asList(2, 3)) + .type(TargetType.Uri.builder().uri("hdfs://example.com/data2").build()) + .build(); + + List targets = Arrays.asList(target1, target2); + + Rel exchange = + BroadcastExchange.builder().input(baseTable).targets(targets).partitionCount(1).build(); + + verifyRoundTrip(exchange); + } + + @Test + void nestedExchangeRelations() { + Rel innerExchange = BroadcastExchange.builder().input(baseTable).partitionCount(1).build(); + + Rel outerExchange = + RoundRobinExchange.builder().input(innerExchange).exact(false).partitionCount(1).build(); + + verifyRoundTrip(outerExchange); + } +} diff --git a/examples/substrait-spark/src/main/java/io/substrait/examples/util/SubstraitStringify.java b/examples/substrait-spark/src/main/java/io/substrait/examples/util/SubstraitStringify.java index f3b34f6c2..a94d47fd4 100644 --- a/examples/substrait-spark/src/main/java/io/substrait/examples/util/SubstraitStringify.java +++ b/examples/substrait-spark/src/main/java/io/substrait/examples/util/SubstraitStringify.java @@ -27,9 +27,14 @@ import io.substrait.relation.Sort; import io.substrait.relation.VirtualTableScan; import io.substrait.relation.files.FileOrFiles; +import io.substrait.relation.physical.BroadcastExchange; import io.substrait.relation.physical.HashJoin; import io.substrait.relation.physical.MergeJoin; +import io.substrait.relation.physical.MultiBucketExchange; import io.substrait.relation.physical.NestedLoopJoin; +import io.substrait.relation.physical.RoundRobinExchange; +import io.substrait.relation.physical.ScatterExchange; +import io.substrait.relation.physical.SingleBucketExchange; import io.substrait.type.NamedStruct; import io.substrait.util.EmptyVisitationContext; import java.util.ArrayList; @@ -393,4 +398,39 @@ public String visit(NamedUpdate update, EmptyVisitationContext context) throws R StringBuilder sb = getIndent().append("namedUpdate:: "); return getOutdent(sb); } + + @Override + public String visit(ScatterExchange exchange, EmptyVisitationContext context) + throws RuntimeException { + StringBuilder sb = getIndent().append("scatterExchange:: "); + return getOutdent(sb); + } + + @Override + public String visit(SingleBucketExchange exchange, EmptyVisitationContext context) + throws RuntimeException { + StringBuilder sb = getIndent().append("singleBucketExchange:: "); + return getOutdent(sb); + } + + @Override + public String visit(MultiBucketExchange exchange, EmptyVisitationContext context) + throws RuntimeException { + StringBuilder sb = getIndent().append("multiBucketExchange:: "); + return getOutdent(sb); + } + + @Override + public String visit(RoundRobinExchange exchange, EmptyVisitationContext context) + throws RuntimeException { + StringBuilder sb = getIndent().append("roundRobinExchange:: "); + return getOutdent(sb); + } + + @Override + public String visit(BroadcastExchange exchange, EmptyVisitationContext context) + throws RuntimeException { + StringBuilder sb = getIndent().append("broadcastExchange:: "); + return getOutdent(sb); + } } diff --git a/isthmus/src/main/java/io/substrait/isthmus/SqlKindFromRel.java b/isthmus/src/main/java/io/substrait/isthmus/SqlKindFromRel.java index 0a1f475c7..ef2da19f0 100644 --- a/isthmus/src/main/java/io/substrait/isthmus/SqlKindFromRel.java +++ b/isthmus/src/main/java/io/substrait/isthmus/SqlKindFromRel.java @@ -24,9 +24,14 @@ import io.substrait.relation.Set; import io.substrait.relation.Sort; import io.substrait.relation.VirtualTableScan; +import io.substrait.relation.physical.BroadcastExchange; import io.substrait.relation.physical.HashJoin; import io.substrait.relation.physical.MergeJoin; +import io.substrait.relation.physical.MultiBucketExchange; import io.substrait.relation.physical.NestedLoopJoin; +import io.substrait.relation.physical.RoundRobinExchange; +import io.substrait.relation.physical.ScatterExchange; +import io.substrait.relation.physical.SingleBucketExchange; import io.substrait.util.EmptyVisitationContext; import org.apache.calcite.sql.SqlKind; @@ -237,4 +242,34 @@ public SqlKind visit(ExtensionDdl ddl, EmptyVisitationContext context) throws Ru public SqlKind visit(NamedUpdate update, EmptyVisitationContext context) throws RuntimeException { return SqlKind.UPDATE; } + + @Override + public SqlKind visit(ScatterExchange exchange, EmptyVisitationContext context) + throws RuntimeException { + return SqlKind.OTHER_DDL; + } + + @Override + public SqlKind visit(SingleBucketExchange exchange, EmptyVisitationContext context) + throws RuntimeException { + return SqlKind.OTHER_DDL; + } + + @Override + public SqlKind visit(MultiBucketExchange exchange, EmptyVisitationContext context) + throws RuntimeException { + return SqlKind.OTHER_DDL; + } + + @Override + public SqlKind visit(RoundRobinExchange exchange, EmptyVisitationContext context) + throws RuntimeException { + return SqlKind.OTHER_DDL; + } + + @Override + public SqlKind visit(BroadcastExchange exchange, EmptyVisitationContext context) + throws RuntimeException { + return SqlKind.OTHER_DDL; + } } diff --git a/isthmus/src/main/java/io/substrait/isthmus/SubstraitRelNodeConverter.java b/isthmus/src/main/java/io/substrait/isthmus/SubstraitRelNodeConverter.java index fe53b7e46..25f128a2d 100644 --- a/isthmus/src/main/java/io/substrait/isthmus/SubstraitRelNodeConverter.java +++ b/isthmus/src/main/java/io/substrait/isthmus/SubstraitRelNodeConverter.java @@ -38,6 +38,11 @@ import io.substrait.relation.Set; import io.substrait.relation.Sort; import io.substrait.relation.VirtualTableScan; +import io.substrait.relation.physical.BroadcastExchange; +import io.substrait.relation.physical.MultiBucketExchange; +import io.substrait.relation.physical.RoundRobinExchange; +import io.substrait.relation.physical.ScatterExchange; +import io.substrait.relation.physical.SingleBucketExchange; import io.substrait.type.NamedStruct; import io.substrait.type.TypeCreator; import io.substrait.util.VisitationContext; @@ -591,6 +596,31 @@ public RelNode visit(NamedUpdate update, Context context) { false); } + @Override + public RelNode visit(ScatterExchange exchange, Context context) throws RuntimeException { + return visitFallback(exchange, context); + } + + @Override + public RelNode visit(SingleBucketExchange exchange, Context context) throws RuntimeException { + return visitFallback(exchange, context); + } + + @Override + public RelNode visit(MultiBucketExchange exchange, Context context) throws RuntimeException { + return visitFallback(exchange, context); + } + + @Override + public RelNode visit(RoundRobinExchange exchange, Context context) throws RuntimeException { + return visitFallback(exchange, context); + } + + @Override + public RelNode visit(BroadcastExchange exchange, Context context) throws RuntimeException { + return visitFallback(exchange, context); + } + @Override public RelNode visit(NamedDdl namedDdl, Context context) { if (namedDdl.getOperation() != AbstractDdlRel.DdlOp.CREATE diff --git a/spark/src/main/scala/io/substrait/debug/RelToVerboseString.scala b/spark/src/main/scala/io/substrait/debug/RelToVerboseString.scala index cf6b7b72c..c30b8b01e 100644 --- a/spark/src/main/scala/io/substrait/debug/RelToVerboseString.scala +++ b/spark/src/main/scala/io/substrait/debug/RelToVerboseString.scala @@ -19,6 +19,7 @@ package io.substrait.debug import io.substrait.spark.DefaultRelVisitor import io.substrait.relation._ +import io.substrait.relation.physical.{BroadcastExchange, MultiBucketExchange, RoundRobinExchange, ScatterExchange, SingleBucketExchange} import io.substrait.util.EmptyVisitationContext import scala.collection.mutable @@ -212,6 +213,71 @@ class RelToVerboseString(addSuffix: Boolean) extends DefaultRelVisitor[String] { .append(localFiles.getItems) }) } + + override def visit(exchange: ScatterExchange, context: EmptyVisitationContext): String = { + withBuilder(exchange, 10)( + builder => { + builder + .append("partitionCount=") + .append(exchange.getPartitionCount) + .append("targets=") + .append(exchange.getTargets) + .append("fields=") + .append(exchange.getFields) + }) + } + + override def visit(exchange: SingleBucketExchange, context: EmptyVisitationContext): String = { + withBuilder(exchange, 10)( + builder => { + builder + .append("partitionCount=") + .append(exchange.getPartitionCount) + .append("targets=") + .append(exchange.getTargets) + .append("expression=") + .append(exchange.getExpression) + }) + } + + override def visit(exchange: MultiBucketExchange, context: EmptyVisitationContext): String = { + withBuilder(exchange, 10)( + builder => { + builder + .append("partitionCount=") + .append(exchange.getPartitionCount) + .append("targets=") + .append(exchange.getTargets) + .append("expression=") + .append(exchange.getExpression) + .append("constrainedToCount=") + .append(exchange.getConstrainedToCount) + }) + } + + override def visit(exchange: RoundRobinExchange, context: EmptyVisitationContext): String = { + withBuilder(exchange, 10)( + builder => { + builder + .append("partitionCount=") + .append(exchange.getPartitionCount) + .append("targets=") + .append(exchange.getTargets) + .append("exact=") + .append(exchange.getExact) + }) + } + + override def visit(exchange: BroadcastExchange, context: EmptyVisitationContext): String = { + withBuilder(exchange, 10)( + builder => { + builder + .append("partitionCount=") + .append(exchange.getPartitionCount) + .append("targets=") + .append(exchange.getTargets) + }) + } } object RelToVerboseString { diff --git a/spark/src/main/scala/io/substrait/spark/logical/ToLogicalPlan.scala b/spark/src/main/scala/io/substrait/spark/logical/ToLogicalPlan.scala index 875eb091b..f383753f7 100644 --- a/spark/src/main/scala/io/substrait/spark/logical/ToLogicalPlan.scala +++ b/spark/src/main/scala/io/substrait/spark/logical/ToLogicalPlan.scala @@ -49,6 +49,7 @@ import io.substrait.relation.AbstractWriteRel.{CreateMode, WriteOp} import io.substrait.relation.Expand.{ConsistentField, SwitchingField} import io.substrait.relation.Set.SetOp import io.substrait.relation.files.FileFormat +import io.substrait.relation.physical.{BroadcastExchange, MultiBucketExchange, RoundRobinExchange, ScatterExchange, SingleBucketExchange} import io.substrait.util.EmptyVisitationContext import org.apache.hadoop.fs.Path @@ -661,4 +662,28 @@ class ToLogicalPlan(spark: SparkSession = SparkSession.builder().getOrCreate()) require(renamedLogicalPlan.resolved) renamedLogicalPlan } + + override def visit(exchange: ScatterExchange, context: EmptyVisitationContext): LogicalPlan = { + visitFallback(exchange, context) + } + + override def visit( + exchange: SingleBucketExchange, + context: EmptyVisitationContext): LogicalPlan = { + visitFallback(exchange, context) + } + + override def visit( + exchange: MultiBucketExchange, + context: EmptyVisitationContext): LogicalPlan = { + visitFallback(exchange, context) + } + + override def visit(exchange: RoundRobinExchange, context: EmptyVisitationContext): LogicalPlan = { + visitFallback(exchange, context) + } + + override def visit(exchange: BroadcastExchange, context: EmptyVisitationContext): LogicalPlan = { + visitFallback(exchange, context) + } }