Skip to content

Commit 415a597

Browse files
committed
feat: add ExchangeRel support to core
1 parent 6acb389 commit 415a597

File tree

17 files changed

+777
-0
lines changed

17 files changed

+777
-0
lines changed

core/src/main/java/io/substrait/relation/AbstractRelVisitor.java

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,13 @@
11
package io.substrait.relation;
22

3+
import io.substrait.relation.physical.BroadcastExchange;
34
import io.substrait.relation.physical.HashJoin;
45
import io.substrait.relation.physical.MergeJoin;
6+
import io.substrait.relation.physical.MultiBucketExchange;
57
import io.substrait.relation.physical.NestedLoopJoin;
8+
import io.substrait.relation.physical.RoundRobinExchange;
9+
import io.substrait.relation.physical.ScatterExchange;
10+
import io.substrait.relation.physical.SingleBucketExchange;
611
import io.substrait.util.VisitationContext;
712

813
public abstract class AbstractRelVisitor<O, C extends VisitationContext, E extends Exception>
@@ -138,4 +143,29 @@ public O visit(ExtensionDdl ddl, C context) throws E {
138143
public O visit(NamedUpdate update, C context) throws E {
139144
return visitFallback(update, context);
140145
}
146+
147+
@Override
148+
public O visit(ScatterExchange exchange, C context) throws E {
149+
return visitFallback(exchange, context);
150+
}
151+
152+
@Override
153+
public O visit(SingleBucketExchange exchange, C context) throws E {
154+
return visitFallback(exchange, context);
155+
}
156+
157+
@Override
158+
public O visit(MultiBucketExchange exchange, C context) throws E {
159+
return visitFallback(exchange, context);
160+
}
161+
162+
@Override
163+
public O visit(BroadcastExchange exchange, C context) throws E {
164+
return visitFallback(exchange, context);
165+
}
166+
167+
@Override
168+
public O visit(RoundRobinExchange exchange, C context) throws E {
169+
return visitFallback(exchange, context);
170+
}
141171
}

core/src/main/java/io/substrait/relation/ProtoRelConverter.java

Lines changed: 174 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package io.substrait.relation;
22

33
import io.substrait.expression.Expression;
4+
import io.substrait.expression.FieldReference;
45
import io.substrait.expression.proto.ProtoExpressionConverter;
56
import io.substrait.extension.AdvancedExtension;
67
import io.substrait.extension.DefaultExtensionCatalog;
@@ -18,6 +19,7 @@
1819
import io.substrait.proto.ConsistentPartitionWindowRel;
1920
import io.substrait.proto.CrossRel;
2021
import io.substrait.proto.DdlRel;
22+
import io.substrait.proto.ExchangeRel;
2123
import io.substrait.proto.ExpandRel;
2224
import io.substrait.proto.ExtensionLeafRel;
2325
import io.substrait.proto.ExtensionMultiRel;
@@ -37,9 +39,22 @@
3739
import io.substrait.relation.extensions.EmptyDetail;
3840
import io.substrait.relation.files.FileFormat;
3941
import io.substrait.relation.files.FileOrFiles;
42+
import io.substrait.relation.physical.AbstractExchangeRel;
43+
import io.substrait.relation.physical.BroadcastExchange;
4044
import io.substrait.relation.physical.HashJoin;
45+
import io.substrait.relation.physical.ImmutableBroadcastExchange;
46+
import io.substrait.relation.physical.ImmutableExchangeTarget;
47+
import io.substrait.relation.physical.ImmutableMultiBucketExchange;
48+
import io.substrait.relation.physical.ImmutableRoundRobinExchange;
49+
import io.substrait.relation.physical.ImmutableScatterExchange;
50+
import io.substrait.relation.physical.ImmutableSingleBucketExchange;
4151
import io.substrait.relation.physical.MergeJoin;
52+
import io.substrait.relation.physical.MultiBucketExchange;
4253
import io.substrait.relation.physical.NestedLoopJoin;
54+
import io.substrait.relation.physical.RoundRobinExchange;
55+
import io.substrait.relation.physical.ScatterExchange;
56+
import io.substrait.relation.physical.SingleBucketExchange;
57+
import io.substrait.relation.physical.TargetType;
4358
import io.substrait.type.NamedStruct;
4459
import io.substrait.type.Type;
4560
import io.substrait.type.proto.ProtoTypeConverter;
@@ -163,6 +178,8 @@ public Rel from(io.substrait.proto.Rel rel) {
163178
return newDdl(rel.getDdl());
164179
case UPDATE:
165180
return newUpdate(rel.getUpdate());
181+
case EXCHANGE:
182+
return newExchange(rel.getExchange());
166183
default:
167184
throw new UnsupportedOperationException("Unsupported RelTypeCase of " + relType);
168185
}
@@ -977,6 +994,163 @@ protected ConsistentPartitionWindow newConsistentPartitionWindow(
977994
return builder.build();
978995
}
979996

