Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions lib/ramble/ramble/keywords.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,9 +59,9 @@
"repeat_index": {"type": key_type.reserved, "level": output_level.variable},
"spec_name": {"type": key_type.optional, "level": output_level.variable},
"env_name": {"type": key_type.optional, "level": output_level.variable},
"n_ranks": {"type": key_type.required, "level": output_level.key},
"n_nodes": {"type": key_type.required, "level": output_level.key},
"processes_per_node": {"type": key_type.required, "level": output_level.key},
"n_ranks": {"type": key_type.optional, "level": output_level.key},
"n_nodes": {"type": key_type.optional, "level": output_level.key},
"processes_per_node": {"type": key_type.optional, "level": output_level.key},
"n_threads": {"type": key_type.optional, "level": output_level.key},
"batch_submit": {"type": key_type.required, "level": output_level.variable},
"mpi_command": {"type": key_type.required, "level": output_level.variable},
Expand Down
16 changes: 10 additions & 6 deletions lib/ramble/ramble/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -654,12 +654,16 @@ def setup_main_options(args):

def _invoke_command(command, parser, args, unknown_args):
"""Run a ramble command *without* setting ramble global options."""
if ramble.cmd.common.arguments.allows_unknown_args(command):
return_val = command(parser, args, unknown_args)
else:
if unknown_args:
logger.die(f'unrecognized arguments: {" ".join(unknown_args)}')
return_val = command(parser, args)
try:
if ramble.cmd.common.arguments.allows_unknown_args(command):
return_val = command(parser, args, unknown_args)
else:
if unknown_args:
logger.die(f'unrecognized arguments: {" ".join(unknown_args)}')
return_val = command(parser, args)
except ramble.expander.WorkloadNotDefinedError as e:
logger.error(e)
return 1

# Allow commands to return and error code if they want
return 0 if return_val is None else return_val
Expand Down
2 changes: 1 addition & 1 deletion lib/ramble/ramble/test/application.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,7 @@ def test_application_copy_is_deep(app_name, wl_name, mutable_mock_apps_repo):

defined_internals = {
"custom_executables": {
"test_exec": {"templates": ["test_exec"], "use_mpi": False, "redirect": "{log_file}"}
"test_exec": {"template": ["test_exec"], "use_mpi": False, "redirect": "{log_file}"}
}
}

Expand Down
16 changes: 14 additions & 2 deletions lib/ramble/ramble/test/cmd/workspace_concretize.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,13 @@ def test_workspace_concretize_additive(workspace_name):
global_args = ["-w", workspace_name]

workspace(
"generate-config", "gromacs", "-p", "spack", "--wf", "water_*", global_args=global_args
"generate-config",
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Instead of editing the tests, can we use the helper method that I mentioned below in the base class, and then conditionally defined these in workspace.add_experiments rather than setting them explicitly here?

The main reason is because we want the user experience from generate-config or workspace manage experiments to be that users get a mostly functional workspace when they run these. And we should be able to identify if we need to define the MPI vars or not when they are creating experiments.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Another option actually is to write two helper functions. One to determine if MPI is required, and another to update the keywords to make the MPI variables required (if MPI is required).

Then this code:
https://github.com/GoogleCloudPlatform/ramble/blob/develop/lib/ramble/ramble/workspace/workspace.py#L1352

should automatically handle defining them.

"gromacs",
"-p",
"spack",
"--wf",
"water_*",
global_args=global_args,
)
workspace("concretize", "-q", global_args=global_args)

Expand All @@ -35,7 +41,13 @@ def test_workspace_concretize_additive(workspace_name):
assert "wrf" not in content
assert "intel-oneapi-vtune" not in content

workspace("generate-config", "wrf", "-p", "spack", global_args=global_args)
workspace(
"generate-config",
"wrf",
"-p",
"spack",
global_args=global_args,
)
workspace("concretize", "-q", global_args=global_args)

with open(ws.config_file_path) as f:
Expand Down
1 change: 1 addition & 0 deletions lib/ramble/ramble/test/end_to_end/analyze_upload.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ def test_analyze_upload(make_workspace_from_config):
mpi_command: ''
batch_submit: 'batch_submit {execute_experiment}'
processes_per_node: '1'
n_ranks: '{n_nodes}*{processes_per_node}'
applications:
hostname:
workloads:
Expand Down
1 change: 1 addition & 0 deletions lib/ramble/ramble/test/end_to_end/formatted_executables.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ def test_formatted_executables(make_workspace_from_config):
batch_submit: '{execute_experiment}'
processes_per_node: '16'
n_threads: '1'
n_ranks: '{processes_per_node}*{n_nodes}'
formatted_executables:
ws_exec_def:
prefix: 'from_ws '
Expand Down
22 changes: 11 additions & 11 deletions lib/ramble/ramble/test/experiment_set.py
Original file line number Diff line number Diff line change
Expand Up @@ -621,7 +621,7 @@ def test_n_ranks_correct_defaults(workspace_name):
}

