From cec0240a34aa1ceafd5727a4fa3f97ea1984f99b Mon Sep 17 00:00:00 2001 From: Jack Li Date: Wed, 6 Aug 2025 10:24:34 -0400 Subject: [PATCH 01/15] fixes to prediction upload --- polaris/benchmark/_benchmark_v2.py | 1 - polaris/hub/client.py | 2 +- polaris/prediction/_predictions_v2.py | 25 ++++++++++++++++---- polaris/utils/zarr/codecs.py | 33 +++++++++++++++++++++++++++ 4 files changed, 55 insertions(+), 6 deletions(-) diff --git a/polaris/benchmark/_benchmark_v2.py b/polaris/benchmark/_benchmark_v2.py index 5a568153..3a4af28b 100644 --- a/polaris/benchmark/_benchmark_v2.py +++ b/polaris/benchmark/_benchmark_v2.py @@ -30,7 +30,6 @@ class BenchmarkV2Specification( Attributes: dataset: The dataset the benchmark specification is based on. - splits: The predefined train-test splits to use for evaluation. n_classes: The number of classes for each of the target columns. readme: Markdown text that can be used to provide a formatted description of the benchmark. artifact_version: The version of the benchmark. diff --git a/polaris/hub/client.py b/polaris/hub/client.py index 3345bcc0..47602f77 100644 --- a/polaris/hub/client.py +++ b/polaris/hub/client.py @@ -766,7 +766,7 @@ def submit_benchmark_predictions( # Set owner prediction.owner = HubOwner.normalize(owner or prediction.owner) - prediction_json = prediction.model_dump(by_alias=True, exclude_none=True) + prediction_json = prediction.model_dump(by_alias=True, exclude_none=True, exclude={"predictions"}) # Step 1: Upload metadata to Hub with track_progress(description="Uploading prediction metadata", total=1) as (progress, task): diff --git a/polaris/prediction/_predictions_v2.py b/polaris/prediction/_predictions_v2.py index 2bb0a250..b97770c4 100644 --- a/polaris/prediction/_predictions_v2.py +++ b/polaris/prediction/_predictions_v2.py @@ -9,10 +9,12 @@ import zarr from pydantic import ( PrivateAttr, + Field, model_validator, ) from polaris.utils.zarr._manifest import generate_zarr_manifest, calculate_file_md5 +from polaris.utils.zarr.codecs import detect_object_codec_and_chunking from polaris.evaluate import ResultsMetadataV2 from polaris.evaluate._predictions import BenchmarkPredictions @@ -33,7 +35,7 @@ class BenchmarkPredictionsV2(BenchmarkPredictions, ResultsMetadataV2): For additional metadata attributes, see the base classes. """ - dataset_zarr_root: zarr.Group + dataset_zarr_root: zarr.Group = Field(exclude=True) # Zarr Group cannot be JSON serialized benchmark_artifact_id: str _artifact_type = "prediction" _zarr_root_path: str | None = PrivateAttr(None) @@ -70,13 +72,28 @@ def to_zarr(self) -> Path: for col in self.target_labels: data = test_set_predictions[col] template = dataset_root[col] + + # Use utility function to detect codec and chunking compatibility + chunks = template.chunks + if template.dtype == object: + object_codec, filters, chunks_compatible = detect_object_codec_and_chunking( + template.filters + ) + # Disable chunking if codec doesn't support it + if not chunks_compatible: + chunks = None + else: + object_codec = None + filters = list(template.filters) if template.filters else [] + test_set_group.array( name=col, data=data, dtype=template.dtype, compressor=template.compressor, - filters=template.filters, - chunks=template.chunks, + filters=filters, + chunks=chunks, + object_codec=object_codec, overwrite=True, ) @@ -142,7 +159,7 @@ def has_zarr_manifest_md5sum(self): return self._zarr_manifest_md5sum is not None def __repr__(self): - return self.model_dump_json(by_alias=True, indent=2) + return self.model_dump_json(by_alias=True, exclude={"predictions"}, indent=2) def __str__(self): return self.__repr__() diff --git a/polaris/utils/zarr/codecs.py b/polaris/utils/zarr/codecs.py index 0459b684..4c1e4141 100644 --- a/polaris/utils/zarr/codecs.py +++ b/polaris/utils/zarr/codecs.py @@ -17,6 +17,7 @@ class RDKitMolCodec(VLenBytes): """ codec_id = "rdkit_mol" + supports_chunking = True def encode(self, buf: np.ndarray): """ @@ -71,6 +72,7 @@ class AtomArrayCodec(MsgPack): """ codec_id = "atom_array" + supports_chunking = False def encode(self, buf: np.ndarray): """ @@ -82,6 +84,7 @@ def encode(self, buf: np.ndarray): for idx, atom_array in enumerate(buf): # A chunk can have missing values if atom_array is None: + to_pack[idx] = None # Explicitly set None values continue if not isinstance(atom_array, struc.AtomArray): @@ -149,3 +152,33 @@ def decode(self, buf, out=None): register_codec(RDKitMolCodec) register_codec(AtomArrayCodec) + + +def detect_object_codec_and_chunking(template_filters=None): + """ + Detect the appropriate object codec and chunking settings from template filters. + + Returns: + tuple: (object_codec, filters, chunks_compatible) + """ + from numcodecs import MsgPack + + filters = list(template_filters) if template_filters else [] + object_codec = None + chunks_compatible = True + + # Check if codec exists in filters (Zarr stores object_codec as part of filters) + for filter_codec in filters: + if hasattr(filter_codec, 'supports_chunking'): # Our custom codecs + object_codec = filter_codec + chunks_compatible = getattr(filter_codec, 'supports_chunking', True) + # Remove from filters since we'll pass it as object_codec + filters = [f for f in filters if f is not filter_codec] + + # Remove MsgPack filters for MsgPack-based codecs to avoid double encoding + if isinstance(filter_codec, MsgPack): + filters = [f for f in filters if not isinstance(f, MsgPack)] + + return object_codec, filters, chunks_compatible + + return object_codec, filters, chunks_compatible From 2b626f74d61dc719885a159e49d665c1fe7e4aeb Mon Sep 17 00:00:00 2001 From: Jack Li Date: Wed, 6 Aug 2025 10:25:31 -0400 Subject: [PATCH 02/15] formatting --- polaris/prediction/_predictions_v2.py | 4 ++-- polaris/utils/zarr/codecs.py | 16 ++++++++-------- 2 files changed, 10 insertions(+), 10 deletions(-) diff --git a/polaris/prediction/_predictions_v2.py b/polaris/prediction/_predictions_v2.py index b97770c4..9fe18228 100644 --- a/polaris/prediction/_predictions_v2.py +++ b/polaris/prediction/_predictions_v2.py @@ -72,7 +72,7 @@ def to_zarr(self) -> Path: for col in self.target_labels: data = test_set_predictions[col] template = dataset_root[col] - + # Use utility function to detect codec and chunking compatibility chunks = template.chunks if template.dtype == object: @@ -85,7 +85,7 @@ def to_zarr(self) -> Path: else: object_codec = None filters = list(template.filters) if template.filters else [] - + test_set_group.array( name=col, data=data, diff --git a/polaris/utils/zarr/codecs.py b/polaris/utils/zarr/codecs.py index 4c1e4141..b3614db5 100644 --- a/polaris/utils/zarr/codecs.py +++ b/polaris/utils/zarr/codecs.py @@ -157,28 +157,28 @@ def decode(self, buf, out=None): def detect_object_codec_and_chunking(template_filters=None): """ Detect the appropriate object codec and chunking settings from template filters. - + Returns: tuple: (object_codec, filters, chunks_compatible) """ from numcodecs import MsgPack - + filters = list(template_filters) if template_filters else [] object_codec = None chunks_compatible = True - + # Check if codec exists in filters (Zarr stores object_codec as part of filters) for filter_codec in filters: - if hasattr(filter_codec, 'supports_chunking'): # Our custom codecs + if hasattr(filter_codec, "supports_chunking"): # Our custom codecs object_codec = filter_codec - chunks_compatible = getattr(filter_codec, 'supports_chunking', True) + chunks_compatible = getattr(filter_codec, "supports_chunking", True) # Remove from filters since we'll pass it as object_codec filters = [f for f in filters if f is not filter_codec] - + # Remove MsgPack filters for MsgPack-based codecs to avoid double encoding if isinstance(filter_codec, MsgPack): filters = [f for f in filters if not isinstance(f, MsgPack)] - + return object_codec, filters, chunks_compatible - + return object_codec, filters, chunks_compatible From 56672b2b3e6b75740a2e5d6367fca2aeb8df2902 Mon Sep 17 00:00:00 2001 From: Jack Li Date: Wed, 6 Aug 2025 16:11:04 -0400 Subject: [PATCH 03/15] refactor for the zarr archive to use standard codecs with conversion to and from custom codecs handled by the client --- polaris/prediction/_predictions_v2.py | 116 +++++++++++++++++++++---- polaris/utils/zarr/codecs.py | 99 +++++++++++++++------ tests/test_benchmark_predictions_v2.py | 37 +++++--- 3 files changed, 194 insertions(+), 58 deletions(-) diff --git a/polaris/prediction/_predictions_v2.py b/polaris/prediction/_predictions_v2.py index 9fe18228..82354ef4 100644 --- a/polaris/prediction/_predictions_v2.py +++ b/polaris/prediction/_predictions_v2.py @@ -12,9 +12,17 @@ Field, model_validator, ) +from numcodecs import MsgPack, VLenBytes +from fastpdb import struc +from rdkit import Chem from polaris.utils.zarr._manifest import generate_zarr_manifest, calculate_file_md5 -from polaris.utils.zarr.codecs import detect_object_codec_and_chunking +from polaris.utils.zarr.codecs import ( + convert_dict_to_atomarray, + convert_bytes_to_mol, + convert_atomarray_to_dict, + convert_mol_to_bytes, +) from polaris.evaluate import ResultsMetadataV2 from polaris.evaluate._predictions import BenchmarkPredictions @@ -63,48 +71,122 @@ def to_zarr(self) -> Path: This method should be called explicitly when ready to write predictions to disk. """ - root = self.zarr_root + # Get raw zarr root for writing (not the converting wrapper) + store = zarr.DirectoryStore(self.zarr_root_path) + raw_root = zarr.group(store=store) dataset_root = self.dataset_zarr_root for test_set_label, test_set_predictions in self.predictions.items(): # Create a group for each test set - test_set_group = root.require_group(test_set_label) + test_set_group = raw_root.require_group(test_set_label) for col in self.target_labels: data = test_set_predictions[col] template = dataset_root[col] - # Use utility function to detect codec and chunking compatibility - chunks = template.chunks + # Handle object data conversion if template.dtype == object: - object_codec, filters, chunks_compatible = detect_object_codec_and_chunking( - template.filters - ) - # Disable chunking if codec doesn't support it - if not chunks_compatible: - chunks = None + # Find first non-None item to determine conversion type + sample_item = next((item for item in data if item is not None), None) + + # Define conversion mapping + conversion_map = { + struc.AtomArray: (convert_atomarray_to_dict, MsgPack()), + Chem.Mol: (convert_mol_to_bytes, VLenBytes()), + } + + if sample_item is not None: + # Find matching conversion for sample item type + converter_func, codec = None, MsgPack() # Default fallback + for obj_type, (conv_func, conv_codec) in conversion_map.items(): + if isinstance(sample_item, obj_type): + converter_func, codec = conv_func, conv_codec + break + + # Apply conversion if we found a matching converter + if converter_func is not None: + final_data = [converter_func(item) if item is not None else None for item in data] + else: + # No converter found - store as-is (shouldn't happen with current types) + final_data = list(data) + else: + # All items are None + codec = MsgPack() + final_data = list(data) + + # Object data uses converted data and custom filters + filters = [codec] else: - object_codec = None - filters = list(template.filters) if template.filters else [] + # Non-object data uses original data and template filters + final_data = data + filters = template.filters + # Single array creation for both cases test_set_group.array( name=col, - data=data, + data=final_data, dtype=template.dtype, compressor=template.compressor, filters=filters, - chunks=chunks, - object_codec=object_codec, + chunks=template.chunks, overwrite=True, ) return Path(self.zarr_root_path) + def get_converted_predictions(self) -> dict: + """Get all predictions with automatic conversion back to original object types. + + Returns: + Full predictions dictionary with same structure as self.predictions, + but with object data converted back to original types (AtomArray, RDKit Mol) + """ + converted_predictions = {} + + for test_set_label in self.test_set_labels: + if test_set_label not in self.zarr_root: + continue + + test_set_group = self.zarr_root[test_set_label] + converted_predictions[test_set_label] = {} + + for target in self.target_labels: + if target not in test_set_group: + continue + + zarr_array = test_set_group[target] + data = zarr_array[:] + + # Check if this needs conversion (object data) + template = self.dataset_zarr_root[target] + if template.dtype == object: + # Use filters to determine conversion (simple and reliable) + filters = zarr_array.filters or [] + if any(isinstance(f, MsgPack) for f in filters): + converter = convert_dict_to_atomarray + elif any(isinstance(f, VLenBytes) for f in filters): + converter = convert_bytes_to_mol + else: + converter = None + + # Apply conversion + if converter: + converted = [converter(item) if item is not None else None for item in data] + converted_predictions[test_set_label][target] = np.array(converted, dtype=object) + else: + converted_predictions[test_set_label][target] = data + else: + # Non-object data, use as-is + converted_predictions[test_set_label][target] = data + + return converted_predictions + @property def zarr_root(self) -> zarr.Group: """Get the zarr Group object corresponding to the root, creating it if it doesn't exist.""" if self._zarr_root is None: store = zarr.DirectoryStore(self.zarr_root_path) - self._zarr_root = zarr.group(store=store) + raw_root = zarr.group(store=store) + self._zarr_root = raw_root return self._zarr_root @property diff --git a/polaris/utils/zarr/codecs.py b/polaris/utils/zarr/codecs.py index b3614db5..22e3ae06 100644 --- a/polaris/utils/zarr/codecs.py +++ b/polaris/utils/zarr/codecs.py @@ -154,31 +154,74 @@ def decode(self, buf, out=None): register_codec(AtomArrayCodec) -def detect_object_codec_and_chunking(template_filters=None): - """ - Detect the appropriate object codec and chunking settings from template filters. - - Returns: - tuple: (object_codec, filters, chunks_compatible) - """ - from numcodecs import MsgPack - - filters = list(template_filters) if template_filters else [] - object_codec = None - chunks_compatible = True - - # Check if codec exists in filters (Zarr stores object_codec as part of filters) - for filter_codec in filters: - if hasattr(filter_codec, "supports_chunking"): # Our custom codecs - object_codec = filter_codec - chunks_compatible = getattr(filter_codec, "supports_chunking", True) - # Remove from filters since we'll pass it as object_codec - filters = [f for f in filters if f is not filter_codec] - - # Remove MsgPack filters for MsgPack-based codecs to avoid double encoding - if isinstance(filter_codec, MsgPack): - filters = [f for f in filters if not isinstance(f, MsgPack)] - - return object_codec, filters, chunks_compatible - - return object_codec, filters, chunks_compatible +def convert_atomarray_to_dict(atom_array): + """Convert AtomArray to a dict that can be stored with standard MsgPack codec.""" + if atom_array is None: + return None + + if not isinstance(atom_array, struc.AtomArray): + raise ValueError(f"Expected an AtomArray, but got {type(atom_array)} instead") + + data = { + "coord": atom_array.coord, + "chain_id": atom_array.chain_id, + "res_id": atom_array.res_id, + "ins_code": atom_array.ins_code, + "res_name": atom_array.res_name, + "hetero": atom_array.hetero, + "atom_name": atom_array.atom_name, + "element": atom_array.element, + "atom_id": atom_array.atom_id, + "b_factor": atom_array.b_factor, + "occupancy": atom_array.occupancy, + "charge": atom_array.charge, + } + return {k: v.tolist() for k, v in data.items()} + + +def convert_dict_to_atomarray(data): + """Convert dict back to AtomArray.""" + if data is None: + return None + + atom_array = [] + array_length = len(data["coord"]) + + for ind in range(array_length): + atom = struc.Atom( + coord=data["coord"][ind], + chain_id=data["chain_id"][ind], + res_id=data["res_id"][ind], + ins_code=data["ins_code"][ind], + res_name=data["res_name"][ind], + hetero=data["hetero"][ind], + atom_name=data["atom_name"][ind], + element=data["element"][ind], + b_factor=data["b_factor"][ind], + occupancy=data["occupancy"][ind], + charge=data["charge"][ind], + atom_id=data["atom_id"][ind], + ) + atom_array.append(atom) + + return struc.array(atom_array) + + +def convert_mol_to_bytes(mol): + """Convert RDKit Mol to bytes that can be stored with standard VLenBytes codec.""" + if mol is None or (isinstance(mol, bytes) and len(mol) == 0): + return b"" + + if not isinstance(mol, Chem.Mol): + raise ValueError(f"Expected an RDKitMol, but got {type(mol)} instead.") + + props = Chem.PropertyPickleOptions.AllProps + return mol.ToBinary(props) + + +def convert_bytes_to_mol(mol_bytes): + """Convert bytes back to RDKit Mol.""" + if len(mol_bytes) == 0: + return None + + return Chem.Mol(mol_bytes) diff --git a/tests/test_benchmark_predictions_v2.py b/tests/test_benchmark_predictions_v2.py index 83bd5330..801095ae 100644 --- a/tests/test_benchmark_predictions_v2.py +++ b/tests/test_benchmark_predictions_v2.py @@ -1,5 +1,4 @@ from polaris.prediction._predictions_v2 import BenchmarkPredictionsV2 -from polaris.utils.zarr.codecs import RDKitMolCodec, AtomArrayCodec from rdkit import Chem import numpy as np import pytest @@ -35,20 +34,26 @@ def test_v2_rdkit_object_codec(v2_benchmark_with_rdkit_object_dtype): assert bp.predictions["test"]["expt"].dtype == object assert_deep_equal(bp.predictions, {"test": {"expt": np.array(mols, dtype=object)}}) - # Check Zarr archive + # Check Zarr archive by reading through the BenchmarkPredictionsV2 object zarr_path = bp.to_zarr() assert zarr_path.exists() - root = zarr.open(str(zarr_path), mode="r") - arr = root["test"]["expt"][:] + + # Use the get_converted_predictions method to get all converted data + converted_predictions = bp.get_converted_predictions() + arr = converted_predictions["test"]["expt"] arr_smiles = [Chem.MolToSmiles(m) for m in arr] mols_smiles = [Chem.MolToSmiles(m) for m in mols] assert arr_smiles == mols_smiles - # Check that object_codec is correctly set as a filter (Zarr stores object_codec as filters) - zarr_array = root["test"]["expt"] + # Check that standard codec is used (direct zarr access shows raw format) + raw_root = zarr.open(str(zarr_path), mode="r") + zarr_array = raw_root["test"]["expt"] assert zarr_array.filters is not None assert len(zarr_array.filters) > 0 - assert any(isinstance(f, RDKitMolCodec) for f in zarr_array.filters) + # Now we expect VLenBytes instead of custom codec + from numcodecs import VLenBytes + + assert any(isinstance(f, VLenBytes) for f in zarr_array.filters) def test_v2_atomarray_object_codec(v2_benchmark_with_atomarray_object_dtype, pdbs_structs): @@ -66,20 +71,26 @@ def test_v2_atomarray_object_codec(v2_benchmark_with_atomarray_object_dtype, pdb assert bp.predictions["test"]["expt"].dtype == object assert_deep_equal(bp.predictions, {"test": {"expt": np.array(pdbs_structs[:2], dtype=object)}}) - # Check Zarr archive (dtype and shape only) + # Check Zarr archive by reading through the BenchmarkPredictionsV2 object zarr_path = bp.to_zarr() assert zarr_path.exists() - root = zarr.open(str(zarr_path), mode="r") - arr = root["test"]["expt"][:] + + # Use the get_converted_predictions method to get all converted data + converted_predictions = bp.get_converted_predictions() + arr = converted_predictions["test"]["expt"] assert arr.dtype == object assert arr.shape == (2,) assert all(isinstance(x, struc.AtomArray) for x in arr) - # Check that object_codec is correctly set as a filter (Zarr stores object_codec as filters) - zarr_array = root["test"]["expt"] + # Check that standard codec is used (direct zarr access shows raw format) + raw_root = zarr.open(str(zarr_path), mode="r") + zarr_array = raw_root["test"]["expt"] assert zarr_array.filters is not None assert len(zarr_array.filters) > 0 - assert any(isinstance(f, AtomArrayCodec) for f in zarr_array.filters) + # Now we expect MsgPack instead of custom codec + from numcodecs import MsgPack + + assert any(isinstance(f, MsgPack) for f in zarr_array.filters) def test_v2_dtype_mismatch_raises(test_benchmark_v2): From f4121e0e5169dc98d28fae9e2804316867018aa6 Mon Sep 17 00:00:00 2001 From: Jack Li Date: Wed, 6 Aug 2025 16:19:29 -0400 Subject: [PATCH 04/15] removed chunking check --- polaris/utils/zarr/codecs.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/polaris/utils/zarr/codecs.py b/polaris/utils/zarr/codecs.py index 22e3ae06..50666fc3 100644 --- a/polaris/utils/zarr/codecs.py +++ b/polaris/utils/zarr/codecs.py @@ -17,7 +17,6 @@ class RDKitMolCodec(VLenBytes): """ codec_id = "rdkit_mol" - supports_chunking = True def encode(self, buf: np.ndarray): """ @@ -72,7 +71,6 @@ class AtomArrayCodec(MsgPack): """ codec_id = "atom_array" - supports_chunking = False def encode(self, buf: np.ndarray): """ From ac3700fe0e42a5bf6e77ab42bebe9303e53494c3 Mon Sep 17 00:00:00 2001 From: Jack Li <73399568+j279li@users.noreply.github.com> Date: Thu, 7 Aug 2025 12:29:50 -0400 Subject: [PATCH 05/15] Update polaris/utils/zarr/codecs.py Co-authored-by: Cas Wognum --- polaris/utils/zarr/codecs.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/polaris/utils/zarr/codecs.py b/polaris/utils/zarr/codecs.py index 50666fc3..2b89f42e 100644 --- a/polaris/utils/zarr/codecs.py +++ b/polaris/utils/zarr/codecs.py @@ -152,7 +152,7 @@ def decode(self, buf, out=None): register_codec(AtomArrayCodec) -def convert_atomarray_to_dict(atom_array): +def convert_atomarray_to_dict(atom_array: AtomArray | None) -> dict[str, list] | None: """Convert AtomArray to a dict that can be stored with standard MsgPack codec.""" if atom_array is None: return None From 16927fe546ed4b8a2f0b54c3ac342bc9cfafdad1 Mon Sep 17 00:00:00 2001 From: Jack Li <73399568+j279li@users.noreply.github.com> Date: Thu, 7 Aug 2025 12:29:59 -0400 Subject: [PATCH 06/15] Update polaris/utils/zarr/codecs.py Co-authored-by: Cas Wognum --- polaris/utils/zarr/codecs.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/polaris/utils/zarr/codecs.py b/polaris/utils/zarr/codecs.py index 2b89f42e..24801f21 100644 --- a/polaris/utils/zarr/codecs.py +++ b/polaris/utils/zarr/codecs.py @@ -177,7 +177,7 @@ def convert_atomarray_to_dict(atom_array: AtomArray | None) -> dict[str, list] | return {k: v.tolist() for k, v in data.items()} -def convert_dict_to_atomarray(data): +def convert_dict_to_atomarray(data: dict) -> AtomArray: """Convert dict back to AtomArray.""" if data is None: return None From 0375efb5daccc52240ad594df99bf801779b4d6c Mon Sep 17 00:00:00 2001 From: Jack Li <73399568+j279li@users.noreply.github.com> Date: Thu, 7 Aug 2025 12:30:06 -0400 Subject: [PATCH 07/15] Update polaris/utils/zarr/codecs.py Co-authored-by: Cas Wognum --- polaris/utils/zarr/codecs.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/polaris/utils/zarr/codecs.py b/polaris/utils/zarr/codecs.py index 24801f21..334c8cb9 100644 --- a/polaris/utils/zarr/codecs.py +++ b/polaris/utils/zarr/codecs.py @@ -205,7 +205,7 @@ def convert_dict_to_atomarray(data: dict) -> AtomArray: return struc.array(atom_array) -def convert_mol_to_bytes(mol): +def convert_mol_to_bytes(mol: rdkit.Chem.Mol | None) -> bytes: """Convert RDKit Mol to bytes that can be stored with standard VLenBytes codec.""" if mol is None or (isinstance(mol, bytes) and len(mol) == 0): return b"" From 88caf59c1d1cf6fe154e67fe586d32401bda43d2 Mon Sep 17 00:00:00 2001 From: Jack Li <73399568+j279li@users.noreply.github.com> Date: Thu, 7 Aug 2025 12:30:27 -0400 Subject: [PATCH 08/15] Update polaris/utils/zarr/codecs.py Co-authored-by: Cas Wognum --- polaris/utils/zarr/codecs.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/polaris/utils/zarr/codecs.py b/polaris/utils/zarr/codecs.py index 334c8cb9..40b069ea 100644 --- a/polaris/utils/zarr/codecs.py +++ b/polaris/utils/zarr/codecs.py @@ -207,7 +207,7 @@ def convert_dict_to_atomarray(data: dict) -> AtomArray: def convert_mol_to_bytes(mol: rdkit.Chem.Mol | None) -> bytes: """Convert RDKit Mol to bytes that can be stored with standard VLenBytes codec.""" - if mol is None or (isinstance(mol, bytes) and len(mol) == 0): + if mol is None: return b"" if not isinstance(mol, Chem.Mol): From 94e6cc3489ae49ccfcfac33f0011a9bacce7d44e Mon Sep 17 00:00:00 2001 From: Jack Li <73399568+j279li@users.noreply.github.com> Date: Thu, 7 Aug 2025 12:30:34 -0400 Subject: [PATCH 09/15] Update polaris/utils/zarr/codecs.py Co-authored-by: Cas Wognum --- polaris/utils/zarr/codecs.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/polaris/utils/zarr/codecs.py b/polaris/utils/zarr/codecs.py index 40b069ea..537cf244 100644 --- a/polaris/utils/zarr/codecs.py +++ b/polaris/utils/zarr/codecs.py @@ -217,7 +217,7 @@ def convert_mol_to_bytes(mol: rdkit.Chem.Mol | None) -> bytes: return mol.ToBinary(props) -def convert_bytes_to_mol(mol_bytes): +def convert_bytes_to_mol(mol_bytes: bytes) -> rdkit.Chem.Mol: """Convert bytes back to RDKit Mol.""" if len(mol_bytes) == 0: return None From 5c49dc0e2e02f577b1f00943aaf9c01209d37b75 Mon Sep 17 00:00:00 2001 From: Jack Li Date: Thu, 7 Aug 2025 13:21:39 -0400 Subject: [PATCH 10/15] various fixes --- polaris/hub/client.py | 2 +- polaris/prediction/_predictions_v2.py | 94 +++++++++++++-------------- polaris/utils/zarr/codecs.py | 14 ++-- 3 files changed, 51 insertions(+), 59 deletions(-) diff --git a/polaris/hub/client.py b/polaris/hub/client.py index 47602f77..3345bcc0 100644 --- a/polaris/hub/client.py +++ b/polaris/hub/client.py @@ -766,7 +766,7 @@ def submit_benchmark_predictions( # Set owner prediction.owner = HubOwner.normalize(owner or prediction.owner) - prediction_json = prediction.model_dump(by_alias=True, exclude_none=True, exclude={"predictions"}) + prediction_json = prediction.model_dump(by_alias=True, exclude_none=True) # Step 1: Upload metadata to Hub with track_progress(description="Uploading prediction metadata", total=1) as (progress, task): diff --git a/polaris/prediction/_predictions_v2.py b/polaris/prediction/_predictions_v2.py index 82354ef4..4c208a52 100644 --- a/polaris/prediction/_predictions_v2.py +++ b/polaris/prediction/_predictions_v2.py @@ -4,6 +4,7 @@ import shutil from pathlib import Path import tempfile +from enum import StrEnum import numpy as np import zarr @@ -28,6 +29,16 @@ logger = logging.getLogger(__name__) +# Reserved metadata key for storing original Python type +RESERVED_TYPE_KEY = "python_type" + + +class ReservedTypes(StrEnum): + """Reserved type identifiers for object data stored in Zarr arrays.""" + + RDKIT_MOL = "rdkit.Chem.Mol" + ATOM_ARRAY = "biotite.structure.AtomArray" + class BenchmarkPredictionsV2(BenchmarkPredictions, ResultsMetadataV2): """ @@ -43,6 +54,7 @@ class BenchmarkPredictionsV2(BenchmarkPredictions, ResultsMetadataV2): For additional metadata attributes, see the base classes. """ + predictions: dict = Field(exclude=True) # NumPy arrays cannot be JSON serialized dataset_zarr_root: zarr.Group = Field(exclude=True) # Zarr Group cannot be JSON serialized benchmark_artifact_id: str _artifact_type = "prediction" @@ -71,57 +83,45 @@ def to_zarr(self) -> Path: This method should be called explicitly when ready to write predictions to disk. """ - # Get raw zarr root for writing (not the converting wrapper) + # Get zarr root for writing store = zarr.DirectoryStore(self.zarr_root_path) - raw_root = zarr.group(store=store) + root = zarr.group(store=store) dataset_root = self.dataset_zarr_root for test_set_label, test_set_predictions in self.predictions.items(): # Create a group for each test set - test_set_group = raw_root.require_group(test_set_label) + test_set_group = root.require_group(test_set_label) for col in self.target_labels: data = test_set_predictions[col] template = dataset_root[col] # Handle object data conversion if template.dtype == object: - # Find first non-None item to determine conversion type - sample_item = next((item for item in data if item is not None), None) - - # Define conversion mapping - conversion_map = { - struc.AtomArray: (convert_atomarray_to_dict, MsgPack()), - Chem.Mol: (convert_mol_to_bytes, VLenBytes()), - } - - if sample_item is not None: - # Find matching conversion for sample item type - converter_func, codec = None, MsgPack() # Default fallback - for obj_type, (conv_func, conv_codec) in conversion_map.items(): - if isinstance(sample_item, obj_type): - converter_func, codec = conv_func, conv_codec - break - - # Apply conversion if we found a matching converter - if converter_func is not None: - final_data = [converter_func(item) if item is not None else None for item in data] - else: - # No converter found - store as-is (shouldn't happen with current types) - final_data = list(data) - else: - # All items are None + sample = next((item for item in data if item is not None), None) + + if isinstance(sample, Chem.Mol): + codec = VLenBytes() + final_data = [convert_mol_to_bytes(item) for item in data] + filters = [codec] + attributes = {RESERVED_TYPE_KEY: ReservedTypes.RDKIT_MOL} + elif isinstance(sample, struc.AtomArray): codec = MsgPack() + final_data = [convert_atomarray_to_dict(item) for item in data] + filters = [codec] + attributes = {RESERVED_TYPE_KEY: ReservedTypes.ATOM_ARRAY} + else: + # Fall back to dataset template for unknown types final_data = list(data) - - # Object data uses converted data and custom filters - filters = [codec] + filters = template.filters + attributes = {} else: # Non-object data uses original data and template filters final_data = data filters = template.filters + attributes = {} # Single array creation for both cases - test_set_group.array( + zarr_array = test_set_group.array( name=col, data=final_data, dtype=template.dtype, @@ -131,6 +131,10 @@ def to_zarr(self) -> Path: overwrite=True, ) + # Set attributes after creation if we have any + if attributes: + zarr_array.attrs.update(attributes) + return Path(self.zarr_root_path) def get_converted_predictions(self) -> dict: @@ -143,34 +147,28 @@ def get_converted_predictions(self) -> dict: converted_predictions = {} for test_set_label in self.test_set_labels: - if test_set_label not in self.zarr_root: - continue - - test_set_group = self.zarr_root[test_set_label] + test_set_group = self.zarr_root.require_group(test_set_label) converted_predictions[test_set_label] = {} for target in self.target_labels: - if target not in test_set_group: - continue - zarr_array = test_set_group[target] data = zarr_array[:] # Check if this needs conversion (object data) template = self.dataset_zarr_root[target] if template.dtype == object: - # Use filters to determine conversion (simple and reliable) - filters = zarr_array.filters or [] - if any(isinstance(f, MsgPack) for f in filters): - converter = convert_dict_to_atomarray - elif any(isinstance(f, VLenBytes) for f in filters): + # Use metadata to determine conversion type + python_type = zarr_array.attrs.get(RESERVED_TYPE_KEY) + if python_type == ReservedTypes.RDKIT_MOL: converter = convert_bytes_to_mol + elif python_type == ReservedTypes.ATOM_ARRAY: + converter = convert_dict_to_atomarray else: converter = None # Apply conversion if converter: - converted = [converter(item) if item is not None else None for item in data] + converted = [converter(item) for item in data] converted_predictions[test_set_label][target] = np.array(converted, dtype=object) else: converted_predictions[test_set_label][target] = data @@ -185,8 +183,8 @@ def zarr_root(self) -> zarr.Group: """Get the zarr Group object corresponding to the root, creating it if it doesn't exist.""" if self._zarr_root is None: store = zarr.DirectoryStore(self.zarr_root_path) - raw_root = zarr.group(store=store) - self._zarr_root = raw_root + root = zarr.group(store=store) + self._zarr_root = root return self._zarr_root @property @@ -241,7 +239,7 @@ def has_zarr_manifest_md5sum(self): return self._zarr_manifest_md5sum is not None def __repr__(self): - return self.model_dump_json(by_alias=True, exclude={"predictions"}, indent=2) + return self.model_dump_json(by_alias=True, indent=2) def __str__(self): return self.__repr__() diff --git a/polaris/utils/zarr/codecs.py b/polaris/utils/zarr/codecs.py index 537cf244..d2f60b3a 100644 --- a/polaris/utils/zarr/codecs.py +++ b/polaris/utils/zarr/codecs.py @@ -152,14 +152,11 @@ def decode(self, buf, out=None): register_codec(AtomArrayCodec) -def convert_atomarray_to_dict(atom_array: AtomArray | None) -> dict[str, list] | None: +def convert_atomarray_to_dict(atom_array: struc.AtomArray | None) -> dict[str, list] | None: """Convert AtomArray to a dict that can be stored with standard MsgPack codec.""" if atom_array is None: return None - if not isinstance(atom_array, struc.AtomArray): - raise ValueError(f"Expected an AtomArray, but got {type(atom_array)} instead") - data = { "coord": atom_array.coord, "chain_id": atom_array.chain_id, @@ -177,7 +174,7 @@ def convert_atomarray_to_dict(atom_array: AtomArray | None) -> dict[str, list] | return {k: v.tolist() for k, v in data.items()} -def convert_dict_to_atomarray(data: dict) -> AtomArray: +def convert_dict_to_atomarray(data: dict | None) -> struc.AtomArray | None: """Convert dict back to AtomArray.""" if data is None: return None @@ -205,19 +202,16 @@ def convert_dict_to_atomarray(data: dict) -> AtomArray: return struc.array(atom_array) -def convert_mol_to_bytes(mol: rdkit.Chem.Mol | None) -> bytes: +def convert_mol_to_bytes(mol: Chem.Mol | None) -> bytes: """Convert RDKit Mol to bytes that can be stored with standard VLenBytes codec.""" if mol is None: return b"" - if not isinstance(mol, Chem.Mol): - raise ValueError(f"Expected an RDKitMol, but got {type(mol)} instead.") - props = Chem.PropertyPickleOptions.AllProps return mol.ToBinary(props) -def convert_bytes_to_mol(mol_bytes: bytes) -> rdkit.Chem.Mol: +def convert_bytes_to_mol(mol_bytes: bytes) -> Chem.Mol | None: """Convert bytes back to RDKit Mol.""" if len(mol_bytes) == 0: return None From 829f1305f3caaade01a2b42bfa3ba905d7b398fe Mon Sep 17 00:00:00 2001 From: Jack Li Date: Thu, 7 Aug 2025 13:27:28 -0400 Subject: [PATCH 11/15] use enum instead of strEnum --- polaris/prediction/_predictions_v2.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/polaris/prediction/_predictions_v2.py b/polaris/prediction/_predictions_v2.py index 4c208a52..4b2b51ea 100644 --- a/polaris/prediction/_predictions_v2.py +++ b/polaris/prediction/_predictions_v2.py @@ -4,7 +4,7 @@ import shutil from pathlib import Path import tempfile -from enum import StrEnum +from enum import Enum import numpy as np import zarr @@ -32,10 +32,8 @@ # Reserved metadata key for storing original Python type RESERVED_TYPE_KEY = "python_type" - -class ReservedTypes(StrEnum): +class ReservedTypes(str, Enum): """Reserved type identifiers for object data stored in Zarr arrays.""" - RDKIT_MOL = "rdkit.Chem.Mol" ATOM_ARRAY = "biotite.structure.AtomArray" From fce35ee92009ff4a88a090b6cdde11041812500f Mon Sep 17 00:00:00 2001 From: Jack Li Date: Thu, 7 Aug 2025 13:27:42 -0400 Subject: [PATCH 12/15] format --- polaris/prediction/_predictions_v2.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/polaris/prediction/_predictions_v2.py b/polaris/prediction/_predictions_v2.py index 4b2b51ea..95df1731 100644 --- a/polaris/prediction/_predictions_v2.py +++ b/polaris/prediction/_predictions_v2.py @@ -32,8 +32,10 @@ # Reserved metadata key for storing original Python type RESERVED_TYPE_KEY = "python_type" + class ReservedTypes(str, Enum): """Reserved type identifiers for object data stored in Zarr arrays.""" + RDKIT_MOL = "rdkit.Chem.Mol" ATOM_ARRAY = "biotite.structure.AtomArray" From 1c079ed6c51337574dcb0338db546116d7e9db04 Mon Sep 17 00:00:00 2001 From: Jack Li Date: Thu, 14 Aug 2025 17:23:27 -0400 Subject: [PATCH 13/15] remove unecessary get prediction method and setting metadata from upload --- polaris/prediction/_predictions_v2.py | 64 +------------------------- tests/test_benchmark_predictions_v2.py | 19 +++++--- 2 files changed, 14 insertions(+), 69 deletions(-) diff --git a/polaris/prediction/_predictions_v2.py b/polaris/prediction/_predictions_v2.py index 95df1731..46ef7b27 100644 --- a/polaris/prediction/_predictions_v2.py +++ b/polaris/prediction/_predictions_v2.py @@ -4,7 +4,6 @@ import shutil from pathlib import Path import tempfile -from enum import Enum import numpy as np import zarr @@ -19,8 +18,6 @@ from polaris.utils.zarr._manifest import generate_zarr_manifest, calculate_file_md5 from polaris.utils.zarr.codecs import ( - convert_dict_to_atomarray, - convert_bytes_to_mol, convert_atomarray_to_dict, convert_mol_to_bytes, ) @@ -29,16 +26,6 @@ logger = logging.getLogger(__name__) -# Reserved metadata key for storing original Python type -RESERVED_TYPE_KEY = "python_type" - - -class ReservedTypes(str, Enum): - """Reserved type identifiers for object data stored in Zarr arrays.""" - - RDKIT_MOL = "rdkit.Chem.Mol" - ATOM_ARRAY = "biotite.structure.AtomArray" - class BenchmarkPredictionsV2(BenchmarkPredictions, ResultsMetadataV2): """ @@ -103,25 +90,21 @@ def to_zarr(self) -> Path: codec = VLenBytes() final_data = [convert_mol_to_bytes(item) for item in data] filters = [codec] - attributes = {RESERVED_TYPE_KEY: ReservedTypes.RDKIT_MOL} elif isinstance(sample, struc.AtomArray): codec = MsgPack() final_data = [convert_atomarray_to_dict(item) for item in data] filters = [codec] - attributes = {RESERVED_TYPE_KEY: ReservedTypes.ATOM_ARRAY} else: # Fall back to dataset template for unknown types final_data = list(data) filters = template.filters - attributes = {} else: # Non-object data uses original data and template filters final_data = data filters = template.filters - attributes = {} # Single array creation for both cases - zarr_array = test_set_group.array( + test_set_group.array( name=col, data=final_data, dtype=template.dtype, @@ -131,53 +114,8 @@ def to_zarr(self) -> Path: overwrite=True, ) - # Set attributes after creation if we have any - if attributes: - zarr_array.attrs.update(attributes) - return Path(self.zarr_root_path) - def get_converted_predictions(self) -> dict: - """Get all predictions with automatic conversion back to original object types. - - Returns: - Full predictions dictionary with same structure as self.predictions, - but with object data converted back to original types (AtomArray, RDKit Mol) - """ - converted_predictions = {} - - for test_set_label in self.test_set_labels: - test_set_group = self.zarr_root.require_group(test_set_label) - converted_predictions[test_set_label] = {} - - for target in self.target_labels: - zarr_array = test_set_group[target] - data = zarr_array[:] - - # Check if this needs conversion (object data) - template = self.dataset_zarr_root[target] - if template.dtype == object: - # Use metadata to determine conversion type - python_type = zarr_array.attrs.get(RESERVED_TYPE_KEY) - if python_type == ReservedTypes.RDKIT_MOL: - converter = convert_bytes_to_mol - elif python_type == ReservedTypes.ATOM_ARRAY: - converter = convert_dict_to_atomarray - else: - converter = None - - # Apply conversion - if converter: - converted = [converter(item) for item in data] - converted_predictions[test_set_label][target] = np.array(converted, dtype=object) - else: - converted_predictions[test_set_label][target] = data - else: - # Non-object data, use as-is - converted_predictions[test_set_label][target] = data - - return converted_predictions - @property def zarr_root(self) -> zarr.Group: """Get the zarr Group object corresponding to the root, creating it if it doesn't exist.""" diff --git a/tests/test_benchmark_predictions_v2.py b/tests/test_benchmark_predictions_v2.py index 801095ae..07cad65c 100644 --- a/tests/test_benchmark_predictions_v2.py +++ b/tests/test_benchmark_predictions_v2.py @@ -1,4 +1,6 @@ from polaris.prediction._predictions_v2 import BenchmarkPredictionsV2 +from polaris.utils.zarr.codecs import convert_bytes_to_mol, convert_dict_to_atomarray + from rdkit import Chem import numpy as np import pytest @@ -38,9 +40,12 @@ def test_v2_rdkit_object_codec(v2_benchmark_with_rdkit_object_dtype): zarr_path = bp.to_zarr() assert zarr_path.exists() - # Use the get_converted_predictions method to get all converted data - converted_predictions = bp.get_converted_predictions() - arr = converted_predictions["test"]["expt"] + # Read raw zarr content and decode via standard codec + raw_root = zarr.open(str(zarr_path), mode="r") + zarr_array = raw_root["test"]["expt"] + data = zarr_array[:] + # Decode using our helper to bytes->Mol + arr = [convert_bytes_to_mol(x) for x in data] arr_smiles = [Chem.MolToSmiles(m) for m in arr] mols_smiles = [Chem.MolToSmiles(m) for m in mols] assert arr_smiles == mols_smiles @@ -75,9 +80,11 @@ def test_v2_atomarray_object_codec(v2_benchmark_with_atomarray_object_dtype, pdb zarr_path = bp.to_zarr() assert zarr_path.exists() - # Use the get_converted_predictions method to get all converted data - converted_predictions = bp.get_converted_predictions() - arr = converted_predictions["test"]["expt"] + # Read raw zarr content and decode via standard codec + raw_root = zarr.open(str(zarr_path), mode="r") + zarr_array = raw_root["test"]["expt"] + data = zarr_array[:] + arr = np.array([convert_dict_to_atomarray(x) for x in data], dtype=object) assert arr.dtype == object assert arr.shape == (2,) assert all(isinstance(x, struc.AtomArray) for x in arr) From de3b8ae9a211dcf32fa3ea7330d2c9f5c8aee5b6 Mon Sep 17 00:00:00 2001 From: Jack Li Date: Fri, 15 Aug 2025 11:16:18 -0400 Subject: [PATCH 14/15] updated to_zarr for future zarr v3 compatibility --- polaris/prediction/_predictions_v2.py | 48 +++++++++++++++------------ 1 file changed, 27 insertions(+), 21 deletions(-) diff --git a/polaris/prediction/_predictions_v2.py b/polaris/prediction/_predictions_v2.py index 46ef7b27..2b4ce664 100644 --- a/polaris/prediction/_predictions_v2.py +++ b/polaris/prediction/_predictions_v2.py @@ -85,34 +85,40 @@ def to_zarr(self) -> Path: # Handle object data conversion if template.dtype == object: sample = next((item for item in data if item is not None), None) - + + # Define object type handlers if isinstance(sample, Chem.Mol): - codec = VLenBytes() - final_data = [convert_mol_to_bytes(item) for item in data] - filters = [codec] + object_codec, final_data, filters = VLenBytes(), [convert_mol_to_bytes(item) for item in data], None elif isinstance(sample, struc.AtomArray): - codec = MsgPack() - final_data = [convert_atomarray_to_dict(item) for item in data] - filters = [codec] + object_codec, final_data, filters = MsgPack(), [convert_atomarray_to_dict(item) for item in data], None else: - # Fall back to dataset template for unknown types - final_data = list(data) - filters = template.filters + object_codec, final_data, filters = None, list(data), template.filters + + # Create array with object_codec for object types (Zarr v3 compatibility) + test_set_group.array( + name=col, + data=final_data, + dtype=template.dtype, + compressor=template.compressor, + filters=filters, + object_codec=object_codec, + chunks=template.chunks, + overwrite=True, + ) else: # Non-object data uses original data and template filters final_data = data filters = template.filters - - # Single array creation for both cases - test_set_group.array( - name=col, - data=final_data, - dtype=template.dtype, - compressor=template.compressor, - filters=filters, - chunks=template.chunks, - overwrite=True, - ) + + test_set_group.array( + name=col, + data=final_data, + dtype=template.dtype, + compressor=template.compressor, + filters=filters, + chunks=template.chunks, + overwrite=True, + ) return Path(self.zarr_root_path) From 009d66eb9d99ed90d66969e68051c79c73bd2138 Mon Sep 17 00:00:00 2001 From: Jack Li Date: Fri, 15 Aug 2025 11:27:42 -0400 Subject: [PATCH 15/15] format --- polaris/prediction/_predictions_v2.py | 16 ++++++++++++---- 1 file changed, 12 insertions(+), 4 deletions(-) diff --git a/polaris/prediction/_predictions_v2.py b/polaris/prediction/_predictions_v2.py index 2b4ce664..d1aa06e0 100644 --- a/polaris/prediction/_predictions_v2.py +++ b/polaris/prediction/_predictions_v2.py @@ -85,12 +85,20 @@ def to_zarr(self) -> Path: # Handle object data conversion if template.dtype == object: sample = next((item for item in data if item is not None), None) - + # Define object type handlers if isinstance(sample, Chem.Mol): - object_codec, final_data, filters = VLenBytes(), [convert_mol_to_bytes(item) for item in data], None + object_codec, final_data, filters = ( + VLenBytes(), + [convert_mol_to_bytes(item) for item in data], + None, + ) elif isinstance(sample, struc.AtomArray): - object_codec, final_data, filters = MsgPack(), [convert_atomarray_to_dict(item) for item in data], None + object_codec, final_data, filters = ( + MsgPack(), + [convert_atomarray_to_dict(item) for item in data], + None, + ) else: object_codec, final_data, filters = None, list(data), template.filters @@ -109,7 +117,7 @@ def to_zarr(self) -> Path: # Non-object data uses original data and template filters final_data = data filters = template.filters - + test_set_group.array( name=col, data=final_data,