997+
protected AbstractExchangeRel newExchange(ExchangeRel rel) {
998+
ExchangeRel.ExchangeKindCase exchangeKind = rel.getExchangeKindCase();
999+
switch (exchangeKind) {
1000+
case SCATTER_BY_FIELDS:
1001+
return newScatterExchange(rel);
1002+
case SINGLE_TARGET:
1003+
return newSingleBucketExchange(rel);
1004+
case MULTI_TARGET:
1005+
return newMultiBucketExchange(rel);
1006+
case BROADCAST:
1007+
return newBroadcastExchange(rel);
1008+
case ROUND_ROBIN:
1009+
return newRoundRobinExchange(rel);
1010+
default:
1011+
throw new UnsupportedOperationException("Unsupported ExchangeKindCase of " + exchangeKind);
1012+
}
1013+
}
1014+
1015+
protected ScatterExchange newScatterExchange(ExchangeRel rel) {
1016+
Rel input = from(rel.getInput());
1017+
List<AbstractExchangeRel.ExchangeTarget> targets =
1018+
rel.getTargetsList().stream().map(this::newExchangeTarget).collect(Collectors.toList());
1019+
1020+
ProtoExpressionConverter converter =
1021+
new ProtoExpressionConverter(lookup, extensions, input.getRecordType(), this);
1022+
List<FieldReference> fieldReferences =
1023+
rel.getScatterByFields().getFieldsList().stream()
1024+
.map(converter::from)
1025+
.collect(Collectors.toList());
1026+
1027+
ImmutableScatterExchange.Builder builder =
1028+
ScatterExchange.builder()
1029+
.input(input)
1030+
.addAllFields(fieldReferences)
1031+
.partitionCount(rel.getPartitionCount())
1032+
.targets(targets);
1033+
1034+
builder
1035+
.commonExtension(optionalAdvancedExtension(rel.getCommon()))
1036+
.remap(optionalRelmap(rel.getCommon()))
1037+
.hint(optionalHint(rel.getCommon()));
1038+
if (rel.hasAdvancedExtension()) {
1039+
builder.extension(protoExtensionConverter.fromProto(rel.getAdvancedExtension()));
1040+
}
1041+
return builder.build();
1042+
}
1043+
1044+
protected SingleBucketExchange newSingleBucketExchange(ExchangeRel rel) {
1045+
Rel input = from(rel.getInput());
1046+
List<AbstractExchangeRel.ExchangeTarget> targets =
1047+
rel.getTargetsList().stream().map(this::newExchangeTarget).collect(Collectors.toList());
1048+
ProtoExpressionConverter converter =
1049+
new ProtoExpressionConverter(lookup, extensions, input.getRecordType(), this);
1050+
1051+
ImmutableSingleBucketExchange.Builder builder =
1052+
SingleBucketExchange.builder()
1053+
.input(input)
1054+
.partitionCount(rel.getPartitionCount())
1055+
.targets(targets)
1056+
.expression(converter.from(rel.getSingleTarget().getExpression()));
1057+
1058+
builder
1059+
.commonExtension(optionalAdvancedExtension(rel.getCommon()))
1060+
.remap(optionalRelmap(rel.getCommon()))
1061+
.hint(optionalHint(rel.getCommon()));
1062+
if (rel.hasAdvancedExtension()) {
1063+
builder.extension(protoExtensionConverter.fromProto(rel.getAdvancedExtension()));
1064+
}
1065+
return builder.build();
1066+
}
1067+
1068+
protected MultiBucketExchange newMultiBucketExchange(ExchangeRel rel) {
1069+
Rel input = from(rel.getInput());
1070+
List<AbstractExchangeRel.ExchangeTarget> targets =
1071+
rel.getTargetsList().stream().map(this::newExchangeTarget).collect(Collectors.toList());
1072+
ProtoExpressionConverter converter =
1073+
new ProtoExpressionConverter(lookup, extensions, input.getRecordType(), this);
1074+
1075+
ImmutableMultiBucketExchange.Builder builder =
1076+
MultiBucketExchange.builder()
1077+
.input(input)
1078+
.partitionCount(rel.getPartitionCount())
1079+
.targets(targets)
1080+
.expression(converter.from(rel.getMultiTarget().getExpression()))
1081+
.constrainedToCount(rel.getMultiTarget().getConstrainedToCount());
1082+
1083+
builder
1084+
.commonExtension(optionalAdvancedExtension(rel.getCommon()))
1085+
.remap(optionalRelmap(rel.getCommon()))
1086+
.hint(optionalHint(rel.getCommon()));
1087+
if (rel.hasAdvancedExtension()) {
1088+
builder.extension(protoExtensionConverter.fromProto(rel.getAdvancedExtension()));
1089+
}
1090+
return builder.build();
1091+
}
1092+
1093+
protected RoundRobinExchange newRoundRobinExchange(ExchangeRel rel) {
1094+
Rel input = from(rel.getInput());
1095+
List<AbstractExchangeRel.ExchangeTarget> targets =
1096+
rel.getTargetsList().stream().map(this::newExchangeTarget).collect(Collectors.toList());
1097+
1098+
ImmutableRoundRobinExchange.Builder builder =
1099+
RoundRobinExchange.builder()
1100+
.input(input)
1101+
.partitionCount(rel.getPartitionCount())
1102+
.targets(targets)
1103+
.exact(rel.getRoundRobin().getExact());
1104+
1105+
builder
1106+
.commonExtension(optionalAdvancedExtension(rel.getCommon()))
1107+
.remap(optionalRelmap(rel.getCommon()))
1108+
.hint(optionalHint(rel.getCommon()));
1109+
if (rel.hasAdvancedExtension()) {
1110+
builder.extension(protoExtensionConverter.fromProto(rel.getAdvancedExtension()));
1111+
}
1112+
return builder.build();
1113+
}
1114+
1115+
protected BroadcastExchange newBroadcastExchange(ExchangeRel rel) {
1116+
Rel input = from(rel.getInput());
1117+
List<AbstractExchangeRel.ExchangeTarget> targets =
1118+
rel.getTargetsList().stream().map(this::newExchangeTarget).collect(Collectors.toList());
1119+
1120+
ImmutableBroadcastExchange.Builder builder =
1121+
BroadcastExchange.builder()
1122+
.input(input)
1123+
.partitionCount(rel.getPartitionCount())
1124+
.targets(targets);
1125+
1126+
builder
1127+
.commonExtension(optionalAdvancedExtension(rel.getCommon()))
1128+
.remap(optionalRelmap(rel.getCommon()))
1129+
.hint(optionalHint(rel.getCommon()));
1130+
if (rel.hasAdvancedExtension()) {
1131+
builder.extension(protoExtensionConverter.fromProto(rel.getAdvancedExtension()));
1132+
}
1133+
return builder.build();
1134+
}
1135+
1136+
protected AbstractExchangeRel.ExchangeTarget newExchangeTarget(
1137+
ExchangeRel.ExchangeTarget target) {
1138+
ImmutableExchangeTarget.Builder builder = AbstractExchangeRel.ExchangeTarget.builder();
1139+
builder.addAllPartitionIds(target.getPartitionIdList());
1140+
switch (target.getTargetTypeCase()) {
1141+
case URI:
1142+
builder.type(TargetType.Uri.builder().uri(target.getUri()).build());
1143+
break;
1144+
case EXTENDED:
1145+
builder.type(TargetType.Extended.builder().extended(target.getExtended()).build());
1146+
break;
1147+
default:
1148+
throw new UnsupportedOperationException(
1149+
"Unsupported TargetTypeCase of " + target.getTargetTypeCase());
1150+
}
1151+
return builder.build();
1152+
}
1153+
9801154
protected static Optional<Rel.Remap> optionalRelmap(io.substrait.proto.RelCommon relCommon) {
9811155
return Optional.ofNullable(
9821156
relCommon.hasEmit() ? Rel.Remap.of(relCommon.getEmit().getOutputMappingList()) : null);

core/src/main/java/io/substrait/relation/RelCopyOnWriteVisitor.java

Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,14 @@
88
import io.substrait.expression.Expression;
99
import io.substrait.expression.FieldReference;
1010
import io.substrait.expression.FunctionArg;
11+
import io.substrait.relation.physical.BroadcastExchange;
1112
import io.substrait.relation.physical.HashJoin;
1213
import io.substrait.relation.physical.MergeJoin;
14+
import io.substrait.relation.physical.MultiBucketExchange;
1315
import io.substrait.relation.physical.NestedLoopJoin;
16+
import io.substrait.relation.physical.RoundRobinExchange;
17+
import io.substrait.relation.physical.ScatterExchange;
18+
import io.substrait.relation.physical.SingleBucketExchange;
1419
import io.substrait.util.EmptyVisitationContext;
1520
import java.util.List;
1621
import java.util.Optional;
@@ -274,6 +279,91 @@ public Optional<Rel> visit(NamedUpdate update, EmptyVisitationContext context) t
274279
.build());
275280
}
276281