workload_context = ramble.context.Context()
workload_context.context_name = "test_wl"
workload_context.context_name = "test_wl2"
workload_context.variables = {"wl_var1": "1", "wl_var2": "2", "processes_per_node": "2"}
experiment_context = ramble.context.Context()
experiment_context.context_name = "series1_{n_ranks}"
Expand All @@ -633,8 +633,8 @@ def test_n_ranks_correct_defaults(workspace_name):
exp_set.set_experiment_context(experiment_context)
exp_set.build_experiment_chains()

assert "basic.test_wl.series1_4" in exp_set.experiments
assert "basic.test_wl.series1_6" in exp_set.experiments
assert "basic.test_wl2.series1_4" in exp_set.experiments
assert "basic.test_wl2.series1_6" in exp_set.experiments


def test_n_nodes_correct_defaults(workspace_name):
Expand All @@ -656,7 +656,7 @@ def test_n_nodes_correct_defaults(workspace_name):
}

workload_context = ramble.context.Context()
workload_context.context_name = "test_wl"
workload_context.context_name = "test_wl2"
workload_context.variables = {"wl_var1": "1", "wl_var2": "2", "processes_per_node": "2"}
experiment_context = ramble.context.Context()
experiment_context.context_name = "series1_{n_ranks}_{n_nodes}"
Expand All @@ -671,8 +671,8 @@ def test_n_nodes_correct_defaults(workspace_name):
exp_set.set_experiment_context(experiment_context)
exp_set.build_experiment_chains()

assert "basic.test_wl.series1_4_2" in exp_set.experiments
assert "basic.test_wl.series1_6_3" in exp_set.experiments
assert "basic.test_wl2.series1_4_2" in exp_set.experiments
assert "basic.test_wl2.series1_6_3" in exp_set.experiments


def test_processes_per_node_correct_defaults(workspace_name):
Expand All @@ -696,7 +696,7 @@ def test_processes_per_node_correct_defaults(workspace_name):
}

workload_context = ramble.context.Context()
workload_context.context_name = "test_wl"
workload_context.context_name = "test_wl2"
workload_context.variables = {
"wl_var1": "1",
"wl_var2": "2",
Expand All @@ -710,8 +710,8 @@ def test_processes_per_node_correct_defaults(workspace_name):
exp_set.set_experiment_context(experiment_context)
exp_set.build_experiment_chains()

assert "basic.test_wl.series1_4_2" in exp_set.experiments
assert "basic.test_wl.series1_6_2" in exp_set.experiments
assert "basic.test_wl2.series1_4_2" in exp_set.experiments
assert "basic.test_wl2.series1_6_2" in exp_set.experiments


@pytest.mark.parametrize("var", ["env_path"])
Expand Down Expand Up @@ -836,7 +836,7 @@ def test_missing_required_keyword_errors(workspace_name):
}

