diff --git a/polaris/benchmark/_benchmark_v2.py b/polaris/benchmark/_benchmark_v2.py index 5a568153..3a4af28b 100644 --- a/polaris/benchmark/_benchmark_v2.py +++ b/polaris/benchmark/_benchmark_v2.py @@ -30,7 +30,6 @@ class BenchmarkV2Specification( Attributes: dataset: The dataset the benchmark specification is based on. - splits: The predefined train-test splits to use for evaluation. n_classes: The number of classes for each of the target columns. readme: Markdown text that can be used to provide a formatted description of the benchmark. artifact_version: The version of the benchmark. diff --git a/polaris/prediction/_predictions_v2.py b/polaris/prediction/_predictions_v2.py index 2bb0a250..d1aa06e0 100644 --- a/polaris/prediction/_predictions_v2.py +++ b/polaris/prediction/_predictions_v2.py @@ -9,10 +9,18 @@ import zarr from pydantic import ( PrivateAttr, + Field, model_validator, ) +from numcodecs import MsgPack, VLenBytes +from fastpdb import struc +from rdkit import Chem from polaris.utils.zarr._manifest import generate_zarr_manifest, calculate_file_md5 +from polaris.utils.zarr.codecs import ( + convert_atomarray_to_dict, + convert_mol_to_bytes, +) from polaris.evaluate import ResultsMetadataV2 from polaris.evaluate._predictions import BenchmarkPredictions @@ -33,7 +41,8 @@ class BenchmarkPredictionsV2(BenchmarkPredictions, ResultsMetadataV2): For additional metadata attributes, see the base classes. """ - dataset_zarr_root: zarr.Group + predictions: dict = Field(exclude=True) # NumPy arrays cannot be JSON serialized + dataset_zarr_root: zarr.Group = Field(exclude=True) # Zarr Group cannot be JSON serialized benchmark_artifact_id: str _artifact_type = "prediction" _zarr_root_path: str | None = PrivateAttr(None) @@ -61,7 +70,9 @@ def to_zarr(self) -> Path: This method should be called explicitly when ready to write predictions to disk. """ - root = self.zarr_root + # Get zarr root for writing + store = zarr.DirectoryStore(self.zarr_root_path) + root = zarr.group(store=store) dataset_root = self.dataset_zarr_root for test_set_label, test_set_predictions in self.predictions.items(): @@ -70,15 +81,52 @@ def to_zarr(self) -> Path: for col in self.target_labels: data = test_set_predictions[col] template = dataset_root[col] - test_set_group.array( - name=col, - data=data, - dtype=template.dtype, - compressor=template.compressor, - filters=template.filters, - chunks=template.chunks, - overwrite=True, - ) + + # Handle object data conversion + if template.dtype == object: + sample = next((item for item in data if item is not None), None) + + # Define object type handlers + if isinstance(sample, Chem.Mol): + object_codec, final_data, filters = ( + VLenBytes(), + [convert_mol_to_bytes(item) for item in data], + None, + ) + elif isinstance(sample, struc.AtomArray): + object_codec, final_data, filters = ( + MsgPack(), + [convert_atomarray_to_dict(item) for item in data], + None, + ) + else: + object_codec, final_data, filters = None, list(data), template.filters + + # Create array with object_codec for object types (Zarr v3 compatibility) + test_set_group.array( + name=col, + data=final_data, + dtype=template.dtype, + compressor=template.compressor, + filters=filters, + object_codec=object_codec, + chunks=template.chunks, + overwrite=True, + ) + else: + # Non-object data uses original data and template filters + final_data = data + filters = template.filters + + test_set_group.array( + name=col, + data=final_data, + dtype=template.dtype, + compressor=template.compressor, + filters=filters, + chunks=template.chunks, + overwrite=True, + ) return Path(self.zarr_root_path) @@ -87,7 +135,8 @@ def zarr_root(self) -> zarr.Group: """Get the zarr Group object corresponding to the root, creating it if it doesn't exist.""" if self._zarr_root is None: store = zarr.DirectoryStore(self.zarr_root_path) - self._zarr_root = zarr.group(store=store) + root = zarr.group(store=store) + self._zarr_root = root return self._zarr_root @property diff --git a/polaris/utils/zarr/codecs.py b/polaris/utils/zarr/codecs.py index 0459b684..d2f60b3a 100644 --- a/polaris/utils/zarr/codecs.py +++ b/polaris/utils/zarr/codecs.py @@ -82,6 +82,7 @@ def encode(self, buf: np.ndarray): for idx, atom_array in enumerate(buf): # A chunk can have missing values if atom_array is None: + to_pack[idx] = None # Explicitly set None values continue if not isinstance(atom_array, struc.AtomArray): @@ -149,3 +150,70 @@ def decode(self, buf, out=None): register_codec(RDKitMolCodec) register_codec(AtomArrayCodec) + + +def convert_atomarray_to_dict(atom_array: struc.AtomArray | None) -> dict[str, list] | None: + """Convert AtomArray to a dict that can be stored with standard MsgPack codec.""" + if atom_array is None: + return None + + data = { + "coord": atom_array.coord, + "chain_id": atom_array.chain_id, + "res_id": atom_array.res_id, + "ins_code": atom_array.ins_code, + "res_name": atom_array.res_name, + "hetero": atom_array.hetero, + "atom_name": atom_array.atom_name, + "element": atom_array.element, + "atom_id": atom_array.atom_id, + "b_factor": atom_array.b_factor, + "occupancy": atom_array.occupancy, + "charge": atom_array.charge, + } + return {k: v.tolist() for k, v in data.items()} + + +def convert_dict_to_atomarray(data: dict | None) -> struc.AtomArray | None: + """Convert dict back to AtomArray.""" + if data is None: + return None + + atom_array = [] + array_length = len(data["coord"]) + + for ind in range(array_length): + atom = struc.Atom( + coord=data["coord"][ind], + chain_id=data["chain_id"][ind], + res_id=data["res_id"][ind], + ins_code=data["ins_code"][ind], + res_name=data["res_name"][ind], + hetero=data["hetero"][ind], + atom_name=data["atom_name"][ind], + element=data["element"][ind], + b_factor=data["b_factor"][ind], + occupancy=data["occupancy"][ind], + charge=data["charge"][ind], + atom_id=data["atom_id"][ind], + ) + atom_array.append(atom) + + return struc.array(atom_array) + + +def convert_mol_to_bytes(mol: Chem.Mol | None) -> bytes: + """Convert RDKit Mol to bytes that can be stored with standard VLenBytes codec.""" + if mol is None: + return b"" + + props = Chem.PropertyPickleOptions.AllProps + return mol.ToBinary(props) + + +def convert_bytes_to_mol(mol_bytes: bytes) -> Chem.Mol | None: + """Convert bytes back to RDKit Mol.""" + if len(mol_bytes) == 0: + return None + + return Chem.Mol(mol_bytes) diff --git a/tests/test_benchmark_predictions_v2.py b/tests/test_benchmark_predictions_v2.py index 83bd5330..07cad65c 100644 --- a/tests/test_benchmark_predictions_v2.py +++ b/tests/test_benchmark_predictions_v2.py @@ -1,5 +1,6 @@ from polaris.prediction._predictions_v2 import BenchmarkPredictionsV2 -from polaris.utils.zarr.codecs import RDKitMolCodec, AtomArrayCodec +from polaris.utils.zarr.codecs import convert_bytes_to_mol, convert_dict_to_atomarray + from rdkit import Chem import numpy as np import pytest @@ -35,20 +36,29 @@ def test_v2_rdkit_object_codec(v2_benchmark_with_rdkit_object_dtype): assert bp.predictions["test"]["expt"].dtype == object assert_deep_equal(bp.predictions, {"test": {"expt": np.array(mols, dtype=object)}}) - # Check Zarr archive + # Check Zarr archive by reading through the BenchmarkPredictionsV2 object zarr_path = bp.to_zarr() assert zarr_path.exists() - root = zarr.open(str(zarr_path), mode="r") - arr = root["test"]["expt"][:] + + # Read raw zarr content and decode via standard codec + raw_root = zarr.open(str(zarr_path), mode="r") + zarr_array = raw_root["test"]["expt"] + data = zarr_array[:] + # Decode using our helper to bytes->Mol + arr = [convert_bytes_to_mol(x) for x in data] arr_smiles = [Chem.MolToSmiles(m) for m in arr] mols_smiles = [Chem.MolToSmiles(m) for m in mols] assert arr_smiles == mols_smiles - # Check that object_codec is correctly set as a filter (Zarr stores object_codec as filters) - zarr_array = root["test"]["expt"] + # Check that standard codec is used (direct zarr access shows raw format) + raw_root = zarr.open(str(zarr_path), mode="r") + zarr_array = raw_root["test"]["expt"] assert zarr_array.filters is not None assert len(zarr_array.filters) > 0 - assert any(isinstance(f, RDKitMolCodec) for f in zarr_array.filters) + # Now we expect VLenBytes instead of custom codec + from numcodecs import VLenBytes + + assert any(isinstance(f, VLenBytes) for f in zarr_array.filters) def test_v2_atomarray_object_codec(v2_benchmark_with_atomarray_object_dtype, pdbs_structs): @@ -66,20 +76,28 @@ def test_v2_atomarray_object_codec(v2_benchmark_with_atomarray_object_dtype, pdb assert bp.predictions["test"]["expt"].dtype == object assert_deep_equal(bp.predictions, {"test": {"expt": np.array(pdbs_structs[:2], dtype=object)}}) - # Check Zarr archive (dtype and shape only) + # Check Zarr archive by reading through the BenchmarkPredictionsV2 object zarr_path = bp.to_zarr() assert zarr_path.exists() - root = zarr.open(str(zarr_path), mode="r") - arr = root["test"]["expt"][:] + + # Read raw zarr content and decode via standard codec + raw_root = zarr.open(str(zarr_path), mode="r") + zarr_array = raw_root["test"]["expt"] + data = zarr_array[:] + arr = np.array([convert_dict_to_atomarray(x) for x in data], dtype=object) assert arr.dtype == object assert arr.shape == (2,) assert all(isinstance(x, struc.AtomArray) for x in arr) - # Check that object_codec is correctly set as a filter (Zarr stores object_codec as filters) - zarr_array = root["test"]["expt"] + # Check that standard codec is used (direct zarr access shows raw format) + raw_root = zarr.open(str(zarr_path), mode="r") + zarr_array = raw_root["test"]["expt"] assert zarr_array.filters is not None assert len(zarr_array.filters) > 0 - assert any(isinstance(f, AtomArrayCodec) for f in zarr_array.filters) + # Now we expect MsgPack instead of custom codec + from numcodecs import MsgPack + + assert any(isinstance(f, MsgPack) for f in zarr_array.filters) def test_v2_dtype_mismatch_raises(test_benchmark_v2):