From ba5a6885ec7e7da35f29ec41a4f880493b1d3f54 Mon Sep 17 00:00:00 2001 From: vishalbh Date: Fri, 6 Mar 2026 12:29:43 +0000 Subject: [PATCH] Allow MPI variables to be optional --- lib/ramble/ramble/keywords.py | 6 +-- lib/ramble/ramble/main.py | 16 ++++--- lib/ramble/ramble/test/application.py | 2 +- .../ramble/test/cmd/workspace_concretize.py | 16 ++++++- .../ramble/test/end_to_end/analyze_upload.py | 1 + .../test/end_to_end/formatted_executables.py | 1 + lib/ramble/ramble/test/experiment_set.py | 22 +++++----- lib/ramble/ramble/test/uploader.py | 2 +- lib/ramble/ramble/test/when.py | 2 +- lib/ramble/ramble/workspace/workspace.py | 35 ++++++++++++++- .../application-base/base_class.py | 43 +++++++++++++++++-- 11 files changed, 116 insertions(+), 30 deletions(-) diff --git a/lib/ramble/ramble/keywords.py b/lib/ramble/ramble/keywords.py index 2db58bd30..249fc20d9 100644 --- a/lib/ramble/ramble/keywords.py +++ b/lib/ramble/ramble/keywords.py @@ -59,9 +59,9 @@ "repeat_index": {"type": key_type.reserved, "level": output_level.variable}, "spec_name": {"type": key_type.optional, "level": output_level.variable}, "env_name": {"type": key_type.optional, "level": output_level.variable}, - "n_ranks": {"type": key_type.required, "level": output_level.key}, - "n_nodes": {"type": key_type.required, "level": output_level.key}, - "processes_per_node": {"type": key_type.required, "level": output_level.key}, + "n_ranks": {"type": key_type.optional, "level": output_level.key}, + "n_nodes": {"type": key_type.optional, "level": output_level.key}, + "processes_per_node": {"type": key_type.optional, "level": output_level.key}, "n_threads": {"type": key_type.optional, "level": output_level.key}, "batch_submit": {"type": key_type.required, "level": output_level.variable}, "mpi_command": {"type": key_type.required, "level": output_level.variable}, diff --git a/lib/ramble/ramble/main.py b/lib/ramble/ramble/main.py index 8c3e72299..3584b5f24 100644 --- a/lib/ramble/ramble/main.py +++ b/lib/ramble/ramble/main.py @@ -654,12 +654,16 @@ def setup_main_options(args): def _invoke_command(command, parser, args, unknown_args): """Run a ramble command *without* setting ramble global options.""" - if ramble.cmd.common.arguments.allows_unknown_args(command): - return_val = command(parser, args, unknown_args) - else: - if unknown_args: - logger.die(f'unrecognized arguments: {" ".join(unknown_args)}') - return_val = command(parser, args) + try: + if ramble.cmd.common.arguments.allows_unknown_args(command): + return_val = command(parser, args, unknown_args) + else: + if unknown_args: + logger.die(f'unrecognized arguments: {" ".join(unknown_args)}') + return_val = command(parser, args) + except ramble.expander.WorkloadNotDefinedError as e: + logger.error(e) + return 1 # Allow commands to return and error code if they want return 0 if return_val is None else return_val diff --git a/lib/ramble/ramble/test/application.py b/lib/ramble/ramble/test/application.py index 03764db56..d1a028200 100644 --- a/lib/ramble/ramble/test/application.py +++ b/lib/ramble/ramble/test/application.py @@ -156,7 +156,7 @@ def test_application_copy_is_deep(app_name, wl_name, mutable_mock_apps_repo): defined_internals = { "custom_executables": { - "test_exec": {"templates": ["test_exec"], "use_mpi": False, "redirect": "{log_file}"} + "test_exec": {"template": ["test_exec"], "use_mpi": False, "redirect": "{log_file}"} } } diff --git a/lib/ramble/ramble/test/cmd/workspace_concretize.py b/lib/ramble/ramble/test/cmd/workspace_concretize.py index 631b01362..85099e828 100644 --- a/lib/ramble/ramble/test/cmd/workspace_concretize.py +++ b/lib/ramble/ramble/test/cmd/workspace_concretize.py @@ -24,7 +24,13 @@ def test_workspace_concretize_additive(workspace_name): global_args = ["-w", workspace_name] workspace( - "generate-config", "gromacs", "-p", "spack", "--wf", "water_*", global_args=global_args + "generate-config", + "gromacs", + "-p", + "spack", + "--wf", + "water_*", + global_args=global_args, ) workspace("concretize", "-q", global_args=global_args) @@ -35,7 +41,13 @@ def test_workspace_concretize_additive(workspace_name): assert "wrf" not in content assert "intel-oneapi-vtune" not in content - workspace("generate-config", "wrf", "-p", "spack", global_args=global_args) + workspace( + "generate-config", + "wrf", + "-p", + "spack", + global_args=global_args, + ) workspace("concretize", "-q", global_args=global_args) with open(ws.config_file_path) as f: diff --git a/lib/ramble/ramble/test/end_to_end/analyze_upload.py b/lib/ramble/ramble/test/end_to_end/analyze_upload.py index ca00a71a7..ca47e78f0 100644 --- a/lib/ramble/ramble/test/end_to_end/analyze_upload.py +++ b/lib/ramble/ramble/test/end_to_end/analyze_upload.py @@ -32,6 +32,7 @@ def test_analyze_upload(make_workspace_from_config): mpi_command: '' batch_submit: 'batch_submit {execute_experiment}' processes_per_node: '1' + n_ranks: '{n_nodes}*{processes_per_node}' applications: hostname: workloads: diff --git a/lib/ramble/ramble/test/end_to_end/formatted_executables.py b/lib/ramble/ramble/test/end_to_end/formatted_executables.py index 854ad82c3..0c0151ece 100644 --- a/lib/ramble/ramble/test/end_to_end/formatted_executables.py +++ b/lib/ramble/ramble/test/end_to_end/formatted_executables.py @@ -31,6 +31,7 @@ def test_formatted_executables(make_workspace_from_config): batch_submit: '{execute_experiment}' processes_per_node: '16' n_threads: '1' + n_ranks: '{processes_per_node}*{n_nodes}' formatted_executables: ws_exec_def: prefix: 'from_ws ' diff --git a/lib/ramble/ramble/test/experiment_set.py b/lib/ramble/ramble/test/experiment_set.py index 9cfc0f56b..2fcbf0d5b 100644 --- a/lib/ramble/ramble/test/experiment_set.py +++ b/lib/ramble/ramble/test/experiment_set.py @@ -621,7 +621,7 @@ def test_n_ranks_correct_defaults(workspace_name): } workload_context = ramble.context.Context() - workload_context.context_name = "test_wl" + workload_context.context_name = "test_wl2" workload_context.variables = {"wl_var1": "1", "wl_var2": "2", "processes_per_node": "2"} experiment_context = ramble.context.Context() experiment_context.context_name = "series1_{n_ranks}" @@ -633,8 +633,8 @@ def test_n_ranks_correct_defaults(workspace_name): exp_set.set_experiment_context(experiment_context) exp_set.build_experiment_chains() - assert "basic.test_wl.series1_4" in exp_set.experiments - assert "basic.test_wl.series1_6" in exp_set.experiments + assert "basic.test_wl2.series1_4" in exp_set.experiments + assert "basic.test_wl2.series1_6" in exp_set.experiments def test_n_nodes_correct_defaults(workspace_name): @@ -656,7 +656,7 @@ def test_n_nodes_correct_defaults(workspace_name): } workload_context = ramble.context.Context() - workload_context.context_name = "test_wl" + workload_context.context_name = "test_wl2" workload_context.variables = {"wl_var1": "1", "wl_var2": "2", "processes_per_node": "2"} experiment_context = ramble.context.Context() experiment_context.context_name = "series1_{n_ranks}_{n_nodes}" @@ -671,8 +671,8 @@ def test_n_nodes_correct_defaults(workspace_name): exp_set.set_experiment_context(experiment_context) exp_set.build_experiment_chains() - assert "basic.test_wl.series1_4_2" in exp_set.experiments - assert "basic.test_wl.series1_6_3" in exp_set.experiments + assert "basic.test_wl2.series1_4_2" in exp_set.experiments + assert "basic.test_wl2.series1_6_3" in exp_set.experiments def test_processes_per_node_correct_defaults(workspace_name): @@ -696,7 +696,7 @@ def test_processes_per_node_correct_defaults(workspace_name): } workload_context = ramble.context.Context() - workload_context.context_name = "test_wl" + workload_context.context_name = "test_wl2" workload_context.variables = { "wl_var1": "1", "wl_var2": "2", @@ -710,8 +710,8 @@ def test_processes_per_node_correct_defaults(workspace_name): exp_set.set_experiment_context(experiment_context) exp_set.build_experiment_chains() - assert "basic.test_wl.series1_4_2" in exp_set.experiments - assert "basic.test_wl.series1_6_2" in exp_set.experiments + assert "basic.test_wl2.series1_4_2" in exp_set.experiments + assert "basic.test_wl2.series1_6_2" in exp_set.experiments @pytest.mark.parametrize("var", ["env_path"]) @@ -836,7 +836,7 @@ def test_missing_required_keyword_errors(workspace_name): } workload_context = ramble.context.Context() - workload_context.context_name = "test_wl" + workload_context.context_name = "test_wl2" workload_context.variables = { "wl_var1": "1", "wl_var2": "2", @@ -1927,7 +1927,7 @@ def test_validation_in_render_repeat_experiments(workspace_name): exp_set.set_application_context(app_context) workload_context = ramble.context.Context() - workload_context.context_name = "test_wl" + workload_context.context_name = "test_wl2" exp_set.set_workload_context(workload_context) experiment_context = ramble.context.Context() diff --git a/lib/ramble/ramble/test/uploader.py b/lib/ramble/ramble/test/uploader.py index dd87d306b..c688671a7 100644 --- a/lib/ramble/ramble/test/uploader.py +++ b/lib/ramble/ramble/test/uploader.py @@ -61,7 +61,7 @@ def test_data_preparation(request, mock_applications): "manage", "experiments", app_name, - "-w", + "--wf", wl_name, "-p", "spack", diff --git a/lib/ramble/ramble/test/when.py b/lib/ramble/ramble/test/when.py index f5b74b072..202cd4ee8 100644 --- a/lib/ramble/ramble/test/when.py +++ b/lib/ramble/ramble/test/when.py @@ -239,7 +239,7 @@ def test_fom_errors_when_context_not_found(workspace_name): with pytest.raises( RambleCommandError, - match=r"Command output:\n\n.*context 'test_context_when'.*is not found", + match=r"(?s)Command output:.*context 'test_context_when'.*is not found", ): workspace("analyze", global_args=global_args) diff --git a/lib/ramble/ramble/workspace/workspace.py b/lib/ramble/ramble/workspace/workspace.py index 8bc538944..4e6abeb62 100644 --- a/lib/ramble/ramble/workspace/workspace.py +++ b/lib/ramble/ramble/workspace/workspace.py @@ -1208,7 +1208,30 @@ def yaml_add_comment_before_key( apps_dict = self.get_applications().copy() app_inst = ramble.repository.get(application) - app_inst.validate_version() + # Set version manually as in set_variables_and_variants + _, _, maybe_version = application.partition("@") + if maybe_version and "{" not in maybe_version: + try: + app_inst.set_version(version_number=maybe_version, description=application) + app_inst.validate_version() + except (ramble.error.RambleError, ramble.error.ObjectValidationError) as e: + # If version validation fails (e.g. unknown version in strict mode), + # we still want to allow adding the experiment to the config. + # Full validation will happen during concretization/setup. + logger.debug(f"Version initialization failed for {application}: {e}") + pass + elif hasattr(app_inst, "preferred_version"): + try: + app_inst.set_version(version=app_inst.preferred_version, description=application) + app_inst.validate_version() + except (ramble.error.RambleError, ramble.error.ObjectValidationError) as e: + # If version validation fails, we still want to allow adding the experiment. + # Full validation will happen during concretization/setup. + logger.debug(f"Version initialization failed for {application}: {e}") + pass + + app_inst.variables = {} + app_inst.expander = ramble.expander.Expander({}, None) var_def_dict = {} def_regex = re.compile(r"\s*=\s*") @@ -1297,6 +1320,16 @@ def yaml_add_comment_before_key( for workload_name in workload_names: edited = True + app_inst.expander._workload_name = None + app_inst.define_variable(app_inst.keywords.workload_name, workload_name) + try: + if app_inst.is_mpi_required(workload_name): + app_inst.require_mpi_variables() + except (ramble.expander.WorkloadNotDefinedError, ramble.error.RambleError) as e: + # Workload may not be defined for the active 'when' conditions. + # Skip MPI requirement checks for now as full validation occurs later. + logger.debug(f"Skipping MPI requirement check for workload {workload_name}: {e}") + pass if workload_name not in workloads_dict: workloads_dict[workload_name] = syaml.syaml_dict() workloads_dict[workload_name][namespace.experiment] = syaml.syaml_dict() diff --git a/var/ramble/repos/builtin/base_classes/application-base/base_class.py b/var/ramble/repos/builtin/base_classes/application-base/base_class.py index 34c3e37b6..df2826248 100644 --- a/var/ramble/repos/builtin/base_classes/application-base/base_class.py +++ b/var/ramble/repos/builtin/base_classes/application-base/base_class.py @@ -287,7 +287,9 @@ def _validate_workload(self, workload_name): for when_set, workloads in self.workloads.items(): if workload_name in workloads: workload_found = True - if self.expander.satisfies( + if not self.expander: + workload = workloads[workload_name] + elif self.expander.satisfies( when_set, self.experiment_variants() ): if workload: @@ -304,7 +306,7 @@ def _validate_workload(self, workload_name): f"as a workload of application {self.name}." ) if not workload: - logger.die( + raise ramble.expander.WorkloadNotDefinedError( f"Workload {workload_name} is not defined " "for the active `when` conditions." ) @@ -325,7 +327,8 @@ def get_workload(self, workload_name=None): workload_name is not provided, retrieves the active workload. """ if not workload_name or ( - self.expander.workload_name + self.expander + and self.expander.workload_name and self.expander.workload_name == workload_name ): if not self._active_workload: @@ -3674,6 +3677,37 @@ def _objects(self, exclude_types=None): for mod_inst in self._modifier_instances: yield (ramble.repository.ObjectTypes.modifiers, mod_inst) + def require_mpi_variables(self): + self.keywords.update_keys( + { + self.keywords.n_ranks: { + "type": ramble.keywords.key_type.required, + "level": ramble.keywords.output_level.key, + }, + self.keywords.processes_per_node: { + "type": ramble.keywords.key_type.required, + "level": ramble.keywords.output_level.key, + }, + self.keywords.n_nodes: { + "type": ramble.keywords.key_type.required, + "level": ramble.keywords.output_level.key, + }, + } + ) + + def is_mpi_required(self, workload_name): + for exec_node in self._get_executable_graph(workload_name).walk(): + if isinstance( + exec_node.attribute, + ramble.util.executable.CommandExecutable, + ): + exec_cmd = exec_node.attribute + if exec_cmd.mpi and self.expander.expand_var( + str(exec_cmd.mpi), typed=True + ): + return True + return False + def set_required_variables(self): """Set required variables from all objects""" @@ -3713,7 +3747,8 @@ def define_mpi_vars(): ) self.define_variable(var_name, value) - define_mpi_vars() + if self.is_mpi_required(self.expander.workload_name): + define_mpi_vars() if self.keywords.n_threads not in self.variables: self.define_variable(self.keywords.n_threads, 1)