-
Notifications
You must be signed in to change notification settings - Fork 235
Add --tmp-dir option for archive creation
#6946
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
52325fc
3331f6d
0e3d38a
58b8d91
d3737aa
f7a7111
74dc0c6
f38ecd0
f0ae6f8
c71d1ea
2379d21
4756f5e
ea4007f
a515785
864c141
99eebb9
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -12,6 +12,7 @@ | |
| stored in a single file. | ||
| """ | ||
|
|
||
| import os | ||
| import shutil | ||
| import tempfile | ||
| from collections.abc import Sequence | ||
|
|
@@ -46,6 +47,7 @@ | |
|
|
||
| def create_archive( | ||
| entities: Optional[Iterable[Union[orm.Computer, orm.Node, orm.Group, orm.User]]], | ||
| # TODO: This is different from the CLI implementation, where OUTPUT_FILENAME is a required argument | ||
| filename: Union[None, str, Path] = None, | ||
| *, | ||
| archive_format: Optional[ArchiveFormatAbstract] = None, | ||
|
|
@@ -61,6 +63,7 @@ def create_archive( | |
| compression: int = 6, | ||
| test_run: bool = False, | ||
| backend: Optional[StorageBackend] = None, | ||
| tmp_dir: Optional[Union[str, Path]] = None, | ||
| **traversal_rules: bool, | ||
| ) -> Path: | ||
| """Export AiiDA data to an archive file. | ||
|
|
@@ -144,6 +147,11 @@ def create_archive( | |
|
|
||
| :param backend: the backend to export from. If not specified, the default backend is used. | ||
|
|
||
| :param tmp_dir: Location where the temporary directory will be written during archive creation. | ||
| The directory must exist and be writable, and defaults to the parent directory of the output file. | ||
| This parameter is useful when the output directory has limited space or when you want to use a specific | ||
| filesystem (e.g., faster storage) for temporary operations. | ||
|
|
||
| :param traversal_rules: graph traversal rules. See :const:`aiida.common.links.GraphTraversalRules` | ||
| what rule names are toggleable and what the defaults are. | ||
|
|
||
|
|
@@ -244,7 +252,7 @@ def querybuilder(): | |
| entity_ids[EntityTypes.USER].add(entry.pk) | ||
| else: | ||
| raise ArchiveExportError( | ||
| f'I was given {entry} ({type(entry)}),' ' which is not a User, Node, Computer, or Group instance' | ||
| f'I was given {entry} ({type(entry)}), which is not a User, Node, Computer, or Group instance' | ||
| ) | ||
| group_nodes, link_data = _collect_required_entities( | ||
| querybuilder, | ||
|
|
@@ -286,96 +294,118 @@ def querybuilder(): | |
|
|
||
| EXPORT_LOGGER.report(f'Creating archive with:\n{tabulate(count_summary)}') | ||
|
|
||
| # Create temporary directory in the same folder as the output file | ||
| parent_dir = filename.parent | ||
|
|
||
| if not parent_dir.exists(): | ||
| msg = "Parent directory of the export file doesn't exist." | ||
| raise ArchiveExportError(msg) | ||
| # Check if directory is writable | ||
| # Taken from: https://stackoverflow.com/a/2113511 | ||
| if not os.access(parent_dir, os.W_OK | os.X_OK): | ||
| msg = f"Specified temporary directory '{tmp_dir}' is not writable" | ||
| raise ArchiveExportError(msg) | ||
|
|
||
| # Create and open the archive for writing. | ||
| # We create in a temp dir then move to final place at end, | ||
| # so that the user cannot end up with a half written archive on errors | ||
| with tempfile.TemporaryDirectory() as tmpdir: | ||
| tmp_filename = Path(tmpdir) / 'export.zip' | ||
| with archive_format.open(tmp_filename, mode='x', compression=compression) as writer: | ||
| # add metadata | ||
| writer.update_metadata( | ||
| { | ||
| 'ctime': datetime.now().isoformat(), | ||
| 'creation_parameters': { | ||
| 'entities_starting_set': None | ||
| if entities is None | ||
| else {etype.value: list(unique) for etype, unique in starting_uuids.items() if unique}, | ||
| 'include_authinfos': include_authinfos, | ||
| 'include_comments': include_comments, | ||
| 'include_logs': include_logs, | ||
| 'graph_traversal_rules': full_traversal_rules, | ||
| }, | ||
| } | ||
| ) | ||
| # stream entity data to the archive | ||
| with get_progress_reporter()(desc='Archiving database: ', total=sum(entity_counts.values())) as progress: | ||
| for etype, ids in entity_ids.items(): | ||
| if etype == EntityTypes.NODE and strip_checkpoints: | ||
|
|
||
| def transform(row): | ||
| data = row['entity'] | ||
| if data.get('node_type', '').startswith('process.'): | ||
| data['attributes'].pop(orm.ProcessNode.CHECKPOINT_KEY, None) | ||
| return data | ||
| else: | ||
|
|
||
| def transform(row): | ||
| return row['entity'] | ||
|
|
||
| progress.set_description_str(f'Archiving database: {etype.value}s') | ||
| if ids: | ||
| for nrows, rows in batch_iter( | ||
| querybuilder() | ||
| .append( | ||
| entity_type_to_orm[etype], filters={'id': {'in': ids}}, tag='entity', project=['**'] | ||
| ) | ||
| .iterdict(batch_size=batch_size), | ||
| batch_size, | ||
| transform, | ||
| ): | ||
| writer.bulk_insert(etype, rows) | ||
| progress.update(nrows) | ||
|
|
||
| # stream links | ||
| progress.set_description_str(f'Archiving database: {EntityTypes.LINK.value}s') | ||
|
|
||
| def transform(d): | ||
| return { | ||
| 'input_id': d.source_id, | ||
| 'output_id': d.target_id, | ||
| 'label': d.link_label, | ||
| 'type': d.link_type, | ||
| try: | ||
| with tempfile.TemporaryDirectory(dir=tmp_dir, prefix='.aiida-export-') as tmpdir: | ||
| tmp_filename = Path(tmpdir) / 'export.zip' | ||
| with archive_format.open(tmp_filename, mode='x', compression=compression) as writer: | ||
| # add metadata | ||
| writer.update_metadata( | ||
| { | ||
| 'ctime': datetime.now().isoformat(), | ||
| 'creation_parameters': { | ||
| 'entities_starting_set': None | ||
| if entities is None | ||
| else {etype.value: list(unique) for etype, unique in starting_uuids.items() if unique}, | ||
| 'include_authinfos': include_authinfos, | ||
| 'include_comments': include_comments, | ||
| 'include_logs': include_logs, | ||
| 'graph_traversal_rules': full_traversal_rules, | ||
| }, | ||
| } | ||
|
|
||
| for nrows, rows in batch_iter(link_data, batch_size, transform): | ||
| writer.bulk_insert(EntityTypes.LINK, rows, allow_defaults=True) | ||
| progress.update(nrows) | ||
| del link_data # release memory | ||
|
|
||
| # stream group_nodes | ||
| progress.set_description_str(f'Archiving database: {EntityTypes.GROUP_NODE.value}s') | ||
|
|
||
| def transform(d): | ||
| return {'dbgroup_id': d[0], 'dbnode_id': d[1]} | ||
|
|
||
| for nrows, rows in batch_iter(group_nodes, batch_size, transform): | ||
| writer.bulk_insert(EntityTypes.GROUP_NODE, rows, allow_defaults=True) | ||
| progress.update(nrows) | ||
| del group_nodes # release memory | ||
| ) | ||
| # stream entity data to the archive | ||
| with get_progress_reporter()( | ||
| desc='Archiving database: ', total=sum(entity_counts.values()) | ||
| ) as progress: | ||
| for etype, ids in entity_ids.items(): | ||
| if etype == EntityTypes.NODE and strip_checkpoints: | ||
|
|
||
| def transform(row): | ||
| data = row['entity'] | ||
| if data.get('node_type', '').startswith('process.'): | ||
| data['attributes'].pop(orm.ProcessNode.CHECKPOINT_KEY, None) | ||
| return data | ||
| else: | ||
|
|
||
| def transform(row): | ||
| return row['entity'] | ||
|
|
||
| progress.set_description_str(f'Archiving database: {etype.value}s') | ||
| if ids: | ||
| for nrows, rows in batch_iter( | ||
| querybuilder() | ||
| .append( | ||
| entity_type_to_orm[etype], filters={'id': {'in': ids}}, tag='entity', project=['**'] | ||
| ) | ||
| .iterdict(batch_size=batch_size), | ||
| batch_size, | ||
| transform, | ||
| ): | ||
| writer.bulk_insert(etype, rows) | ||
| progress.update(nrows) | ||
|
|
||
| # stream links | ||
| progress.set_description_str(f'Archiving database: {EntityTypes.LINK.value}s') | ||
|
|
||
| def transform(d): | ||
| return { | ||
| 'input_id': d.source_id, | ||
| 'output_id': d.target_id, | ||
| 'label': d.link_label, | ||
| 'type': d.link_type, | ||
| } | ||
|
|
||
| for nrows, rows in batch_iter(link_data, batch_size, transform): | ||
| writer.bulk_insert(EntityTypes.LINK, rows, allow_defaults=True) | ||
| progress.update(nrows) | ||
| del link_data # release memory | ||
|
|
||
| # stream group_nodes | ||
| progress.set_description_str(f'Archiving database: {EntityTypes.GROUP_NODE.value}s') | ||
|
|
||
| def transform(d): | ||
| return {'dbgroup_id': d[0], 'dbnode_id': d[1]} | ||
|
|
||
| for nrows, rows in batch_iter(group_nodes, batch_size, transform): | ||
| writer.bulk_insert(EntityTypes.GROUP_NODE, rows, allow_defaults=True) | ||
| progress.update(nrows) | ||
| del group_nodes # release memory | ||
|
|
||
| # stream node repository files to the archive | ||
| if entity_ids[EntityTypes.NODE]: | ||
| _stream_repo_files( | ||
| archive_format.key_format, writer, entity_ids[EntityTypes.NODE], backend, batch_size, filter_size | ||
| ) | ||
|
|
||
| EXPORT_LOGGER.report('Finalizing archive creation...') | ||
| EXPORT_LOGGER.report('Finalizing archive creation...') | ||
|
|
||
| if filename.exists(): | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I know it's not part you just put in the |
||
| filename.unlink() | ||
|
|
||
| if filename.exists(): | ||
| filename.unlink() | ||
| filename.parent.mkdir(parents=True, exist_ok=True) | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Again, not your code, but something to think about: Apparently you can just write the archive to a folder that doesn't exist and the command will create it and all non-existent parents. I think that's ok. But the behaviour is then different from the |
||
| shutil.move(tmp_filename, filename) | ||
| except OSError as e: | ||
| if e.errno == 28: # No space left on device | ||
| msg = f"Insufficient disk space in temporary directory '{tmp_dir}'. " | ||
| raise ArchiveExportError(msg) from e | ||
|
|
||
| filename.parent.mkdir(parents=True, exist_ok=True) | ||
| shutil.move(tmp_filename, filename) | ||
| msg = f'Failed to create temporary directory: {e}' | ||
| raise ArchiveExportError(msg) from e | ||
|
|
||
| EXPORT_LOGGER.report('Archive created successfully') | ||
|
|
||
|
|
@@ -680,7 +710,7 @@ def _check_unsealed_nodes(querybuilder: QbType, node_ids: set[int], batch_size: | |
| if unsealed_node_pks: | ||
| raise ExportValidationError( | ||
| 'All ProcessNodes must be sealed before they can be exported. ' | ||
| f"Node(s) with PK(s): {', '.join(str(pk) for pk in unsealed_node_pks)} is/are not sealed." | ||
| f'Node(s) with PK(s): {", ".join(str(pk) for pk in unsealed_node_pks)} is/are not sealed.' | ||
| ) | ||
|
|
||
|
|
||
|
|
@@ -775,18 +805,18 @@ def get_init_summary( | |
| """Get summary for archive initialisation""" | ||
| parameters: list[list[Any]] = [['Path', str(outfile)], ['Version', archive_version], ['Compression', compression]] | ||
|
|
||
| result = f"\n{tabulate(parameters, headers=['Archive Parameters', ''])}" | ||
| result = f'\n{tabulate(parameters, headers=["Archive Parameters", ""])}' | ||
|
|
||
| inclusions: list[list[Any]] = [ | ||
| ['Computers/Nodes/Groups/Users', 'All' if collect_all else 'Selected'], | ||
| ['Computer Authinfos', include_authinfos], | ||
| ['Node Comments', include_comments], | ||
| ['Node Logs', include_logs], | ||
| ] | ||
| result += f"\n\n{tabulate(inclusions, headers=['Inclusion rules', ''])}" | ||
| result += f'\n\n{tabulate(inclusions, headers=["Inclusion rules", ""])}' | ||
|
|
||
| if not collect_all: | ||
| rules_table = [[f"Follow links {' '.join(name.split('_'))}s", value] for name, value in traversal_rules.items()] | ||
| result += f"\n\n{tabulate(rules_table, headers=['Traversal rules', ''])}" | ||
| rules_table = [[f'Follow links {" ".join(name.split("_"))}s', value] for name, value in traversal_rules.items()] | ||
| result += f'\n\n{tabulate(rules_table, headers=["Traversal rules", ""])}' | ||
|
|
||
| return result + '\n' | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Diff is large here because of the additional try-except and the indent that goes with it (to capture disk-space errors) but the actual code inside should be the same!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
git diff --ignore-all-space main src/aiida/tools/archive/create.pygives instead only:There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the clarification @GeigerJ2! This'll help in my review <3