From b33e504a0af766247b0be7437d77fa1b0d5755a9 Mon Sep 17 00:00:00 2001 From: Sokwhan Huh Date: Thu, 2 Apr 2026 20:33:18 -0700 Subject: [PATCH] Refactor tests to inject a runtime environment to invoke planner and legacy tests PiperOrigin-RevId: 893850002 --- .../src/test/java/dev/cel/bundle/BUILD.bazel | 1 + .../test/java/dev/cel/bundle/CelImplTest.java | 60 +++-- .../main/java/dev/cel/extensions/BUILD.bazel | 1 + .../cel/extensions/CelBindingsExtensions.java | 14 +- .../CelComprehensionsExtensions.java | 35 +-- .../extensions/CelBindingsExtensionsTest.java | 3 +- .../dev/cel/optimizer/optimizers/BUILD.bazel | 4 + .../ConstantFoldingOptimizerTest.java | 151 ++++++------- .../SubexpressionOptimizerBaselineTest.java | 170 ++++++++------ .../SubexpressionOptimizerTest.java | 207 +++++++++++++----- ...old_before_subexpression_unparsed.baseline | 2 +- .../resources/subexpression_unparsed.baseline | 2 +- .../java/dev/cel/runtime/planner/BUILD.bazel | 62 +++--- .../cel/runtime/planner/BlockMemoizer.java | 72 ++++++ .../dev/cel/runtime/planner/EvalBlock.java | 67 ++++++ .../cel/runtime/planner/ExecutionFrame.java | 12 + .../cel/runtime/planner/ProgramPlanner.java | 59 ++++- testing/BUILD.bazel | 5 + .../src/main/java/dev/cel/testing/BUILD.bazel | 9 + .../dev/cel/testing/CelRuntimeFlavor.java | 38 ++++ 20 files changed, 672 insertions(+), 302 deletions(-) create mode 100644 runtime/src/main/java/dev/cel/runtime/planner/BlockMemoizer.java create mode 100644 runtime/src/main/java/dev/cel/runtime/planner/EvalBlock.java create mode 100644 testing/src/main/java/dev/cel/testing/CelRuntimeFlavor.java diff --git a/bundle/src/test/java/dev/cel/bundle/BUILD.bazel b/bundle/src/test/java/dev/cel/bundle/BUILD.bazel index 2901e1ff9..4d1d239db 100644 --- a/bundle/src/test/java/dev/cel/bundle/BUILD.bazel +++ b/bundle/src/test/java/dev/cel/bundle/BUILD.bazel @@ -56,6 +56,7 @@ java_library( "//runtime:evaluation_listener", "//runtime:function_binding", "//runtime:unknown_attributes", + "//testing:cel_runtime_flavor", "//testing/protos:single_file_extension_java_proto", "//testing/protos:single_file_java_proto", "@cel_spec//proto/cel/expr:checked_java_proto", diff --git a/bundle/src/test/java/dev/cel/bundle/CelImplTest.java b/bundle/src/test/java/dev/cel/bundle/CelImplTest.java index 22ef7e2f4..a3ad60d40 100644 --- a/bundle/src/test/java/dev/cel/bundle/CelImplTest.java +++ b/bundle/src/test/java/dev/cel/bundle/CelImplTest.java @@ -114,6 +114,7 @@ import dev.cel.runtime.CelUnknownSet; import dev.cel.runtime.CelVariableResolver; import dev.cel.runtime.UnknownContext; +import dev.cel.testing.CelRuntimeFlavor; import dev.cel.testing.testdata.SingleFile; import dev.cel.testing.testdata.SingleFileExtensionsProto; import dev.cel.testing.testdata.proto3.StandaloneGlobalEnum; @@ -2144,8 +2145,9 @@ public void toBuilder_isImmutable() { } @Test - public void eval_withJsonFieldName(@TestParameter RuntimeEnv runtimeEnv) throws Exception { - Cel cel = runtimeEnv.cel; + public void eval_withJsonFieldName(@TestParameter CelRuntimeFlavor runtimeFlavor) + throws Exception { + Cel cel = setupEnv(runtimeFlavor.builder()); CelAbstractSyntaxTree ast = cel.compile( "file.int32_snake_case_json_name == 1 && " @@ -2176,8 +2178,9 @@ public void eval_withJsonFieldName(@TestParameter RuntimeEnv runtimeEnv) throws } @Test - public void eval_withJsonFieldName_fieldsFallBack(@TestParameter RuntimeEnv runtimeEnv) throws Exception { - Cel cel = runtimeEnv.cel; + public void eval_withJsonFieldName_fieldsFallBack(@TestParameter CelRuntimeFlavor runtimeFlavor) + throws Exception { + Cel cel = setupEnv(runtimeFlavor.builder()); CelAbstractSyntaxTree ast = cel.compile( "dyn(file).int32_snake_case_json_name == 1 && " @@ -2206,8 +2209,9 @@ public void eval_withJsonFieldName_fieldsFallBack(@TestParameter RuntimeEnv runt } @Test - public void eval_withJsonFieldName_extensionFields(@TestParameter RuntimeEnv runtimeEnv) throws Exception { - Cel cel = runtimeEnv.cel; + public void eval_withJsonFieldName_extensionFields(@TestParameter CelRuntimeFlavor runtimeFlavor) + throws Exception { + Cel cel = setupEnv(runtimeFlavor.builder()); CelAbstractSyntaxTree ast = cel.compile( "proto.getExt(file, dev.cel.testing.testdata.int64CamelCaseJsonName) == 5 &&" @@ -2317,33 +2321,21 @@ private static TypeProvider aliasingProvider(ImmutableMap typeAlia }; } - private enum RuntimeEnv { - LEGACY(setupEnv(CelFactory.standardCelBuilder())), - PLANNER(setupEnv(CelExperimentalFactory.plannerCelBuilder())) - ; - - private final Cel cel; - - private static Cel setupEnv(CelBuilder celBuilder) { - ExtensionRegistry extensionRegistry = ExtensionRegistry.newInstance(); - SingleFileExtensionsProto.registerAllExtensions(extensionRegistry); - return celBuilder - .addVar("file", StructTypeReference.create(SingleFile.getDescriptor().getFullName())) - .addMessageTypes(SingleFile.getDescriptor()) - .addFileTypes(SingleFileExtensionsProto.getDescriptor()) - .addCompilerLibraries(CelExtensions.protos()) - .setExtensionRegistry(extensionRegistry) - .setOptions( - CelOptions.current() - .enableJsonFieldNames(true) - .enableHeterogeneousNumericComparisons(true) - .enableQuotedIdentifierSyntax(true) - .build()) - .build(); - } - - RuntimeEnv(Cel cel) { - this.cel = cel; - } + private static Cel setupEnv(CelBuilder celBuilder) { + ExtensionRegistry extensionRegistry = ExtensionRegistry.newInstance(); + SingleFileExtensionsProto.registerAllExtensions(extensionRegistry); + return celBuilder + .addVar("file", StructTypeReference.create(SingleFile.getDescriptor().getFullName())) + .addMessageTypes(SingleFile.getDescriptor()) + .addFileTypes(SingleFileExtensionsProto.getDescriptor()) + .addCompilerLibraries(CelExtensions.protos()) + .setExtensionRegistry(extensionRegistry) + .setOptions( + CelOptions.current() + .enableJsonFieldNames(true) + .enableHeterogeneousNumericComparisons(true) + .enableQuotedIdentifierSyntax(true) + .build()) + .build(); } } diff --git a/extensions/src/main/java/dev/cel/extensions/BUILD.bazel b/extensions/src/main/java/dev/cel/extensions/BUILD.bazel index ed2d19d6f..77663f2fa 100644 --- a/extensions/src/main/java/dev/cel/extensions/BUILD.bazel +++ b/extensions/src/main/java/dev/cel/extensions/BUILD.bazel @@ -142,6 +142,7 @@ java_library( deps = [ "//common:compiler_common", "//common/ast", + "//common/types", "//compiler:compiler_builder", "//extensions:extension_library", "//parser:macro", diff --git a/extensions/src/main/java/dev/cel/extensions/CelBindingsExtensions.java b/extensions/src/main/java/dev/cel/extensions/CelBindingsExtensions.java index 5eb2c2e8c..0e6537334 100644 --- a/extensions/src/main/java/dev/cel/extensions/CelBindingsExtensions.java +++ b/extensions/src/main/java/dev/cel/extensions/CelBindingsExtensions.java @@ -22,7 +22,11 @@ import com.google.errorprone.annotations.Immutable; import dev.cel.common.CelFunctionDecl; import dev.cel.common.CelIssue; +import dev.cel.common.CelOverloadDecl; import dev.cel.common.ast.CelExpr; +import dev.cel.common.types.ListType; +import dev.cel.common.types.SimpleType; +import dev.cel.common.types.TypeParamType; import dev.cel.compiler.CelCompilerLibrary; import dev.cel.parser.CelMacro; import dev.cel.parser.CelMacroExprFactory; @@ -62,7 +66,15 @@ public int version() { @Override public ImmutableSet functions() { - return ImmutableSet.of(); + // TODO: Add bindings for block once decorator support is available. + return ImmutableSet.of( + CelFunctionDecl.newFunctionDeclaration( + "cel.@block", + CelOverloadDecl.newGlobalOverload( + "cel_block_list", + TypeParamType.create("T"), + ListType.create(SimpleType.DYN), + TypeParamType.create("T")))); } @Override diff --git a/extensions/src/main/java/dev/cel/extensions/CelComprehensionsExtensions.java b/extensions/src/main/java/dev/cel/extensions/CelComprehensionsExtensions.java index 23663f02e..7c298a773 100644 --- a/extensions/src/main/java/dev/cel/extensions/CelComprehensionsExtensions.java +++ b/extensions/src/main/java/dev/cel/extensions/CelComprehensionsExtensions.java @@ -118,29 +118,18 @@ public void setRuntimeOptions(CelRuntimeBuilder runtimeBuilder) { @Override public void setRuntimeOptions( CelRuntimeBuilder runtimeBuilder, RuntimeEquality runtimeEquality, CelOptions celOptions) { - for (Function function : functions) { - for (CelOverloadDecl overload : function.functionDecl.overloads()) { - switch (overload.overloadId()) { - case MAP_INSERT_OVERLOAD_MAP_MAP: - runtimeBuilder.addFunctionBindings( - CelFunctionBinding.from( - MAP_INSERT_OVERLOAD_MAP_MAP, - Map.class, - Map.class, - (map1, map2) -> mapInsertMap(map1, map2, runtimeEquality))); - break; - case MAP_INSERT_OVERLOAD_KEY_VALUE: - runtimeBuilder.addFunctionBindings( - CelFunctionBinding.from( - MAP_INSERT_OVERLOAD_KEY_VALUE, - ImmutableList.of(Map.class, Object.class, Object.class), - args -> mapInsertKeyValue(args, runtimeEquality))); - break; - default: - // Nothing to add. - } - } - } + runtimeBuilder.addFunctionBindings( + CelFunctionBinding.fromOverloads( + MAP_INSERT_FUNCTION, + CelFunctionBinding.from( + MAP_INSERT_OVERLOAD_MAP_MAP, + Map.class, + Map.class, + (map1, map2) -> mapInsertMap(map1, map2, runtimeEquality)), + CelFunctionBinding.from( + MAP_INSERT_OVERLOAD_KEY_VALUE, + ImmutableList.of(Map.class, Object.class, Object.class), + args -> mapInsertKeyValue(args, runtimeEquality)))); } @Override diff --git a/extensions/src/test/java/dev/cel/extensions/CelBindingsExtensionsTest.java b/extensions/src/test/java/dev/cel/extensions/CelBindingsExtensionsTest.java index bc98c9816..ff9e31432 100644 --- a/extensions/src/test/java/dev/cel/extensions/CelBindingsExtensionsTest.java +++ b/extensions/src/test/java/dev/cel/extensions/CelBindingsExtensionsTest.java @@ -63,7 +63,8 @@ public void library() { CelExtensions.getExtensionLibrary("bindings", CelOptions.DEFAULT); assertThat(library.name()).isEqualTo("bindings"); assertThat(library.latest().version()).isEqualTo(0); - assertThat(library.version(0).functions()).isEmpty(); + assertThat(library.version(0).functions().stream().map(CelFunctionDecl::name)) + .containsExactly("cel.@block"); assertThat(library.version(0).macros().stream().map(CelMacro::getFunction)) .containsExactly("bind"); } diff --git a/optimizer/src/test/java/dev/cel/optimizer/optimizers/BUILD.bazel b/optimizer/src/test/java/dev/cel/optimizer/optimizers/BUILD.bazel index b0c48682a..445ec6c20 100644 --- a/optimizer/src/test/java/dev/cel/optimizer/optimizers/BUILD.bazel +++ b/optimizer/src/test/java/dev/cel/optimizer/optimizers/BUILD.bazel @@ -33,7 +33,11 @@ java_library( "//parser:unparser", "//runtime", "//runtime:function_binding", + "//runtime:partial_vars", + "//runtime:program", + "//runtime:unknown_attributes", "//testing:baseline_test_case", + "//testing:cel_runtime_flavor", "@maven//:junit_junit", "@maven//:com_google_testparameterinjector_test_parameter_injector", "//:java_truth", diff --git a/optimizer/src/test/java/dev/cel/optimizer/optimizers/ConstantFoldingOptimizerTest.java b/optimizer/src/test/java/dev/cel/optimizer/optimizers/ConstantFoldingOptimizerTest.java index bbb5c6e7e..33dc2d941 100644 --- a/optimizer/src/test/java/dev/cel/optimizer/optimizers/ConstantFoldingOptimizerTest.java +++ b/optimizer/src/test/java/dev/cel/optimizer/optimizers/ConstantFoldingOptimizerTest.java @@ -23,8 +23,6 @@ import com.google.testing.junit.testparameterinjector.TestParameters; import dev.cel.bundle.Cel; import dev.cel.bundle.CelBuilder; -import dev.cel.bundle.CelExperimentalFactory; -import dev.cel.bundle.CelFactory; import dev.cel.common.CelAbstractSyntaxTree; import dev.cel.common.CelContainer; import dev.cel.common.CelFunctionDecl; @@ -44,6 +42,8 @@ import dev.cel.parser.CelUnparser; import dev.cel.parser.CelUnparserFactory; import dev.cel.runtime.CelFunctionBinding; +import dev.cel.testing.CelRuntimeFlavor; +import org.junit.Before; import org.junit.Test; import org.junit.runner.RunWith; @@ -57,72 +57,57 @@ public class ConstantFoldingOptimizerTest { private static final CelUnparser CEL_UNPARSER = CelUnparserFactory.newUnparser(); - @SuppressWarnings("ImmutableEnumChecker") // test only - private enum RuntimeEnv { - LEGACY(setupEnv(CelFactory.standardCelBuilder())), - PLANNER(setupEnv(CelExperimentalFactory.plannerCelBuilder())); - - private final Cel cel; - private final CelOptimizer celOptimizer; - - private static Cel setupEnv(CelBuilder celBuilder) { - return celBuilder - .addVar("x", SimpleType.DYN) - .addVar("y", SimpleType.DYN) - .addVar("list_var", ListType.create(SimpleType.STRING)) - .addVar("map_var", MapType.create(SimpleType.STRING, SimpleType.STRING)) - .setStandardMacros(CelStandardMacro.STANDARD_MACROS) - .addFunctionDeclarations( - CelFunctionDecl.newFunctionDeclaration( - "get_true", - CelOverloadDecl.newGlobalOverload("get_true_overload", SimpleType.BOOL)), - CelFunctionDecl.newFunctionDeclaration( - "get_list", - CelOverloadDecl.newGlobalOverload( - "get_list_overload", - ListType.create(SimpleType.INT), - ListType.create(SimpleType.INT)))) - .addFunctionBindings( - CelFunctionBinding.from("get_true_overload", ImmutableList.of(), unused -> true)) - .addMessageTypes(TestAllTypes.getDescriptor()) - .setContainer(CelContainer.ofName("cel.expr.conformance.proto3")) - .setOptions(CEL_OPTIONS) - .addCompilerLibraries( - CelExtensions.bindings(), - CelOptionalLibrary.INSTANCE, - CelExtensions.math(CEL_OPTIONS), - CelExtensions.strings(), - CelExtensions.sets(CEL_OPTIONS), - CelExtensions.encoders(CEL_OPTIONS)) - .addRuntimeLibraries( - CelOptionalLibrary.INSTANCE, - CelExtensions.math(CEL_OPTIONS), - CelExtensions.strings(), - CelExtensions.sets(CEL_OPTIONS), - CelExtensions.encoders(CEL_OPTIONS)) - .build(); - } - - RuntimeEnv(Cel cel) { - this.cel = cel; - this.celOptimizer = - CelOptimizerFactory.standardCelOptimizerBuilder(cel) - .addAstOptimizers(ConstantFoldingOptimizer.getInstance()) - .build(); - } - - private CelBuilder newCelBuilder() { - switch (this) { - case LEGACY: - return CelFactory.standardCelBuilder(); - case PLANNER: - return CelExperimentalFactory.plannerCelBuilder(); - } - throw new AssertionError("Unknown RuntimeEnv: " + this); - } + @TestParameter CelRuntimeFlavor runtimeFlavor; + + private Cel cel; + private CelOptimizer celOptimizer; + + @Before + public void setUp() { + this.cel = setupEnv(runtimeFlavor.builder()); + this.celOptimizer = + CelOptimizerFactory.standardCelOptimizerBuilder(this.cel) + .addAstOptimizers(ConstantFoldingOptimizer.getInstance()) + .build(); } - @TestParameter RuntimeEnv runtimeEnv; + private static Cel setupEnv(CelBuilder celBuilder) { + return celBuilder + .addVar("x", SimpleType.DYN) + .addVar("y", SimpleType.DYN) + .addVar("list_var", ListType.create(SimpleType.STRING)) + .addVar("map_var", MapType.create(SimpleType.STRING, SimpleType.STRING)) + .setStandardMacros(CelStandardMacro.STANDARD_MACROS) + .addFunctionDeclarations( + CelFunctionDecl.newFunctionDeclaration( + "get_true", + CelOverloadDecl.newGlobalOverload("get_true_overload", SimpleType.BOOL)), + CelFunctionDecl.newFunctionDeclaration( + "get_list", + CelOverloadDecl.newGlobalOverload( + "get_list_overload", + ListType.create(SimpleType.INT), + ListType.create(SimpleType.INT)))) + .addFunctionBindings( + CelFunctionBinding.from("get_true_overload", ImmutableList.of(), unused -> true)) + .addMessageTypes(TestAllTypes.getDescriptor()) + .setContainer(CelContainer.ofName("cel.expr.conformance.proto3")) + .setOptions(CEL_OPTIONS) + .addCompilerLibraries( + CelExtensions.bindings(), + CelOptionalLibrary.INSTANCE, + CelExtensions.math(CEL_OPTIONS), + CelExtensions.strings(), + CelExtensions.sets(CEL_OPTIONS), + CelExtensions.encoders(CEL_OPTIONS)) + .addRuntimeLibraries( + CelOptionalLibrary.INSTANCE, + CelExtensions.math(CEL_OPTIONS), + CelExtensions.strings(), + CelExtensions.sets(CEL_OPTIONS), + CelExtensions.encoders(CEL_OPTIONS)) + .build(); + } @Test @TestParameters("{source: 'null', expected: 'null'}") @@ -270,9 +255,9 @@ private CelBuilder newCelBuilder() { // TODO: Support folding lists with mixed types. This requires mutable lists. // @TestParameters("{source: 'dyn([1]) + [1.0]'}") public void constantFold_success(String source, String expected) throws Exception { - CelAbstractSyntaxTree ast = runtimeEnv.cel.compile(source).getAst(); + CelAbstractSyntaxTree ast = cel.compile(source).getAst(); - CelAbstractSyntaxTree optimizedAst = runtimeEnv.celOptimizer.optimize(ast); + CelAbstractSyntaxTree optimizedAst = celOptimizer.optimize(ast); assertThat(CEL_UNPARSER.unparse(optimizedAst)).isEqualTo(expected); } @@ -317,8 +302,8 @@ public void constantFold_success(String source, String expected) throws Exceptio public void constantFold_macros_macroCallMetadataPopulated(String source, String expected) throws Exception { Cel cel = - runtimeEnv - .newCelBuilder() + runtimeFlavor + .builder() .addVar("x", SimpleType.DYN) .addVar("y", SimpleType.DYN) .addMessageTypes(TestAllTypes.getDescriptor()) @@ -363,8 +348,8 @@ public void constantFold_macros_macroCallMetadataPopulated(String source, String @TestParameters("{source: 'false ? false : cel.bind(a, true, a)'}") public void constantFold_macros_withoutMacroCallMetadata(String source) throws Exception { Cel cel = - runtimeEnv - .newCelBuilder() + runtimeFlavor + .builder() .addVar("x", SimpleType.DYN) .addVar("y", SimpleType.DYN) .addMessageTypes(TestAllTypes.getDescriptor()) @@ -418,20 +403,20 @@ public void constantFold_macros_withoutMacroCallMetadata(String source) throws E @TestParameters("{source: 'get_list([1, 2]).map(x, x * 2)'}") @TestParameters("{source: '[(x - 1 > 3) ? (x - 1) : 5].exists(x, x - 1 > 3)'}") public void constantFold_noOp(String source) throws Exception { - CelAbstractSyntaxTree ast = runtimeEnv.cel.compile(source).getAst(); + CelAbstractSyntaxTree ast = cel.compile(source).getAst(); - CelAbstractSyntaxTree optimizedAst = runtimeEnv.celOptimizer.optimize(ast); + CelAbstractSyntaxTree optimizedAst = celOptimizer.optimize(ast); assertThat(CEL_UNPARSER.unparse(optimizedAst)).isEqualTo(source); } @Test public void constantFold_addFoldableFunction_success() throws Exception { - CelAbstractSyntaxTree ast = runtimeEnv.cel.compile("get_true() == get_true()").getAst(); + CelAbstractSyntaxTree ast = cel.compile("get_true() == get_true()").getAst(); ConstantFoldingOptions options = ConstantFoldingOptions.newBuilder().addFoldableFunctions("get_true").build(); CelOptimizer optimizer = - CelOptimizerFactory.standardCelOptimizerBuilder(runtimeEnv.cel) + CelOptimizerFactory.standardCelOptimizerBuilder(cel) .addAstOptimizers(ConstantFoldingOptimizer.newInstance(options)) .build(); @@ -442,7 +427,7 @@ public void constantFold_addFoldableFunction_success() throws Exception { @Test public void constantFold_withExpectedResultTypeSet_success() throws Exception { - Cel cel = runtimeEnv.newCelBuilder().setResultType(SimpleType.STRING).build(); + Cel cel = runtimeFlavor.builder().setResultType(SimpleType.STRING).build(); CelOptimizer optimizer = CelOptimizerFactory.standardCelOptimizerBuilder(cel) .addAstOptimizers(ConstantFoldingOptimizer.getInstance()) @@ -458,8 +443,8 @@ public void constantFold_withExpectedResultTypeSet_success() throws Exception { public void constantFold_withMacroCallPopulated_comprehensionsAreReplacedWithNotSet() throws Exception { Cel cel = - runtimeEnv - .newCelBuilder() + runtimeFlavor + .builder() .addVar("x", SimpleType.DYN) .setStandardMacros(CelStandardMacro.STANDARD_MACROS) .setOptions(CEL_OPTIONS) @@ -532,9 +517,9 @@ public void constantFold_withMacroCallPopulated_comprehensionsAreReplacedWithNot @Test public void constantFold_astProducesConsistentlyNumberedIds() throws Exception { - CelAbstractSyntaxTree ast = runtimeEnv.cel.compile("[1] + [2] + [3]").getAst(); + CelAbstractSyntaxTree ast = cel.compile("[1] + [2] + [3]").getAst(); - CelAbstractSyntaxTree optimizedAst = runtimeEnv.celOptimizer.optimize(ast); + CelAbstractSyntaxTree optimizedAst = celOptimizer.optimize(ast); assertThat(optimizedAst.getExpr().toString()) .isEqualTo( @@ -555,8 +540,8 @@ public void iterationLimitReached_throws() throws Exception { sb.append(" + ").append(i); } // 0 + 1 + 2 + 3 + ... 200 Cel cel = - runtimeEnv - .newCelBuilder() + runtimeFlavor + .builder() .setOptions( CelOptions.current() .enableHeterogeneousNumericComparisons(true) diff --git a/optimizer/src/test/java/dev/cel/optimizer/optimizers/SubexpressionOptimizerBaselineTest.java b/optimizer/src/test/java/dev/cel/optimizer/optimizers/SubexpressionOptimizerBaselineTest.java index 802ef3037..04e4e6a1d 100644 --- a/optimizer/src/test/java/dev/cel/optimizer/optimizers/SubexpressionOptimizerBaselineTest.java +++ b/optimizer/src/test/java/dev/cel/optimizer/optimizers/SubexpressionOptimizerBaselineTest.java @@ -24,7 +24,6 @@ // import com.google.testing.testsize.MediumTest; import dev.cel.bundle.Cel; import dev.cel.bundle.CelBuilder; -import dev.cel.bundle.CelFactory; import dev.cel.common.CelAbstractSyntaxTree; import dev.cel.common.CelContainer; import dev.cel.common.CelFunctionDecl; @@ -43,6 +42,8 @@ import dev.cel.parser.CelUnparserFactory; import dev.cel.runtime.CelFunctionBinding; import dev.cel.testing.BaselineTestCase; +import dev.cel.testing.CelRuntimeFlavor; +import java.util.EnumSet; import java.util.Optional; import org.junit.Before; import org.junit.Test; @@ -51,6 +52,43 @@ // @MediumTest @RunWith(TestParameterInjector.class) public class SubexpressionOptimizerBaselineTest extends BaselineTestCase { + private static Cel setupCelEnv(CelBuilder celBuilder) { + return celBuilder + .addMessageTypes(TestAllTypes.getDescriptor()) + .setContainer(CelContainer.ofName("cel.expr.conformance.proto3")) + .setStandardMacros(CelStandardMacro.STANDARD_MACROS) + .setOptions( + CelOptions.current() + .populateMacroCalls(true) + .enableHeterogeneousNumericComparisons(true) + .build()) + .addCompilerLibraries( + CelExtensions.optional(), CelExtensions.bindings(), CelExtensions.comprehensions()) + .addRuntimeLibraries(CelExtensions.optional(), CelExtensions.comprehensions()) + .addFunctionDeclarations( + CelFunctionDecl.newFunctionDeclaration( + "pure_custom_func", + newGlobalOverload("pure_custom_func_overload", SimpleType.INT, SimpleType.INT)), + CelFunctionDecl.newFunctionDeclaration( + "non_pure_custom_func", + newGlobalOverload("non_pure_custom_func_overload", SimpleType.INT, SimpleType.INT))) + .addFunctionBindings( + // This is pure, but for the purposes of excluding it as a CSE candidate, pretend that + // it isn't. + CelFunctionBinding.fromOverloads( + "non_pure_custom_func", + CelFunctionBinding.from("non_pure_custom_func_overload", Long.class, val -> val))) + .addFunctionBindings( + CelFunctionBinding.fromOverloads( + "pure_custom_func", + CelFunctionBinding.from("pure_custom_func_overload", Long.class, val -> val))) + .addVar("x", SimpleType.DYN) + .addVar("y", SimpleType.DYN) + .addVar("opt_x", OptionalType.create(SimpleType.DYN)) + .addVar("msg", StructTypeReference.create(TestAllTypes.getDescriptor().getFullName())) + .build(); + } + private static final CelUnparser CEL_UNPARSER = CelUnparserFactory.newUnparser(); private static final TestAllTypes TEST_ALL_TYPES_INPUT = TestAllTypes.newBuilder() @@ -67,7 +105,6 @@ public class SubexpressionOptimizerBaselineTest extends BaselineTestCase { .putMapInt32Int64(2, 2) .putMapStringString("key", "A"))) .build(); - private static final Cel CEL = newCelBuilder().build(); private static final SubexpressionOptimizerOptions OPTIMIZER_COMMON_OPTIONS = SubexpressionOptimizerOptions.newBuilder() @@ -79,6 +116,7 @@ public class SubexpressionOptimizerBaselineTest extends BaselineTestCase { @Before public void setUp() { + this.cel = setupCelEnv(runtimeFlavor.builder()); overriddenBaseFilePath = ""; } @@ -90,45 +128,70 @@ protected String baselineFileName() { return overriddenBaseFilePath; } + @TestParameter CelRuntimeFlavor runtimeFlavor; + + private Cel cel; + @Test public void allOptimizers_producesSameEvaluationResult( @TestParameter CseTestOptimizer cseTestOptimizer, @TestParameter CseTestCase cseTestCase) throws Exception { skipBaselineVerification(); - CelAbstractSyntaxTree ast = CEL.compile(cseTestCase.source).getAst(); + CelAbstractSyntaxTree ast = cel.compile(cseTestCase.source).getAst(); ImmutableMap inputMap = ImmutableMap.of("msg", TEST_ALL_TYPES_INPUT, "x", 5L, "y", 6L, "opt_x", Optional.of(5L)); - Object expectedEvalResult = CEL.createProgram(ast).eval(inputMap); + Object expectedEvalResult = cel.createProgram(ast).eval(inputMap); - CelAbstractSyntaxTree optimizedAst = cseTestOptimizer.cseOptimizer.optimize(ast); + CelAbstractSyntaxTree optimizedAst = cseTestOptimizer.newCseOptimizer(cel).optimize(ast); - Object optimizedEvalResult = CEL.createProgram(optimizedAst).eval(inputMap); + Object optimizedEvalResult = cel.createProgram(optimizedAst).eval(inputMap); + assertThat(optimizedEvalResult).isEqualTo(expectedEvalResult); + } + + @Test + public void allOptimizers_producesSameEvaluationResult_parsedOnly( + @TestParameter CseTestCase cseTestCase, @TestParameter CseTestOptimizer cseTestOptimizer) + throws Exception { + skipBaselineVerification(); + if (runtimeFlavor.equals(CelRuntimeFlavor.LEGACY)) { + return; + } + CelAbstractSyntaxTree ast = cel.compile(cseTestCase.source).getAst(); + ImmutableMap inputMap = + ImmutableMap.of("msg", TEST_ALL_TYPES_INPUT, "x", 5L, "y", 6L, "opt_x", Optional.of(5L)); + Object expectedEvalResult = cel.createProgram(ast).eval(inputMap); + + CelAbstractSyntaxTree optimizedAst = cseTestOptimizer.newCseOptimizer(cel).optimize(ast); + CelAbstractSyntaxTree parsedOnlyOptimizedAst = + CelAbstractSyntaxTree.newParsedAst(optimizedAst.getExpr(), optimizedAst.getSource()); + + Object optimizedEvalResult = cel.createProgram(parsedOnlyOptimizedAst).eval(inputMap); assertThat(optimizedEvalResult).isEqualTo(expectedEvalResult); } @Test public void subexpression_unparsed() throws Exception { - for (CseTestCase cseTestCase : CseTestCase.values()) { + for (CseTestCase cseTestCase : EnumSet.allOf(CseTestCase.class)) { testOutput().println("Test case: " + cseTestCase.name()); testOutput().println("Source: " + cseTestCase.source); testOutput().println("=====>"); - CelAbstractSyntaxTree ast = CEL.compile(cseTestCase.source).getAst(); + CelAbstractSyntaxTree ast = cel.compile(cseTestCase.source).getAst(); boolean resultPrinted = false; for (CseTestOptimizer cseTestOptimizer : CseTestOptimizer.values()) { String optimizerName = cseTestOptimizer.name(); CelAbstractSyntaxTree optimizedAst; try { - optimizedAst = cseTestOptimizer.cseOptimizer.optimize(ast); + optimizedAst = cseTestOptimizer.newCseOptimizer(cel).optimize(ast); } catch (Exception e) { testOutput().printf("[%s]: Optimization Error: %s", optimizerName, e); continue; } if (!resultPrinted) { Object optimizedEvalResult = - CEL.createProgram(optimizedAst) + cel.createProgram(optimizedAst) .eval( ImmutableMap.of( - "msg", TEST_ALL_TYPES_INPUT, "x", 5L, "opt_x", Optional.of(5L))); + "msg", TEST_ALL_TYPES_INPUT, "x", 5L, "y", 6L, "opt_x", Optional.of(5L))); testOutput().println("Result: " + optimizedEvalResult); resultPrinted = true; } @@ -145,22 +208,22 @@ public void subexpression_unparsed() throws Exception { @Test public void constfold_before_subexpression_unparsed() throws Exception { - for (CseTestCase cseTestCase : CseTestCase.values()) { + for (CseTestCase cseTestCase : EnumSet.allOf(CseTestCase.class)) { testOutput().println("Test case: " + cseTestCase.name()); testOutput().println("Source: " + cseTestCase.source); testOutput().println("=====>"); - CelAbstractSyntaxTree ast = CEL.compile(cseTestCase.source).getAst(); + CelAbstractSyntaxTree ast = cel.compile(cseTestCase.source).getAst(); boolean resultPrinted = false; - for (CseTestOptimizer cseTestOptimizer : CseTestOptimizer.values()) { + for (CseTestOptimizer cseTestOptimizer : EnumSet.allOf(CseTestOptimizer.class)) { String optimizerName = cseTestOptimizer.name(); CelAbstractSyntaxTree optimizedAst = - cseTestOptimizer.cseWithConstFoldingOptimizer.optimize(ast); + cseTestOptimizer.newCseWithConstFoldingOptimizer(cel).optimize(ast); if (!resultPrinted) { Object optimizedEvalResult = - CEL.createProgram(optimizedAst) + cel.createProgram(optimizedAst) .eval( ImmutableMap.of( - "msg", TEST_ALL_TYPES_INPUT, "x", 5L, "opt_x", Optional.of(5L))); + "msg", TEST_ALL_TYPES_INPUT, "x", 5L, "y", 6L, "opt_x", Optional.of(5L))); testOutput().println("Result: " + optimizedEvalResult); resultPrinted = true; } @@ -179,12 +242,13 @@ public void constfold_before_subexpression_unparsed() throws Exception { public void subexpression_ast(@TestParameter CseTestOptimizer cseTestOptimizer) throws Exception { String testBasefileName = "subexpression_ast_" + Ascii.toLowerCase(cseTestOptimizer.name()); overriddenBaseFilePath = String.format("%s%s.baseline", testdataDir(), testBasefileName); - for (CseTestCase cseTestCase : CseTestCase.values()) { + for (CseTestCase cseTestCase : EnumSet.allOf(CseTestCase.class)) { testOutput().println("Test case: " + cseTestCase.name()); testOutput().println("Source: " + cseTestCase.source); testOutput().println("=====>"); - CelAbstractSyntaxTree ast = CEL.compile(cseTestCase.source).getAst(); - CelAbstractSyntaxTree optimizedAst = cseTestOptimizer.cseOptimizer.optimize(ast); + CelAbstractSyntaxTree ast = cel.compile(cseTestCase.source).getAst(); + CelAbstractSyntaxTree optimizedAst = + newCseOptimizer(cel, cseTestOptimizer.option).optimize(ast); testOutput().println(optimizedAst.getExpr()); } } @@ -193,7 +257,7 @@ public void subexpression_ast(@TestParameter CseTestOptimizer cseTestOptimizer) public void large_expressions_block_common_subexpr() throws Exception { CelOptimizer celOptimizer = newCseOptimizer( - CEL, SubexpressionOptimizerOptions.newBuilder().populateMacroCalls(true).build()); + cel, SubexpressionOptimizerOptions.newBuilder().populateMacroCalls(true).build()); runLargeTestCases(celOptimizer); } @@ -202,7 +266,7 @@ public void large_expressions_block_common_subexpr() throws Exception { public void large_expressions_block_recursion_depth_1() throws Exception { CelOptimizer celOptimizer = newCseOptimizer( - CEL, + cel, SubexpressionOptimizerOptions.newBuilder() .populateMacroCalls(true) .subexpressionMaxRecursionDepth(1) @@ -215,7 +279,7 @@ public void large_expressions_block_recursion_depth_1() throws Exception { public void large_expressions_block_recursion_depth_2() throws Exception { CelOptimizer celOptimizer = newCseOptimizer( - CEL, + cel, SubexpressionOptimizerOptions.newBuilder() .populateMacroCalls(true) .subexpressionMaxRecursionDepth(2) @@ -228,7 +292,7 @@ public void large_expressions_block_recursion_depth_2() throws Exception { public void large_expressions_block_recursion_depth_3() throws Exception { CelOptimizer celOptimizer = newCseOptimizer( - CEL, + cel, SubexpressionOptimizerOptions.newBuilder() .populateMacroCalls(true) .subexpressionMaxRecursionDepth(3) @@ -238,15 +302,14 @@ public void large_expressions_block_recursion_depth_3() throws Exception { } private void runLargeTestCases(CelOptimizer celOptimizer) throws Exception { - for (CseLargeTestCase cseTestCase : CseLargeTestCase.values()) { + for (CseLargeTestCase cseTestCase : EnumSet.allOf(CseLargeTestCase.class)) { testOutput().println("Test case: " + cseTestCase.name()); testOutput().println("Source: " + cseTestCase.source); testOutput().println("=====>"); - CelAbstractSyntaxTree ast = CEL.compile(cseTestCase.source).getAst(); - + CelAbstractSyntaxTree ast = cel.compile(cseTestCase.source).getAst(); CelAbstractSyntaxTree optimizedAst = celOptimizer.optimize(ast); Object optimizedEvalResult = - CEL.createProgram(optimizedAst) + cel.createProgram(optimizedAst) .eval( ImmutableMap.of("msg", TEST_ALL_TYPES_INPUT, "x", 5L, "opt_x", Optional.of(5L))); testOutput().println("Result: " + optimizedEvalResult); @@ -260,33 +323,6 @@ private void runLargeTestCases(CelOptimizer celOptimizer) throws Exception { } } - private static CelBuilder newCelBuilder() { - return CelFactory.standardCelBuilder() - .addMessageTypes(TestAllTypes.getDescriptor()) - .setContainer(CelContainer.ofName("cel.expr.conformance.proto3")) - .setStandardMacros(CelStandardMacro.STANDARD_MACROS) - .setOptions(CelOptions.current().populateMacroCalls(true).build()) - .addCompilerLibraries( - CelExtensions.optional(), CelExtensions.bindings(), CelExtensions.comprehensions()) - .addRuntimeLibraries(CelExtensions.optional(), CelExtensions.comprehensions()) - .addFunctionDeclarations( - CelFunctionDecl.newFunctionDeclaration( - "pure_custom_func", - newGlobalOverload("pure_custom_func_overload", SimpleType.INT, SimpleType.INT)), - CelFunctionDecl.newFunctionDeclaration( - "non_pure_custom_func", - newGlobalOverload("non_pure_custom_func_overload", SimpleType.INT, SimpleType.INT))) - .addFunctionBindings( - // This is pure, but for the purposes of excluding it as a CSE candidate, pretend that - // it isn't. - CelFunctionBinding.from("non_pure_custom_func_overload", Long.class, val -> val), - CelFunctionBinding.from("pure_custom_func_overload", Long.class, val -> val)) - .addVar("x", SimpleType.DYN) - .addVar("y", SimpleType.DYN) - .addVar("opt_x", OptionalType.create(SimpleType.DYN)) - .addVar("msg", StructTypeReference.create(TestAllTypes.getDescriptor().getFullName())); - } - private static CelOptimizer newCseOptimizer(Cel cel, SubexpressionOptimizerOptions options) { return CelOptimizerFactory.standardCelOptimizerBuilder(cel) .addAstOptimizers(SubexpressionOptimizer.newInstance(options)) @@ -315,17 +351,23 @@ private enum CseTestOptimizer { BLOCK_RECURSION_DEPTH_9( OPTIMIZER_COMMON_OPTIONS.toBuilder().subexpressionMaxRecursionDepth(9).build()); - private final CelOptimizer cseOptimizer; - private final CelOptimizer cseWithConstFoldingOptimizer; + private final SubexpressionOptimizerOptions option; CseTestOptimizer(SubexpressionOptimizerOptions option) { - this.cseOptimizer = newCseOptimizer(CEL, option); - this.cseWithConstFoldingOptimizer = - CelOptimizerFactory.standardCelOptimizerBuilder(CEL) - .addAstOptimizers( - ConstantFoldingOptimizer.getInstance(), - SubexpressionOptimizer.newInstance(option)) - .build(); + this.option = option; + } + + // Defers building the optimizer until the test runs + private CelOptimizer newCseOptimizer(Cel cel) { + return SubexpressionOptimizerBaselineTest.newCseOptimizer(cel, option); + } + + // Defers building the optimizer until the test runs + private CelOptimizer newCseWithConstFoldingOptimizer(Cel cel) { + return CelOptimizerFactory.standardCelOptimizerBuilder(cel) + .addAstOptimizers( + ConstantFoldingOptimizer.getInstance(), SubexpressionOptimizer.newInstance(option)) + .build(); } } diff --git a/optimizer/src/test/java/dev/cel/optimizer/optimizers/SubexpressionOptimizerTest.java b/optimizer/src/test/java/dev/cel/optimizer/optimizers/SubexpressionOptimizerTest.java index 2289a7d4a..e7387d7d8 100644 --- a/optimizer/src/test/java/dev/cel/optimizer/optimizers/SubexpressionOptimizerTest.java +++ b/optimizer/src/test/java/dev/cel/optimizer/optimizers/SubexpressionOptimizerTest.java @@ -52,52 +52,88 @@ import dev.cel.parser.CelStandardMacro; import dev.cel.parser.CelUnparser; import dev.cel.parser.CelUnparserFactory; +import dev.cel.runtime.CelAttributePattern; import dev.cel.runtime.CelEvaluationException; import dev.cel.runtime.CelFunctionBinding; import dev.cel.runtime.CelRuntime; import dev.cel.runtime.CelRuntimeFactory; +import dev.cel.runtime.CelUnknownSet; +import dev.cel.runtime.PartialVars; +import dev.cel.runtime.Program; +import dev.cel.testing.CelRuntimeFlavor; import java.util.List; import java.util.concurrent.atomic.AtomicInteger; +import org.junit.Before; import org.junit.Test; import org.junit.runner.RunWith; @RunWith(TestParameterInjector.class) public class SubexpressionOptimizerTest { - private static final Cel CEL = newCelBuilder().build(); - - private static final Cel CEL_FOR_EVALUATING_BLOCK = - CelFactory.standardCelBuilder() - .setStandardMacros(CelStandardMacro.STANDARD_MACROS) - .addFunctionDeclarations( - // These are test only declarations, as the actual function is made internal using @ - // symbol. - // If the main function declaration needs updating, be sure to update the test - // declaration as well. - CelFunctionDecl.newFunctionDeclaration( - "cel.block", - CelOverloadDecl.newGlobalOverload( - "block_test_only_overload", - SimpleType.DYN, - ListType.create(SimpleType.DYN), - SimpleType.DYN)), - SubexpressionOptimizer.newCelBlockFunctionDecl(SimpleType.DYN), - CelFunctionDecl.newFunctionDeclaration( - "get_true", - CelOverloadDecl.newGlobalOverload("get_true_overload", SimpleType.BOOL))) - // Similarly, this is a test only decl (index0 -> @index0) - .addVarDeclarations( - CelVarDecl.newVarDeclaration("c0", SimpleType.DYN), - CelVarDecl.newVarDeclaration("c1", SimpleType.DYN), - CelVarDecl.newVarDeclaration("index0", SimpleType.DYN), - CelVarDecl.newVarDeclaration("index1", SimpleType.DYN), - CelVarDecl.newVarDeclaration("index2", SimpleType.DYN), - CelVarDecl.newVarDeclaration("@index0", SimpleType.DYN), - CelVarDecl.newVarDeclaration("@index1", SimpleType.DYN), - CelVarDecl.newVarDeclaration("@index2", SimpleType.DYN)) - .addMessageTypes(TestAllTypes.getDescriptor()) - .addVar("msg", StructTypeReference.create(TestAllTypes.getDescriptor().getFullName())) - .build(); + private static Cel setupCelEnv(CelBuilder celBuilder) { + return celBuilder + .addMessageTypes(TestAllTypes.getDescriptor()) + .setStandardMacros(CelStandardMacro.STANDARD_MACROS) + .setOptions( + CelOptions.current() + .populateMacroCalls(true) + .enableHeterogeneousNumericComparisons(true) + .build()) + .addCompilerLibraries(CelExtensions.bindings(), CelExtensions.strings()) + .addRuntimeLibraries(CelExtensions.strings()) + .addFunctionDeclarations( + CelFunctionDecl.newFunctionDeclaration( + "non_pure_custom_func", + newGlobalOverload("non_pure_custom_func_overload", SimpleType.INT, SimpleType.INT))) + .addVar("x", SimpleType.DYN) + .addVar("msg", StructTypeReference.create(TestAllTypes.getDescriptor().getFullName())) + .build(); + } + + private static Cel setupCelForEvaluatingBlock(CelBuilder celBuilder) { + return celBuilder + .setStandardMacros(CelStandardMacro.STANDARD_MACROS) + .addFunctionDeclarations( + // These are test only declarations, as the actual function is made internal using @ + // symbol. + // If the main function declaration needs updating, be sure to update the test + // declaration as well. + CelFunctionDecl.newFunctionDeclaration( + "cel.block", + CelOverloadDecl.newGlobalOverload( + "block_test_only_overload", + SimpleType.DYN, + ListType.create(SimpleType.DYN), + SimpleType.DYN)), + SubexpressionOptimizer.newCelBlockFunctionDecl(SimpleType.DYN), + CelFunctionDecl.newFunctionDeclaration( + "get_true", + CelOverloadDecl.newGlobalOverload("get_true_overload", SimpleType.BOOL))) + // Similarly, this is a test only decl (index0 -> @index0) + .addVarDeclarations( + CelVarDecl.newVarDeclaration("c0", SimpleType.DYN), + CelVarDecl.newVarDeclaration("c1", SimpleType.DYN), + CelVarDecl.newVarDeclaration("index0", SimpleType.DYN), + CelVarDecl.newVarDeclaration("index1", SimpleType.DYN), + CelVarDecl.newVarDeclaration("index2", SimpleType.DYN), + CelVarDecl.newVarDeclaration("@index0", SimpleType.DYN), + CelVarDecl.newVarDeclaration("@index1", SimpleType.DYN), + CelVarDecl.newVarDeclaration("@index2", SimpleType.DYN)) + .addMessageTypes(TestAllTypes.getDescriptor()) + .addVar("msg", StructTypeReference.create(TestAllTypes.getDescriptor().getFullName())) + .build(); + } + + @TestParameter CelRuntimeFlavor runtimeFlavor; + + private Cel cel; + private Cel celForEvaluatingBlock; + + @Before + public void setUp() { + this.cel = setupCelEnv(runtimeFlavor.builder()); + this.celForEvaluatingBlock = setupCelForEvaluatingBlock(runtimeFlavor.builder()); + } private static final CelUnparser CEL_UNPARSER = CelUnparserFactory.newUnparser(); @@ -115,8 +151,8 @@ private static CelBuilder newCelBuilder() { .addVar("msg", StructTypeReference.create(TestAllTypes.getDescriptor().getFullName())); } - private static CelOptimizer newCseOptimizer(SubexpressionOptimizerOptions options) { - return CelOptimizerFactory.standardCelOptimizerBuilder(CEL) + private CelOptimizer newCseOptimizer(SubexpressionOptimizerOptions options) { + return CelOptimizerFactory.standardCelOptimizerBuilder(cel) .addAstOptimizers(SubexpressionOptimizer.newInstance(options)) .build(); } @@ -130,15 +166,53 @@ public void cse_resultTypeSet_celBlockOptimizationSuccess() throws Exception { SubexpressionOptimizer.newInstance( SubexpressionOptimizerOptions.newBuilder().build())) .build(); - CelAbstractSyntaxTree ast = CEL.compile("size('a') + size('a') == 2").getAst(); + CelAbstractSyntaxTree ast = cel.compile("size('a') + size('a') == 2").getAst(); CelAbstractSyntaxTree optimizedAst = celOptimizer.optimize(ast); - assertThat(CEL.createProgram(optimizedAst).eval()).isEqualTo(true); + assertThat(cel.createProgram(optimizedAst).eval()).isEqualTo(true); assertThat(CEL_UNPARSER.unparse(optimizedAst)) .isEqualTo("cel.@block([size(\"a\")], @index0 + @index0 == 2)"); } + @Test + public void cse_indexEvaluationErrors_throws() throws Exception { + CelAbstractSyntaxTree ast = cel.compile("\"abc\".charAt(10) + \"abc\".charAt(10)").getAst(); + CelOptimizer optimizedOptimizer = + CelOptimizerFactory.standardCelOptimizerBuilder(cel) + .addAstOptimizers(SubexpressionOptimizer.getInstance()) + .build(); + + CelAbstractSyntaxTree optimizedAst = optimizedOptimizer.optimize(ast); + + String unparsed = CEL_UNPARSER.unparse(optimizedAst); + assertThat(unparsed).isEqualTo("cel.@block([\"abc\".charAt(10)], @index0 + @index0)"); + + Program program = cel.createProgram(optimizedAst); + CelEvaluationException e = + assertThrows(CelEvaluationException.class, () -> program.eval(ImmutableMap.of())); + assertThat(e).hasMessageThat().contains("charAt failure: Index out of range: 10"); + } + + @Test + public void cse_withUnknownAttributes() throws Exception { + CelAbstractSyntaxTree ast = cel.compile("size(\"a\") == 1 ? x.y : x.y").getAst(); + CelOptimizer optimizer = + CelOptimizerFactory.standardCelOptimizerBuilder(cel) + .addAstOptimizers(SubexpressionOptimizer.getInstance()) + .build(); + + CelAbstractSyntaxTree optimizedAst = optimizer.optimize(ast); + + assertThat(CEL_UNPARSER.unparse(optimizedAst)) + .isEqualTo("cel.@block([x.y], (size(\"a\") == 1) ? @index0 : @index0)"); + + Object result = + cel.createProgram(optimizedAst) + .eval(PartialVars.of(CelAttributePattern.fromQualifiedIdentifier("x"))); + assertThat(result).isInstanceOf(CelUnknownSet.class); + } + private enum CseNoOpTestCase { // Nothing to optimize NO_COMMON_SUBEXPR("size(\"hello\")"), @@ -169,7 +243,7 @@ private enum CseNoOpTestCase { @Test public void cse_withCelBind_noop(@TestParameter CseNoOpTestCase testCase) throws Exception { - CelAbstractSyntaxTree ast = CEL.compile(testCase.source).getAst(); + CelAbstractSyntaxTree ast = cel.compile(testCase.source).getAst(); CelAbstractSyntaxTree optimizedAst = newCseOptimizer(SubexpressionOptimizerOptions.newBuilder().populateMacroCalls(true).build()) @@ -181,7 +255,7 @@ public void cse_withCelBind_noop(@TestParameter CseNoOpTestCase testCase) throws @Test public void cse_withCelBlock_noop(@TestParameter CseNoOpTestCase testCase) throws Exception { - CelAbstractSyntaxTree ast = CEL.compile(testCase.source).getAst(); + CelAbstractSyntaxTree ast = cel.compile(testCase.source).getAst(); CelAbstractSyntaxTree optimizedAst = newCseOptimizer(SubexpressionOptimizerOptions.newBuilder().populateMacroCalls(true).build()) @@ -194,7 +268,7 @@ public void cse_withCelBlock_noop(@TestParameter CseNoOpTestCase testCase) throw @Test public void cse_withComprehensionStructureRetained() throws Exception { CelAbstractSyntaxTree ast = - CEL.compile("['foo'].map(x, [x+x]) + ['foo'].map(x, [x+x, x+x])").getAst(); + cel.compile("['foo'].map(x, [x+x]) + ['foo'].map(x, [x+x, x+x])").getAst(); CelOptimizer celOptimizer = newCseOptimizer( SubexpressionOptimizerOptions.newBuilder().populateMacroCalls(true).build()); @@ -210,10 +284,10 @@ public void cse_withComprehensionStructureRetained() throws Exception { @Test public void cse_applyConstFoldingBefore() throws Exception { CelAbstractSyntaxTree ast = - CEL.compile("size([1+1+1]) + size([1+1+1]) + size([1,1+1+1]) + size([1,1+1+1]) + x") + cel.compile("size([1+1+1]) + size([1+1+1]) + size([1,1+1+1]) + size([1,1+1+1]) + x") .getAst(); CelOptimizer optimizer = - CelOptimizerFactory.standardCelOptimizerBuilder(CEL) + CelOptimizerFactory.standardCelOptimizerBuilder(cel) .addAstOptimizers( ConstantFoldingOptimizer.getInstance(), SubexpressionOptimizer.newInstance( @@ -228,10 +302,10 @@ public void cse_applyConstFoldingBefore() throws Exception { @Test public void cse_applyConstFoldingAfter() throws Exception { CelAbstractSyntaxTree ast = - CEL.compile("size([1+1+1]) + size([1+1+1]) + size([1,1+1+1]) + size([1,1+1+1]) + x") + cel.compile("size([1+1+1]) + size([1+1+1]) + size([1,1+1+1]) + size([1,1+1+1]) + x") .getAst(); CelOptimizer optimizer = - CelOptimizerFactory.standardCelOptimizerBuilder(CEL) + CelOptimizerFactory.standardCelOptimizerBuilder(cel) .addAstOptimizers( SubexpressionOptimizer.newInstance( SubexpressionOptimizerOptions.newBuilder().build()), @@ -246,9 +320,9 @@ public void cse_applyConstFoldingAfter() throws Exception { @Test public void cse_applyConstFoldingAfter_nothingToFold() throws Exception { - CelAbstractSyntaxTree ast = CEL.compile("size(x) + size(x)").getAst(); + CelAbstractSyntaxTree ast = cel.compile("size(x) + size(x)").getAst(); CelOptimizer optimizer = - CelOptimizerFactory.standardCelOptimizerBuilder(CEL) + CelOptimizerFactory.standardCelOptimizerBuilder(cel) .addAstOptimizers( SubexpressionOptimizer.newInstance( SubexpressionOptimizerOptions.newBuilder().populateMacroCalls(true).build()), @@ -271,7 +345,7 @@ public void iterationLimitReached_throws() throws Exception { largeExprBuilder.append("+"); } } - CelAbstractSyntaxTree ast = CEL.compile(largeExprBuilder.toString()).getAst(); + CelAbstractSyntaxTree ast = cel.compile(largeExprBuilder.toString()).getAst(); CelOptimizationException e = assertThrows( @@ -287,9 +361,9 @@ public void iterationLimitReached_throws() throws Exception { @Test public void celBlock_astExtensionTagged() throws Exception { - CelAbstractSyntaxTree ast = CEL.compile("size(x) + size(x)").getAst(); + CelAbstractSyntaxTree ast = cel.compile("size(x) + size(x)").getAst(); CelOptimizer optimizer = - CelOptimizerFactory.standardCelOptimizerBuilder(CEL) + CelOptimizerFactory.standardCelOptimizerBuilder(cel) .addAstOptimizers( SubexpressionOptimizer.newInstance( SubexpressionOptimizerOptions.newBuilder().populateMacroCalls(true).build()), @@ -322,7 +396,20 @@ private enum BlockTestCase { public void block_success(@TestParameter BlockTestCase testCase) throws Exception { CelAbstractSyntaxTree ast = compileUsingInternalFunctions(testCase.source); - Object evaluatedResult = CEL_FOR_EVALUATING_BLOCK.createProgram(ast).eval(); + Object evaluatedResult = celForEvaluatingBlock.createProgram(ast).eval(); + + assertThat(evaluatedResult).isNotNull(); + } + + @Test + public void block_success_parsedOnly(@TestParameter BlockTestCase testCase) throws Exception { + if (runtimeFlavor.equals(CelRuntimeFlavor.LEGACY)) { + return; + } + CelAbstractSyntaxTree ast = + compileUsingInternalFunctions(testCase.source, /* parsedOnly= */ true); + + Object evaluatedResult = celForEvaluatingBlock.createProgram(ast).eval(); assertThat(evaluatedResult).isNotNull(); } @@ -584,7 +671,7 @@ public void block_containsCycle_throws() throws Exception { CelAbstractSyntaxTree ast = compileUsingInternalFunctions("cel.block([index1,index0],index0)"); CelEvaluationException e = - assertThrows(CelEvaluationException.class, () -> CEL.createProgram(ast).eval()); + assertThrows(CelEvaluationException.class, () -> cel.createProgram(ast).eval()); assertThat(e).hasMessageThat().contains("Cycle detected: @index0"); } @@ -595,7 +682,7 @@ public void block_lazyEvaluationContainsError_cleansUpCycleState() throws Except "cel.block([1/0 > 0], (index0 && false) || (index0 && true))"); CelEvaluationException e = - assertThrows(CelEvaluationException.class, () -> CEL.createProgram(ast).eval()); + assertThrows(CelEvaluationException.class, () -> cel.createProgram(ast).eval()); assertThat(e).hasMessageThat().contains("/ by zero"); assertThat(e).hasMessageThat().doesNotContain("Cycle detected"); @@ -605,9 +692,9 @@ public void block_lazyEvaluationContainsError_cleansUpCycleState() throws Except * Converts AST containing cel.block related test functions to internal functions (e.g: cel.block * -> cel.@block) */ - private static CelAbstractSyntaxTree compileUsingInternalFunctions(String expression) + private CelAbstractSyntaxTree compileUsingInternalFunctions(String expression, boolean parsedOnly) throws CelValidationException { - CelAbstractSyntaxTree astToModify = CEL_FOR_EVALUATING_BLOCK.compile(expression).getAst(); + CelAbstractSyntaxTree astToModify = celForEvaluatingBlock.compile(expression).getAst(); CelMutableAst mutableAst = CelMutableAst.fromCelAst(astToModify); CelNavigableMutableAst.fromAst(mutableAst) .getRoot() @@ -629,6 +716,14 @@ private static CelAbstractSyntaxTree compileUsingInternalFunctions(String expres indexExpr.ident().setName(internalIdentName); }); - return CEL_FOR_EVALUATING_BLOCK.check(mutableAst.toParsedAst()).getAst(); + if (parsedOnly) { + return mutableAst.toParsedAst(); + } + return celForEvaluatingBlock.check(mutableAst.toParsedAst()).getAst(); + } + + private CelAbstractSyntaxTree compileUsingInternalFunctions(String expression) + throws CelValidationException { + return compileUsingInternalFunctions(expression, false); } } diff --git a/optimizer/src/test/resources/constfold_before_subexpression_unparsed.baseline b/optimizer/src/test/resources/constfold_before_subexpression_unparsed.baseline index 55da856cd..9139c7a35 100644 --- a/optimizer/src/test/resources/constfold_before_subexpression_unparsed.baseline +++ b/optimizer/src/test/resources/constfold_before_subexpression_unparsed.baseline @@ -526,7 +526,7 @@ Result: [[[foofoo, foofoo, foofoo, foofoo], [foofoo, foofoo, foofoo, foofoo]], [ Test case: MACRO_SHADOWED_VARIABLE_COMP_V2_1 Source: [x - y - 1 > 3 ? x - y - 1 : 5].exists(x, y, x - y - 1 > 3) || x - y - 1 > 3 =====> -Result: CelUnknownSet{attributes=[], unknownExprIds=[6]} +Result: false [BLOCK_COMMON_SUBEXPR_ONLY]: cel.@block([x - y - 1, @index0 > 3], [@index1 ? @index0 : 5].exists(@it:0:0, @it2:0:0, @it:0:0 - @it2:0:0 - 1 > 3) || @index1) [BLOCK_RECURSION_DEPTH_1]: cel.@block([x - y, @index0 - 1, @index1 > 3, @index2 ? @index1 : 5, [@index3]], @index4.exists(@it:0:0, @it2:0:0, @it:0:0 - @it2:0:0 - 1 > 3) || @index2) [BLOCK_RECURSION_DEPTH_2]: cel.@block([x - y - 1, @index0 > 3, [@index1 ? @index0 : 5]], @index2.exists(@it:0:0, @it2:0:0, @it:0:0 - @it2:0:0 - 1 > 3) || @index1) diff --git a/optimizer/src/test/resources/subexpression_unparsed.baseline b/optimizer/src/test/resources/subexpression_unparsed.baseline index e0edc8987..780664a14 100644 --- a/optimizer/src/test/resources/subexpression_unparsed.baseline +++ b/optimizer/src/test/resources/subexpression_unparsed.baseline @@ -526,7 +526,7 @@ Result: [[[foofoo, foofoo, foofoo, foofoo], [foofoo, foofoo, foofoo, foofoo]], [ Test case: MACRO_SHADOWED_VARIABLE_COMP_V2_1 Source: [x - y - 1 > 3 ? x - y - 1 : 5].exists(x, y, x - y - 1 > 3) || x - y - 1 > 3 =====> -Result: CelUnknownSet{attributes=[], unknownExprIds=[6]} +Result: false [BLOCK_COMMON_SUBEXPR_ONLY]: cel.@block([x - y - 1, @index0 > 3], [@index1 ? @index0 : 5].exists(@it:0:0, @it2:0:0, @it:0:0 - @it2:0:0 - 1 > 3) || @index1) [BLOCK_RECURSION_DEPTH_1]: cel.@block([x - y, @index0 - 1, @index1 > 3, @index2 ? @index1 : 5, [@index3]], @index4.exists(@it:0:0, @it2:0:0, @it:0:0 - @it2:0:0 - 1 > 3) || @index2) [BLOCK_RECURSION_DEPTH_2]: cel.@block([x - y - 1, @index0 > 3, [@index1 ? @index0 : 5]], @index2.exists(@it:0:0, @it2:0:0, @it:0:0 - @it2:0:0 - 1 > 3) || @index1) diff --git a/runtime/src/main/java/dev/cel/runtime/planner/BUILD.bazel b/runtime/src/main/java/dev/cel/runtime/planner/BUILD.bazel index fc70118e4..cb2ad5a82 100644 --- a/runtime/src/main/java/dev/cel/runtime/planner/BUILD.bazel +++ b/runtime/src/main/java/dev/cel/runtime/planner/BUILD.bazel @@ -18,6 +18,7 @@ java_library( ":eval_and", ":eval_attribute", ":eval_binary", + ":eval_block", ":eval_conditional", ":eval_const", ":eval_create_list", @@ -67,7 +68,6 @@ java_library( srcs = ["PlannedProgram.java"], deps = [ ":error_metadata", - ":execution_frame", ":localized_evaluation_exception", ":planned_interpretable", "//:auto_value", @@ -92,11 +92,9 @@ java_library( name = "eval_const", srcs = ["EvalConstant.java"], deps = [ - ":execution_frame", ":planned_interpretable", "//runtime:interpretable", "@maven//:com_google_errorprone_error_prone_annotations", - "@maven//:com_google_guava_guava", ], ) @@ -123,7 +121,6 @@ java_library( deps = [ ":activation_wrapper", ":eval_helpers", - ":execution_frame", ":planned_interpretable", ":qualifier", "//common:container", @@ -183,8 +180,8 @@ java_library( srcs = ["EvalAttribute.java"], deps = [ ":attribute", - ":execution_frame", ":interpretable_attribute", + ":planned_interpretable", ":qualifier", "//runtime:interpretable", "@maven//:com_google_errorprone_error_prone_annotations", @@ -195,8 +192,8 @@ java_library( name = "eval_test_only", srcs = ["EvalTestOnly.java"], deps = [ - ":execution_frame", ":interpretable_attribute", + ":planned_interpretable", ":presence_test_qualifier", ":qualifier", "//runtime:evaluation_exception", @@ -210,7 +207,6 @@ java_library( srcs = ["EvalZeroArity.java"], deps = [ ":eval_helpers", - ":execution_frame", ":planned_interpretable", "//common/values", "//runtime:evaluation_exception", @@ -224,7 +220,6 @@ java_library( srcs = ["EvalUnary.java"], deps = [ ":eval_helpers", - ":execution_frame", ":planned_interpretable", "//common/values", "//runtime:evaluation_exception", @@ -238,7 +233,6 @@ java_library( srcs = ["EvalBinary.java"], deps = [ ":eval_helpers", - ":execution_frame", ":planned_interpretable", "//common/values", "//runtime:accumulated_unknowns", @@ -253,7 +247,6 @@ java_library( srcs = ["EvalVarArgsCall.java"], deps = [ ":eval_helpers", - ":execution_frame", ":planned_interpretable", "//common/values", "//runtime:accumulated_unknowns", @@ -268,7 +261,6 @@ java_library( srcs = ["EvalLateBoundCall.java"], deps = [ ":eval_helpers", - ":execution_frame", ":planned_interpretable", "//common/exceptions:overload_not_found", "//common/values", @@ -285,7 +277,6 @@ java_library( srcs = ["EvalOr.java"], deps = [ ":eval_helpers", - ":execution_frame", ":planned_interpretable", "//common/values", "//runtime:accumulated_unknowns", @@ -299,7 +290,6 @@ java_library( srcs = ["EvalAnd.java"], deps = [ ":eval_helpers", - ":execution_frame", ":planned_interpretable", "//common/values", "//runtime:accumulated_unknowns", @@ -312,7 +302,6 @@ java_library( name = "eval_conditional", srcs = ["EvalConditional.java"], deps = [ - ":execution_frame", ":planned_interpretable", "//runtime:accumulated_unknowns", "//runtime:evaluation_exception", @@ -326,7 +315,6 @@ java_library( srcs = ["EvalCreateStruct.java"], deps = [ ":eval_helpers", - ":execution_frame", ":planned_interpretable", "//common/types:type_providers", "//common/values", @@ -344,7 +332,6 @@ java_library( srcs = ["EvalCreateList.java"], deps = [ ":eval_helpers", - ":execution_frame", ":planned_interpretable", "//runtime:accumulated_unknowns", "//runtime:evaluation_exception", @@ -359,7 +346,6 @@ java_library( srcs = ["EvalCreateMap.java"], deps = [ ":eval_helpers", - ":execution_frame", ":localized_evaluation_exception", ":planned_interpretable", "//common/exceptions:duplicate_key", @@ -377,7 +363,6 @@ java_library( srcs = ["EvalFold.java"], deps = [ ":activation_wrapper", - ":execution_frame", ":planned_interpretable", "//runtime:accumulated_unknowns", "//runtime:concatenated_list_view", @@ -389,24 +374,10 @@ java_library( ], ) -java_library( - name = "execution_frame", - srcs = ["ExecutionFrame.java"], - deps = [ - "//common:options", - "//common/exceptions:iteration_budget_exceeded", - "//runtime:evaluation_exception", - "//runtime:function_resolver", - "//runtime:partial_vars", - "//runtime:resolved_overload", - ], -) - java_library( name = "eval_helpers", srcs = ["EvalHelpers.java"], deps = [ - ":execution_frame", ":localized_evaluation_exception", ":planned_interpretable", "//common:error_codes", @@ -440,11 +411,20 @@ java_library( java_library( name = "planned_interpretable", - srcs = ["PlannedInterpretable.java"], + srcs = [ + "BlockMemoizer.java", + "ExecutionFrame.java", + "PlannedInterpretable.java", + ], deps = [ - ":execution_frame", + ":localized_evaluation_exception", + "//common:options", + "//common/exceptions:iteration_budget_exceeded", "//runtime:evaluation_exception", + "//runtime:function_resolver", "//runtime:interpretable", + "//runtime:partial_vars", + "//runtime:resolved_overload", "@maven//:com_google_errorprone_error_prone_annotations", ], ) @@ -454,7 +434,6 @@ java_library( srcs = ["EvalOptionalOr.java"], deps = [ ":eval_helpers", - ":execution_frame", ":planned_interpretable", "//common/exceptions:overload_not_found", "//runtime:accumulated_unknowns", @@ -469,7 +448,6 @@ java_library( srcs = ["EvalOptionalOrValue.java"], deps = [ ":eval_helpers", - ":execution_frame", ":planned_interpretable", "//common/exceptions:overload_not_found", "//runtime:accumulated_unknowns", @@ -484,7 +462,6 @@ java_library( srcs = ["EvalOptionalSelectField.java"], deps = [ ":eval_helpers", - ":execution_frame", ":planned_interpretable", "//common/values", "//runtime:accumulated_unknowns", @@ -493,3 +470,14 @@ java_library( "@maven//:com_google_guava_guava", ], ) + +java_library( + name = "eval_block", + srcs = ["EvalBlock.java"], + deps = [ + ":planned_interpretable", + "//runtime:evaluation_exception", + "//runtime:interpretable", + "@maven//:com_google_errorprone_error_prone_annotations", + ], +) diff --git a/runtime/src/main/java/dev/cel/runtime/planner/BlockMemoizer.java b/runtime/src/main/java/dev/cel/runtime/planner/BlockMemoizer.java new file mode 100644 index 000000000..978029b3d --- /dev/null +++ b/runtime/src/main/java/dev/cel/runtime/planner/BlockMemoizer.java @@ -0,0 +1,72 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package dev.cel.runtime.planner; + +import dev.cel.runtime.CelEvaluationException; +import dev.cel.runtime.GlobalResolver; +import java.util.Arrays; + +/** Handles memoization, lazy evaluation, and cycle detection for cel.@block slots. */ +final class BlockMemoizer { + + private static final Object IN_PROGRESS = new Object(); + private static final Object UNSET = new Object(); + + private final PlannedInterpretable[] slotExprs; + private final Object[] slotVals; + private final ExecutionFrame frame; + + static BlockMemoizer create(PlannedInterpretable[] slotExprs, ExecutionFrame frame) { + return new BlockMemoizer(slotExprs, frame); + } + + private BlockMemoizer(PlannedInterpretable[] slotExprs, ExecutionFrame frame) { + this.slotExprs = slotExprs; + this.frame = frame; + this.slotVals = new Object[slotExprs.length]; + Arrays.fill(this.slotVals, UNSET); + } + + Object resolveSlot(int idx, GlobalResolver resolver) { + Object val = slotVals[idx]; + + // Already evaluated + if (val != UNSET && val != IN_PROGRESS) { + if (val instanceof RuntimeException) { + throw (RuntimeException) val; + } + return val; + } + + if (val == IN_PROGRESS) { + throw new IllegalStateException("Cycle detected: @index" + idx); + } + + slotVals[idx] = IN_PROGRESS; + try { + Object result = slotExprs[idx].eval(resolver, frame); + slotVals[idx] = result; + return result; + } catch (CelEvaluationException e) { + LocalizedEvaluationException localizedException = + new LocalizedEvaluationException(e, e.getErrorCode(), slotExprs[idx].exprId()); + slotVals[idx] = localizedException; + throw localizedException; + } catch (RuntimeException e) { + slotVals[idx] = e; + throw e; + } + } +} diff --git a/runtime/src/main/java/dev/cel/runtime/planner/EvalBlock.java b/runtime/src/main/java/dev/cel/runtime/planner/EvalBlock.java new file mode 100644 index 000000000..41ad4034e --- /dev/null +++ b/runtime/src/main/java/dev/cel/runtime/planner/EvalBlock.java @@ -0,0 +1,67 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package dev.cel.runtime.planner; + +import com.google.errorprone.annotations.Immutable; +import dev.cel.runtime.CelEvaluationException; +import dev.cel.runtime.GlobalResolver; + +/** Eval implementation of {@code cel.@block}. */ +@Immutable +final class EvalBlock extends PlannedInterpretable { + + @SuppressWarnings("Immutable") // Array not mutated after creation + private final PlannedInterpretable[] slotExprs; + + private final PlannedInterpretable resultExpr; + + static EvalBlock create( + long exprId, PlannedInterpretable[] slotExprs, PlannedInterpretable resultExpr) { + return new EvalBlock(exprId, slotExprs, resultExpr); + } + + private EvalBlock( + long exprId, PlannedInterpretable[] slotExprs, PlannedInterpretable resultExpr) { + super(exprId); + this.slotExprs = slotExprs; + this.resultExpr = resultExpr; + } + + @Override + public Object eval(GlobalResolver resolver, ExecutionFrame frame) throws CelEvaluationException { + BlockMemoizer memoizer = BlockMemoizer.create(slotExprs, frame); + frame.setBlockMemoizer(memoizer); + return resultExpr.eval(resolver, frame); + } + + @Immutable + static final class EvalBlockSlot extends PlannedInterpretable { + private final int slotIndex; + + static EvalBlockSlot create(long exprId, int slotIndex) { + return new EvalBlockSlot(exprId, slotIndex); + } + + private EvalBlockSlot(long exprId, int slotIndex) { + super(exprId); + this.slotIndex = slotIndex; + } + + @Override + public Object eval(GlobalResolver resolver, ExecutionFrame frame) { + return frame.getBlockMemoizer().resolveSlot(slotIndex, resolver); + } + } +} diff --git a/runtime/src/main/java/dev/cel/runtime/planner/ExecutionFrame.java b/runtime/src/main/java/dev/cel/runtime/planner/ExecutionFrame.java index e29c68dd8..282b7c83a 100644 --- a/runtime/src/main/java/dev/cel/runtime/planner/ExecutionFrame.java +++ b/runtime/src/main/java/dev/cel/runtime/planner/ExecutionFrame.java @@ -30,6 +30,7 @@ final class ExecutionFrame { private final CelFunctionResolver functionResolver; private final PartialVars partialVars; private int iterationCount; + private BlockMemoizer blockMemoizer; Optional findOverload( String functionName, Collection overloadIds, Object[] args) @@ -49,6 +50,17 @@ void incrementIterations() { } } + void setBlockMemoizer(BlockMemoizer blockMemoizer) { + if (this.blockMemoizer != null) { + throw new IllegalStateException("BlockMemoizer is already initialized"); + } + this.blockMemoizer = blockMemoizer; + } + + BlockMemoizer getBlockMemoizer() { + return blockMemoizer; + } + static ExecutionFrame create( CelFunctionResolver functionResolver, PartialVars partialVars, CelOptions celOptions) { return new ExecutionFrame( diff --git a/runtime/src/main/java/dev/cel/runtime/planner/ProgramPlanner.java b/runtime/src/main/java/dev/cel/runtime/planner/ProgramPlanner.java index add918f64..9bd5f3ecd 100644 --- a/runtime/src/main/java/dev/cel/runtime/planner/ProgramPlanner.java +++ b/runtime/src/main/java/dev/cel/runtime/planner/ProgramPlanner.java @@ -164,6 +164,11 @@ private PlannedInterpretable planIdent(CelExpr celExpr, PlannerContext ctx) { } String identName = celExpr.ident().name(); + PlannedInterpretable blockSlot = maybeInterceptBlockSlot(celExpr.id(), identName).orElse(null); + if (blockSlot != null) { + return blockSlot; + } + if (ctx.isLocalVar(identName)) { return EvalAttribute.create(celExpr.id(), attributeFactory.newAbsoluteAttribute(identName)); } @@ -196,11 +201,42 @@ private PlannedInterpretable planCheckedIdent( return EvalConstant.create(id, identType); } + String identName = identRef.name(); + PlannedInterpretable blockSlot = maybeInterceptBlockSlot(id, identName).orElse(null); + if (blockSlot != null) { + return blockSlot; + } + return EvalAttribute.create(id, attributeFactory.newAbsoluteAttribute(identRef.name())); } + private Optional maybeInterceptBlockSlot(long id, String identName) { + if (!identName.startsWith("@index")) { + return Optional.empty(); + } + if (identName.length() <= 6) { + throw new IllegalArgumentException("Malformed block slot identifier: " + identName); + } + try { + int slotIndex = Integer.parseInt(identName.substring(6)); + if (slotIndex < 0) { + throw new IllegalArgumentException("Negative block slot index: " + identName); + } + return Optional.of(EvalBlock.EvalBlockSlot.create(id, slotIndex)); + } catch (NumberFormatException e) { + throw new IllegalArgumentException("Invalid block slot index: " + identName, e); + } + } + private PlannedInterpretable planCall(CelExpr expr, PlannerContext ctx) { ResolvedFunction resolvedFunction = resolveFunction(expr, ctx.referenceMap()); + String functionName = resolvedFunction.functionName(); + + PlannedInterpretable blockCall = maybeInterceptBlockCall(functionName, expr, ctx).orElse(null); + if (blockCall != null) { + return blockCall; + } + CelExpr target = resolvedFunction.target().orElse(null); int argCount = expr.call().args().size(); if (target != null) { @@ -220,7 +256,6 @@ private PlannedInterpretable planCall(CelExpr expr, PlannerContext ctx) { evaluatedArgs[argIndex + offset] = plan(args.get(argIndex), ctx); } - String functionName = resolvedFunction.functionName(); Operator operator = Operator.findReverse(functionName).orElse(null); if (operator != null) { switch (operator) { @@ -285,6 +320,28 @@ private PlannedInterpretable planCall(CelExpr expr, PlannerContext ctx) { } } + private Optional maybeInterceptBlockCall( + String functionName, CelExpr expr, PlannerContext ctx) { + if (!functionName.equals("cel.@block")) { + return Optional.empty(); + } + + CelCall blockCall = expr.call(); + + if (blockCall.args().size() != 2) { + throw new IllegalArgumentException( + "Expected 2 arguments for cel.@block call. Got: " + blockCall.args().size()); + } + + CelList exprList = blockCall.args().get(0).list(); + PlannedInterpretable[] slotExprs = new PlannedInterpretable[exprList.elements().size()]; + for (int i = 0; i < slotExprs.length; i++) { + slotExprs[i] = plan(exprList.elements().get(i), ctx); + } + PlannedInterpretable resultExpr = plan(blockCall.args().get(1), ctx); + return Optional.of(EvalBlock.create(expr.id(), slotExprs, resultExpr)); + } + /** * Intercepts a potential optional function call. * diff --git a/testing/BUILD.bazel b/testing/BUILD.bazel index c1b2a92b4..b9e68f003 100644 --- a/testing/BUILD.bazel +++ b/testing/BUILD.bazel @@ -11,6 +11,11 @@ java_library( exports = ["//testing/src/main/java/dev/cel/testing:adorner"], ) +java_library( + name = "cel_runtime_flavor", + exports = ["//testing/src/main/java/dev/cel/testing:cel_runtime_flavor"], +) + java_library( name = "line_differ", exports = ["//testing/src/main/java/dev/cel/testing:line_differ"], diff --git a/testing/src/main/java/dev/cel/testing/BUILD.bazel b/testing/src/main/java/dev/cel/testing/BUILD.bazel index 5ee142200..0d94bc8fc 100644 --- a/testing/src/main/java/dev/cel/testing/BUILD.bazel +++ b/testing/src/main/java/dev/cel/testing/BUILD.bazel @@ -105,3 +105,12 @@ java_library( "@maven_android//:com_google_protobuf_protobuf_javalite", ], ) + +java_library( + name = "cel_runtime_flavor", + srcs = ["CelRuntimeFlavor.java"], + deps = [ + "//bundle:cel", + "//bundle:cel_experimental_factory", + ], +) diff --git a/testing/src/main/java/dev/cel/testing/CelRuntimeFlavor.java b/testing/src/main/java/dev/cel/testing/CelRuntimeFlavor.java new file mode 100644 index 000000000..576e0c1d3 --- /dev/null +++ b/testing/src/main/java/dev/cel/testing/CelRuntimeFlavor.java @@ -0,0 +1,38 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package dev.cel.testing; + +import dev.cel.bundle.CelBuilder; +import dev.cel.bundle.CelExperimentalFactory; +import dev.cel.bundle.CelFactory; + +/** Enumeration of supported CEL runtime environments for testing. */ +public enum CelRuntimeFlavor { + LEGACY { + @Override + public CelBuilder builder() { + return CelFactory.standardCelBuilder(); + } + }, + PLANNER { + @Override + public CelBuilder builder() { + return CelExperimentalFactory.plannerCelBuilder(); + } + }; + + /** Returns a new {@link CelBuilder} instance for this runtime flavor. */ + public abstract CelBuilder builder(); +}