|
1 | 1 | package io.substrait.relation; |
2 | 2 |
|
3 | 3 | import io.substrait.expression.Expression; |
| 4 | +import io.substrait.expression.FieldReference; |
4 | 5 | import io.substrait.expression.proto.ProtoExpressionConverter; |
5 | 6 | import io.substrait.extension.AdvancedExtension; |
6 | 7 | import io.substrait.extension.DefaultExtensionCatalog; |
|
18 | 19 | import io.substrait.proto.ConsistentPartitionWindowRel; |
19 | 20 | import io.substrait.proto.CrossRel; |
20 | 21 | import io.substrait.proto.DdlRel; |
| 22 | +import io.substrait.proto.ExchangeRel; |
21 | 23 | import io.substrait.proto.ExpandRel; |
22 | 24 | import io.substrait.proto.ExtensionLeafRel; |
23 | 25 | import io.substrait.proto.ExtensionMultiRel; |
|
37 | 39 | import io.substrait.relation.extensions.EmptyDetail; |
38 | 40 | import io.substrait.relation.files.FileFormat; |
39 | 41 | import io.substrait.relation.files.FileOrFiles; |
| 42 | +import io.substrait.relation.physical.AbstractExchangeRel; |
| 43 | +import io.substrait.relation.physical.BroadcastExchange; |
40 | 44 | import io.substrait.relation.physical.HashJoin; |
| 45 | +import io.substrait.relation.physical.ImmutableBroadcastExchange; |
| 46 | +import io.substrait.relation.physical.ImmutableExchangeTarget; |
| 47 | +import io.substrait.relation.physical.ImmutableMultiBucketExchange; |
| 48 | +import io.substrait.relation.physical.ImmutableRoundRobinExchange; |
| 49 | +import io.substrait.relation.physical.ImmutableScatterExchange; |
| 50 | +import io.substrait.relation.physical.ImmutableSingleBucketExchange; |
41 | 51 | import io.substrait.relation.physical.MergeJoin; |
| 52 | +import io.substrait.relation.physical.MultiBucketExchange; |
42 | 53 | import io.substrait.relation.physical.NestedLoopJoin; |
| 54 | +import io.substrait.relation.physical.RoundRobinExchange; |
| 55 | +import io.substrait.relation.physical.ScatterExchange; |
| 56 | +import io.substrait.relation.physical.SingleBucketExchange; |
| 57 | +import io.substrait.relation.physical.TargetType; |
43 | 58 | import io.substrait.type.NamedStruct; |
44 | 59 | import io.substrait.type.Type; |
45 | 60 | import io.substrait.type.proto.ProtoTypeConverter; |
@@ -163,6 +178,8 @@ public Rel from(io.substrait.proto.Rel rel) { |
163 | 178 | return newDdl(rel.getDdl()); |
164 | 179 | case UPDATE: |
165 | 180 | return newUpdate(rel.getUpdate()); |
| 181 | + case EXCHANGE: |
| 182 | + return newExchange(rel.getExchange()); |
166 | 183 | default: |
167 | 184 | throw new UnsupportedOperationException("Unsupported RelTypeCase of " + relType); |
168 | 185 | } |
@@ -977,6 +994,163 @@ protected ConsistentPartitionWindow newConsistentPartitionWindow( |
977 | 994 | return builder.build(); |
978 | 995 | } |
979 | 996 |
|
| 997 | + protected AbstractExchangeRel newExchange(ExchangeRel rel) { |
| 998 | + ExchangeRel.ExchangeKindCase exchangeKind = rel.getExchangeKindCase(); |
| 999 | + switch (exchangeKind) { |
| 1000 | + case SCATTER_BY_FIELDS: |
| 1001 | + return newScatterExchange(rel); |
| 1002 | + case SINGLE_TARGET: |
| 1003 | + return newSingleBucketExchange(rel); |
| 1004 | + case MULTI_TARGET: |
| 1005 | + return newMultiBucketExchange(rel); |
| 1006 | + case BROADCAST: |
| 1007 | + return newBroadcastExchange(rel); |
| 1008 | + case ROUND_ROBIN: |
| 1009 | + return newRoundRobinExchange(rel); |
| 1010 | + default: |
| 1011 | + throw new UnsupportedOperationException("Unsupported ExchangeKindCase of " + exchangeKind); |
| 1012 | + } |
| 1013 | + } |
| 1014 | + |
| 1015 | + protected ScatterExchange newScatterExchange(ExchangeRel rel) { |
| 1016 | + Rel input = from(rel.getInput()); |
| 1017 | + List<AbstractExchangeRel.ExchangeTarget> targets = |
| 1018 | + rel.getTargetsList().stream().map(this::newExchangeTarget).collect(Collectors.toList()); |
| 1019 | + |
| 1020 | + ProtoExpressionConverter converter = |
| 1021 | + new ProtoExpressionConverter(lookup, extensions, input.getRecordType(), this); |
| 1022 | + List<FieldReference> fieldReferences = |
| 1023 | + rel.getScatterByFields().getFieldsList().stream() |
| 1024 | + .map(converter::from) |
| 1025 | + .collect(Collectors.toList()); |
| 1026 | + |
| 1027 | + ImmutableScatterExchange.Builder builder = |
| 1028 | + ScatterExchange.builder() |
| 1029 | + .input(input) |
| 1030 | + .addAllFields(fieldReferences) |
| 1031 | + .partitionCount(rel.getPartitionCount()) |
| 1032 | + .targets(targets); |
| 1033 | + |
| 1034 | + builder |
| 1035 | + .commonExtension(optionalAdvancedExtension(rel.getCommon())) |
| 1036 | + .remap(optionalRelmap(rel.getCommon())) |
| 1037 | + .hint(optionalHint(rel.getCommon())); |
| 1038 | + if (rel.hasAdvancedExtension()) { |
| 1039 | + builder.extension(protoExtensionConverter.fromProto(rel.getAdvancedExtension())); |
| 1040 | + } |
| 1041 | + return builder.build(); |
| 1042 | + } |
| 1043 | + |
| 1044 | + protected SingleBucketExchange newSingleBucketExchange(ExchangeRel rel) { |
| 1045 | + Rel input = from(rel.getInput()); |
| 1046 | + List<AbstractExchangeRel.ExchangeTarget> targets = |
| 1047 | + rel.getTargetsList().stream().map(this::newExchangeTarget).collect(Collectors.toList()); |
| 1048 | + ProtoExpressionConverter converter = |
| 1049 | + new ProtoExpressionConverter(lookup, extensions, input.getRecordType(), this); |
| 1050 | + |
| 1051 | + ImmutableSingleBucketExchange.Builder builder = |
| 1052 | + SingleBucketExchange.builder() |
| 1053 | + .input(input) |
| 1054 | + .partitionCount(rel.getPartitionCount()) |
| 1055 | + .targets(targets) |
| 1056 | + .expression(converter.from(rel.getSingleTarget().getExpression())); |
| 1057 | + |
| 1058 | + builder |
| 1059 | + .commonExtension(optionalAdvancedExtension(rel.getCommon())) |
| 1060 | + .remap(optionalRelmap(rel.getCommon())) |
| 1061 | + .hint(optionalHint(rel.getCommon())); |
| 1062 | + if (rel.hasAdvancedExtension()) { |
| 1063 | + builder.extension(protoExtensionConverter.fromProto(rel.getAdvancedExtension())); |
| 1064 | + } |
| 1065 | + return builder.build(); |
| 1066 | + } |
| 1067 | + |
| 1068 | + protected MultiBucketExchange newMultiBucketExchange(ExchangeRel rel) { |
| 1069 | + Rel input = from(rel.getInput()); |
| 1070 | + List<AbstractExchangeRel.ExchangeTarget> targets = |
| 1071 | + rel.getTargetsList().stream().map(this::newExchangeTarget).collect(Collectors.toList()); |
| 1072 | + ProtoExpressionConverter converter = |
| 1073 | + new ProtoExpressionConverter(lookup, extensions, input.getRecordType(), this); |
| 1074 | + |
| 1075 | + ImmutableMultiBucketExchange.Builder builder = |
| 1076 | + MultiBucketExchange.builder() |
| 1077 | + .input(input) |
| 1078 | + .partitionCount(rel.getPartitionCount()) |
| 1079 | + .targets(targets) |
| 1080 | + .expression(converter.from(rel.getMultiTarget().getExpression())) |
| 1081 | + .constrainedToCount(rel.getMultiTarget().getConstrainedToCount()); |
| 1082 | + |
| 1083 | + builder |
| 1084 | + .commonExtension(optionalAdvancedExtension(rel.getCommon())) |
| 1085 | + .remap(optionalRelmap(rel.getCommon())) |
| 1086 | + .hint(optionalHint(rel.getCommon())); |
| 1087 | + if (rel.hasAdvancedExtension()) { |
| 1088 | + builder.extension(protoExtensionConverter.fromProto(rel.getAdvancedExtension())); |
| 1089 | + } |
| 1090 | + return builder.build(); |
| 1091 | + } |
| 1092 | + |
| 1093 | + protected RoundRobinExchange newRoundRobinExchange(ExchangeRel rel) { |
| 1094 | + Rel input = from(rel.getInput()); |
| 1095 | + List<AbstractExchangeRel.ExchangeTarget> targets = |
| 1096 | + rel.getTargetsList().stream().map(this::newExchangeTarget).collect(Collectors.toList()); |
| 1097 | + |
| 1098 | + ImmutableRoundRobinExchange.Builder builder = |
| 1099 | + RoundRobinExchange.builder() |
| 1100 | + .input(input) |
| 1101 | + .partitionCount(rel.getPartitionCount()) |
| 1102 | + .targets(targets) |
| 1103 | + .exact(rel.getRoundRobin().getExact()); |
| 1104 | + |
| 1105 | + builder |
| 1106 | + .commonExtension(optionalAdvancedExtension(rel.getCommon())) |
| 1107 | + .remap(optionalRelmap(rel.getCommon())) |
| 1108 | + .hint(optionalHint(rel.getCommon())); |
| 1109 | + if (rel.hasAdvancedExtension()) { |
| 1110 | + builder.extension(protoExtensionConverter.fromProto(rel.getAdvancedExtension())); |
| 1111 | + } |
| 1112 | + return builder.build(); |
| 1113 | + } |
| 1114 | + |
| 1115 | + protected BroadcastExchange newBroadcastExchange(ExchangeRel rel) { |
| 1116 | + Rel input = from(rel.getInput()); |
| 1117 | + List<AbstractExchangeRel.ExchangeTarget> targets = |
| 1118 | + rel.getTargetsList().stream().map(this::newExchangeTarget).collect(Collectors.toList()); |
| 1119 | + |
| 1120 | + ImmutableBroadcastExchange.Builder builder = |
| 1121 | + BroadcastExchange.builder() |
| 1122 | + .input(input) |
| 1123 | + .partitionCount(rel.getPartitionCount()) |
| 1124 | + .targets(targets); |
| 1125 | + |
| 1126 | + builder |
| 1127 | + .commonExtension(optionalAdvancedExtension(rel.getCommon())) |
| 1128 | + .remap(optionalRelmap(rel.getCommon())) |
| 1129 | + .hint(optionalHint(rel.getCommon())); |
| 1130 | + if (rel.hasAdvancedExtension()) { |
| 1131 | + builder.extension(protoExtensionConverter.fromProto(rel.getAdvancedExtension())); |
| 1132 | + } |
| 1133 | + return builder.build(); |
| 1134 | + } |
| 1135 | + |
| 1136 | + protected AbstractExchangeRel.ExchangeTarget newExchangeTarget( |
| 1137 | + ExchangeRel.ExchangeTarget target) { |
| 1138 | + ImmutableExchangeTarget.Builder builder = AbstractExchangeRel.ExchangeTarget.builder(); |
| 1139 | + builder.addAllPartitionIds(target.getPartitionIdList()); |
| 1140 | + switch (target.getTargetTypeCase()) { |
| 1141 | + case URI: |
| 1142 | + builder.type(TargetType.Uri.builder().uri(target.getUri()).build()); |
| 1143 | + break; |
| 1144 | + case EXTENDED: |
| 1145 | + builder.type(TargetType.Extended.builder().extended(target.getExtended()).build()); |
| 1146 | + break; |
| 1147 | + default: |
| 1148 | + throw new UnsupportedOperationException( |
| 1149 | + "Unsupported TargetTypeCase of " + target.getTargetTypeCase()); |
| 1150 | + } |
| 1151 | + return builder.build(); |
| 1152 | + } |
| 1153 | + |
980 | 1154 | protected static Optional<Rel.Remap> optionalRelmap(io.substrait.proto.RelCommon relCommon) { |
981 | 1155 | return Optional.ofNullable( |
982 | 1156 | relCommon.hasEmit() ? Rel.Remap.of(relCommon.getEmit().getOutputMappingList()) : null); |
|
0 commit comments