workload_context = ramble.context.Context()
workload_context.context_name = "test_wl"
workload_context.context_name = "test_wl2"
workload_context.variables = {
"wl_var1": "1",
"wl_var2": "2",
Expand Down Expand Up @@ -1927,7 +1927,7 @@ def test_validation_in_render_repeat_experiments(workspace_name):
exp_set.set_application_context(app_context)

workload_context = ramble.context.Context()
workload_context.context_name = "test_wl"
workload_context.context_name = "test_wl2"
exp_set.set_workload_context(workload_context)

experiment_context = ramble.context.Context()
Expand Down
2 changes: 1 addition & 1 deletion lib/ramble/ramble/test/uploader.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def test_data_preparation(request, mock_applications):
"manage",
"experiments",
app_name,
"-w",
"--wf",
wl_name,
"-p",
"spack",
Expand Down
2 changes: 1 addition & 1 deletion lib/ramble/ramble/test/when.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,7 +239,7 @@ def test_fom_errors_when_context_not_found(workspace_name):

with pytest.raises(
RambleCommandError,
match=r"Command output:\n\n.*context 'test_context_when'.*is not found",
match=r"(?s)Command output:.*context 'test_context_when'.*is not found",
):
workspace("analyze", global_args=global_args)

Expand Down
35 changes: 34 additions & 1 deletion lib/ramble/ramble/workspace/workspace.py
Original file line number Diff line number Diff line change
Expand Up @@ -1208,7 +1208,30 @@ def yaml_add_comment_before_key(
apps_dict = self.get_applications().copy()

app_inst = ramble.repository.get(application)
app_inst.validate_version()
# Set version manually as in set_variables_and_variants
_, _, maybe_version = application.partition("@")
if maybe_version and "{" not in maybe_version:
try:
app_inst.set_version(version_number=maybe_version, description=application)
app_inst.validate_version()
except (ramble.error.RambleError, ramble.error.ObjectValidationError) as e:
# If version validation fails (e.g. unknown version in strict mode),
# we still want to allow adding the experiment to the config.
# Full validation will happen during concretization/setup.
logger.debug(f"Version initialization failed for {application}: {e}")
pass
elif hasattr(app_inst, "preferred_version"):
try:
app_inst.set_version(version=app_inst.preferred_version, description=application)
app_inst.validate_version()
except (ramble.error.RambleError, ramble.error.ObjectValidationError) as e:
# If version validation fails, we still want to allow adding the experiment.
# Full validation will happen during concretization/setup.
logger.debug(f"Version initialization failed for {application}: {e}")
pass

app_inst.variables = {}
app_inst.expander = ramble.expander.Expander({}, None)

var_def_dict = {}
def_regex = re.compile(r"\s*=\s*")
Expand Down Expand Up @@ -1297,6 +1320,16 @@ def yaml_add_comment_before_key(

for workload_name in workload_names:
edited = True
app_inst.expander._workload_name = None
app_inst.define_variable(app_inst.keywords.workload_name, workload_name)
try:
if app_inst.is_mpi_required(workload_name):
app_inst.require_mpi_variables()
except (ramble.expander.WorkloadNotDefinedError, ramble.error.RambleError) as e:
# Workload may not be defined for the active 'when' conditions.
# Skip MPI requirement checks for now as full validation occurs later.
logger.debug(f"Skipping MPI requirement check for workload {workload_name}: {e}")
pass
if workload_name not in workloads_dict:
workloads_dict[workload_name] = syaml.syaml_dict()
workloads_dict[workload_name][namespace.experiment] = syaml.syaml_dict()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -287,7 +287,9 @@ def _validate_workload(self, workload_name):
for when_set, workloads in self.workloads.items():
if workload_name in workloads:
workload_found = True
if self.expander.satisfies(
if not self.expander:
workload = workloads[workload_name]
elif self.expander.satisfies(
when_set, self.experiment_variants()
):
if workload:
Expand All @@ -304,7 +306,7 @@ def _validate_workload(self, workload_name):
f"as a workload of application {self.name}."
)
if not workload:
logger.die(
raise ramble.expander.WorkloadNotDefinedError(
f"Workload {workload_name} is not defined "
"for the active `when` conditions."
)
Expand All @@ -325,7 +327,8 @@ def get_workload(self, workload_name=None):
workload_name is not provided, retrieves the active workload.
"""
if not workload_name or (
self.expander.workload_name
self.expander
and self.expander.workload_name
and self.expander.workload_name == workload_name
):
if not self._active_workload:
Expand Down Expand Up @@ -3674,6 +3677,37 @@ def _objects(self, exclude_types=None):
for mod_inst in self._modifier_instances:
yield (ramble.repository.ObjectTypes.modifiers, mod_inst)

def require_mpi_variables(self):
self.keywords.update_keys(
{
self.keywords.n_ranks: {
"type": ramble.keywords.key_type.required,
"level": ramble.keywords.output_level.key,
},
self.keywords.processes_per_node: {
"type": ramble.keywords.key_type.required,
"level": ramble.keywords.output_level.key,
},
self.keywords.n_nodes: {
"type": ramble.keywords.key_type.required,
"level": ramble.keywords.output_level.key,
},
}
)

def is_mpi_required(self, workload_name):
for exec_node in self._get_executable_graph(workload_name).walk():
if isinstance(
exec_node.attribute,
ramble.util.executable.CommandExecutable,
):
exec_cmd = exec_node.attribute
if exec_cmd.mpi and self.expander.expand_var(
str(exec_cmd.mpi), typed=True
):
return True
return False

def set_required_variables(self):
"""Set required variables from all objects"""

Expand Down Expand Up @@ -3713,7 +3747,8 @@ def define_mpi_vars():
)
self.define_variable(var_name, value)

define_mpi_vars()
if self.is_mpi_required(self.expander.workload_name):
define_mpi_vars()

if self.keywords.n_threads not in self.variables:
self.define_variable(self.keywords.n_threads, 1)
Expand Down
Loading