282+
@Override
283+
public Optional<Rel> visit(ScatterExchange exchange, EmptyVisitationContext context) throws E {
284+
Optional<Rel> input = exchange.getInput().accept(this, context);
285+
Optional<List<FieldReference>> fields =
286+
transformList(exchange.getFields(), context, this::visitFieldReference);
287+
288+
if (allEmpty(input, fields)) {
289+
return Optional.empty();
290+
}
291+
292+
return Optional.of(
293+
ScatterExchange.builder()
294+
.from(exchange)
295+
.input(input.orElse(exchange.getInput()))
296+
.fields(fields.orElse(exchange.getFields()))
297+
.build());
298+
}
299+
300+
@Override
301+
public Optional<Rel> visit(SingleBucketExchange exchange, EmptyVisitationContext context)
302+
throws E {
303+
Optional<Rel> input = exchange.getInput().accept(this, context);
304+
305+
Optional<Expression> expression =
306+
exchange.getExpression().accept(getExpressionCopyOnWriteVisitor(), context);
307+
308+
if (allEmpty(input, expression)) {
309+
return Optional.empty();
310+
}
311+
312+
return Optional.of(
313+
SingleBucketExchange.builder()
314+
.from(exchange)
315+
.input(input.orElse(exchange.getInput()))
316+
.expression(expression.orElse(exchange.getExpression()))
317+
.build());
318+
}
319+
320+
@Override
321+
public Optional<Rel> visit(MultiBucketExchange exchange, EmptyVisitationContext context)
322+
throws E {
323+
Optional<Rel> input = exchange.getInput().accept(this, context);
324+
Optional<Expression> expression =
325+
exchange.getExpression().accept(getExpressionCopyOnWriteVisitor(), context);
326+
327+
if (allEmpty(input)) {
328+
return Optional.empty();
329+
}
330+
331+
return Optional.of(
332+
MultiBucketExchange.builder()
333+
.from(exchange)
334+
.input(input.orElse(exchange.getInput()))
335+
.expression(expression.orElse(exchange.getExpression()))
336+
.build());
337+
}
338+
339+
@Override
340+
public Optional<Rel> visit(RoundRobinExchange exchange, EmptyVisitationContext context) throws E {
341+
Optional<Rel> input = exchange.getInput().accept(this, context);
342+
if (allEmpty(input)) {
343+
return Optional.empty();
344+
}
345+
346+
return Optional.of(
347+
RoundRobinExchange.builder()
348+
.from(exchange)
349+
.input(input.orElse(exchange.getInput()))
350+
.build());
351+
}
352+
353+
@Override
354+
public Optional<Rel> visit(BroadcastExchange exchange, EmptyVisitationContext context) throws E {
355+
Optional<Rel> input = exchange.getInput().accept(this, context);
356+
if (allEmpty(input)) {
357+
return Optional.empty();
358+
}
359+
360+
return Optional.of(
361+
BroadcastExchange.builder()
362+
.from(exchange)
363+
.input(input.orElse(exchange.getInput()))
364+
.build());
365+
}
366+
277367
@Override
278368
public Optional<Rel> visit(Sort sort, EmptyVisitationContext context) throws E {
279369
Optional<Rel> input = sort.getInput().accept(this, context);

0 commit comments

Comments
 (0)