Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 0 additions & 66 deletions cpp-ch/local-engine/Parser/SerializedPlanParser.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -121,72 +121,6 @@ void SerializedPlanParser::adjustOutput(const DB::QueryPlanPtr & query_plan, con
aliases.emplace_back(DB::NameWithAlias(input_iter->name, *output_name));
});
}

// fixes: issue-1874, to keep the nullability as expected.
const auto & output_schema = root_rel.root().output_schema();
if (output_schema.types_size())
{
const auto & origin_header = *query_plan->getCurrentHeader();
const auto & origin_columns = origin_header.getColumnsWithTypeAndName();

if (static_cast<size_t>(output_schema.types_size()) != origin_columns.size())
{
debug::dumpPlan(*query_plan, "clickhouse plan", true);
debug::dumpMessage(plan, "substrait::Plan", true);
throw Exception(
ErrorCodes::LOGICAL_ERROR,
"Missmatch result columns size. plan column size {}, subtrait plan output schema size {}, subtrait plan name size {}.",
origin_columns.size(),
output_schema.types_size(),
root_rel.root().names_size());
}

bool need_final_project = false;
ColumnsWithTypeAndName final_columns;
for (int i = 0; i < output_schema.types_size(); ++i)
{
const auto & origin_column = origin_columns[i];
const auto & origin_type = origin_column.type;
auto final_type = TypeParser::parseType(output_schema.types(i));

/// Intermediate aggregate data is special, no check here.
if (typeid_cast<const DataTypeAggregateFunction *>(origin_column.type.get()) || origin_type->equals(*final_type))
final_columns.push_back(origin_column);
else
{
need_final_project = true;
if (origin_column.column && isColumnConst(*origin_column.column))
{
/// For const column, we need to cast it individually. Otherwise, the const column will be converted to full column in
/// ActionsDAG::makeConvertingActions.
/// Note: creating fianl_column with Field of origin_column will cause Exception in some case.
const DB::ContextPtr context = DB::CurrentThread::get().getQueryContext();
const FunctionOverloadResolverPtr & cast_resolver = FunctionFactory::instance().get("CAST", context);
const DataTypePtr string_type = std::make_shared<DataTypeString>();
ColumnWithTypeAndName to_type_column = {string_type->createColumnConst(1, final_type->getName()), string_type, "__cast_const__"};
FunctionBasePtr cast_function = cast_resolver->build({origin_column, to_type_column});
ColumnPtr const_col = ColumnConst::create(cast_function->execute({origin_column, to_type_column}, final_type, 1, false), 1);
ColumnWithTypeAndName final_column(const_col, final_type, origin_column.name);
final_columns.emplace_back(std::move(final_column));
}
else
{
ColumnWithTypeAndName final_column(final_type->createColumn(), final_type, origin_column.name);
final_columns.emplace_back(std::move(final_column));
}
}
}

if (need_final_project)
{
ActionsDAG final_project
= ActionsDAG::makeConvertingActions(origin_columns, final_columns, ActionsDAG::MatchColumnsMode::Position, true);
QueryPlanStepPtr final_project_step
= std::make_unique<ExpressionStep>(query_plan->getCurrentHeader(), std::move(final_project));
final_project_step->setStepDescription("Project for output schema");
query_plan->addStep(std::move(final_project_step));
}
}
}

QueryPlanPtr SerializedPlanParser::parse(const substrait::Plan & plan)
Expand Down
1 change: 0 additions & 1 deletion docs/developers/SubstraitModifications.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@ alternatives like `AdvancedExtension` could be considered.
* Changed join type `JOIN_TYPE_SEMI` to `JOIN_TYPE_LEFT_SEMI` and `JOIN_TYPE_RIGHT_SEMI`([#408](https://github.com/apache/incubator-gluten/pull/408)).
* Added `WindowRel`, added `column_name` and `window_type` in `WindowFunction`,
changed `Unbounded` in `WindowFunction` into `Unbounded_Preceding` and `Unbounded_Following`, and added WindowType([#485](https://github.com/apache/incubator-gluten/pull/485)).
* Added `output_schema` in RelRoot([#1901](https://github.com/apache/incubator-gluten/pull/1901)).
* Added `ExpandRel`([#1361](https://github.com/apache/incubator-gluten/pull/1361)).
* Added `GenerateRel`([#574](https://github.com/apache/incubator-gluten/pull/574)).
* Added `PartitionColumn` in `LocalFiles`([#2405](https://github.com/apache/incubator-gluten/pull/2405)).
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
import org.apache.gluten.substrait.extensions.ExtensionBuilder;
import org.apache.gluten.substrait.extensions.FunctionMappingNode;
import org.apache.gluten.substrait.rel.RelNode;
import org.apache.gluten.substrait.type.TypeNode;

import com.google.common.base.Preconditions;

Expand All @@ -39,21 +38,19 @@ public static PlanNode makePlan(
List<FunctionMappingNode> mappingNodes,
List<RelNode> relNodes,
List<String> outNames,
TypeNode outputSchema,
AdvancedExtensionNode extension) {
return new PlanNode(mappingNodes, relNodes, outNames, outputSchema, extension);
return new PlanNode(mappingNodes, relNodes, outNames, extension);
}

public static PlanNode makePlan(
SubstraitContext subCtx, List<RelNode> relNodes, List<String> outNames) {
return makePlan(subCtx, relNodes, outNames, null, null);
return makePlan(subCtx, relNodes, outNames, null);
}

public static PlanNode makePlan(
SubstraitContext subCtx,
List<RelNode> relNodes,
List<String> outNames,
TypeNode outputSchema,
AdvancedExtensionNode extension) {
Preconditions.checkNotNull(
subCtx, "Cannot execute doTransform due to the SubstraitContext is null.");
Expand All @@ -64,7 +61,7 @@ public static PlanNode makePlan(
ExtensionBuilder.makeFunctionMapping(entry.getKey(), entry.getValue());
mappingNodes.add(mappingNode);
}
return makePlan(mappingNodes, relNodes, outNames, outputSchema, extension);
return makePlan(mappingNodes, relNodes, outNames, extension);
}

public static PlanNode makePlan(SubstraitContext subCtx, ArrayList<RelNode> relNodes) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
import org.apache.gluten.substrait.extensions.AdvancedExtensionNode;
import org.apache.gluten.substrait.extensions.FunctionMappingNode;
import org.apache.gluten.substrait.rel.RelNode;
import org.apache.gluten.substrait.type.TypeNode;

import io.substrait.proto.Plan;
import io.substrait.proto.PlanRel;
Expand All @@ -33,19 +32,16 @@ public class PlanNode implements Serializable {
private final List<RelNode> relNodes;
private final List<String> outNames;

private TypeNode outputSchema = null;
private AdvancedExtensionNode extension = null;

PlanNode(
List<FunctionMappingNode> mappingNodes,
List<RelNode> relNodes,
List<String> outNames,
TypeNode outputSchema,
AdvancedExtensionNode extension) {
this.mappingNodes = mappingNodes;
this.relNodes = relNodes;
this.outNames = outNames;
this.outputSchema = outputSchema;
this.extension = extension;
}

Expand All @@ -64,9 +60,6 @@ public Plan toProtobuf() {
for (String name : outNames) {
relRootBuilder.addNames(name);
}
if (outputSchema != null) {
relRootBuilder.setOutputSchema(outputSchema.toProtobuf().getStruct());
}
planRelBuilder.setRoot(relRootBuilder.build());

planBuilder.addRelations(planRelBuilder.build());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -523,7 +523,6 @@ message RelRoot {
Rel input = 1;
// Field names in depth-first order
repeated string names = 2;
Type.Struct output_schema = 3;
}

// A relation (used internally in a plan)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,10 @@ import org.apache.gluten.config.GlutenConfig
import org.apache.gluten.expression._
import org.apache.gluten.extension.columnar.transition.Convention
import org.apache.gluten.metrics.{GlutenTimeMetric, MetricsUpdater}
import org.apache.gluten.substrait.`type`.{TypeBuilder, TypeNode}
import org.apache.gluten.substrait.SubstraitContext
import org.apache.gluten.substrait.expression.{ExpressionBuilder, ExpressionNode}
import org.apache.gluten.substrait.plan.{PlanBuilder, PlanNode}
import org.apache.gluten.substrait.rel.{LocalFilesNode, RelNode, SplitInfo}
import org.apache.gluten.substrait.rel.{LocalFilesNode, RelBuilder, RelNode, SplitInfo}
import org.apache.gluten.substrait.rel.LocalFilesNode.ReadFileFormat
import org.apache.gluten.utils.SubstraitPlanPrinterUtil

Expand Down Expand Up @@ -172,24 +172,53 @@ case class WholeStageTransformer(child: SparkPlan, materializeInput: Boolean = f
@transient
private var wholeStageTransformerContext: Option[WholeStageTransformContext] = None

private var outputSchemaForPlan: Option[TypeNode] = None
private var expectedOutputForPlan: Option[Seq[Attribute]] = None

private def inferSchemaFromAttributes(attrs: Seq[Attribute]): TypeNode = {
val outputTypeNodeList = new java.util.ArrayList[TypeNode]()
for (attr <- attrs) {
outputTypeNodeList.add(ConverterUtils.getTypeNode(attr.dataType, attr.nullable))
def setOutputSchemaForPlan(expectOutput: Seq[Attribute]): Unit = {
if (expectedOutputForPlan.isDefined) {
return
}

TypeBuilder.makeStruct(false, outputTypeNodeList)
// Fixes issue-1874: store expected output attributes for generating a ProjectRel with casts.
expectedOutputForPlan = Some(expectOutput)
}

def setOutputSchemaForPlan(expectOutput: Seq[Attribute]): Unit = {
if (outputSchemaForPlan.isDefined) {
return
/**
* Creates a ProjectRel that casts each input column to the expected output type. This is used to
* enforce nullability and type constraints when the child plan's output may not match the
* expected schema (e.g., in union operations). Returns the input unchanged if no casts are
* needed.
*/
private def createOutputCastProjectRel(
input: RelNode,
inputAttrs: Seq[Attribute],
expectedAttrs: Seq[Attribute],
substraitContext: SubstraitContext): RelNode = {
val castExpressions = new java.util.ArrayList[ExpressionNode]()
var needsCast = false
for (i <- inputAttrs.indices) {
val inputAttr = inputAttrs(i)
val expectedAttr = expectedAttrs(i)
val fieldRef = ExpressionBuilder.makeSelection(i)
// If types differ (including nullability), add a cast; otherwise pass through.
if (
inputAttr.dataType != expectedAttr.dataType ||
inputAttr.nullable != expectedAttr.nullable
) {
val targetType = ConverterUtils.getTypeNode(expectedAttr.dataType, expectedAttr.nullable)
castExpressions.add(ExpressionBuilder.makeCast(targetType, fieldRef, false))
needsCast = true
} else {
castExpressions.add(fieldRef)
}
}
// Only create a ProjectRel if casts are actually needed.
if (needsCast) {
// Use emitStartIndex = 0 to emit only the projected expressions (not input + expressions).
RelBuilder.makeProjectRel(input, castExpressions, substraitContext, -1L, 0)
} else {
input
}

// Fixes issue-1874
outputSchemaForPlan = Some(inferSchemaFromAttributes(expectOutput))
}

def substraitPlan: PlanNode = {
Expand Down Expand Up @@ -241,21 +270,27 @@ case class WholeStageTransformer(child: SparkPlan, materializeInput: Boolean = f
throw new IllegalStateException(s"WholeStageTransformer can't do Transform on $child")
}

val outNames = childCtx.outputAttributes.map(ConverterUtils.genColumnNameWithExprId).asJava

val planNode = if (BackendsApiManager.getSettings.needOutputSchemaForPlan()) {
val outputSchema =
outputSchemaForPlan.getOrElse(inferSchemaFromAttributes(childCtx.outputAttributes))
val (finalRoot, finalOutputAttrs) =
if (BackendsApiManager.getSettings.needOutputSchemaForPlan()) {
// If expected output schema differs from child's output, wrap in a ProjectRel with casts.
// This fixes issue-1874 by explicitly converting types (including nullability) in the plan.
expectedOutputForPlan match {
case Some(expectedAttrs) =>
val projectRel = createOutputCastProjectRel(
childCtx.root,
childCtx.outputAttributes,
expectedAttrs,
substraitContext)
(projectRel, expectedAttrs)
case None =>
(childCtx.root, childCtx.outputAttributes)
}
} else {
(childCtx.root, childCtx.outputAttributes)
}

PlanBuilder.makePlan(
substraitContext,
Lists.newArrayList(childCtx.root),
outNames,
outputSchema,
null)
} else {
PlanBuilder.makePlan(substraitContext, Lists.newArrayList(childCtx.root), outNames)
}
val outNames = finalOutputAttrs.map(ConverterUtils.genColumnNameWithExprId).asJava
val planNode = PlanBuilder.makePlan(substraitContext, Lists.newArrayList(finalRoot), outNames)

WholeStageTransformContext(planNode, substraitContext, isCudf)
}
Expand Down
Loading