Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion polaris/benchmark/_benchmark_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@ class BenchmarkV2Specification(

Attributes:
dataset: The dataset the benchmark specification is based on.
splits: The predefined train-test splits to use for evaluation.
n_classes: The number of classes for each of the target columns.
readme: Markdown text that can be used to provide a formatted description of the benchmark.
artifact_version: The version of the benchmark.
Expand Down
73 changes: 61 additions & 12 deletions polaris/prediction/_predictions_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,18 @@
import zarr
from pydantic import (
PrivateAttr,
Field,
model_validator,
)
from numcodecs import MsgPack, VLenBytes
from fastpdb import struc
from rdkit import Chem

from polaris.utils.zarr._manifest import generate_zarr_manifest, calculate_file_md5
from polaris.utils.zarr.codecs import (
convert_atomarray_to_dict,
convert_mol_to_bytes,
)
from polaris.evaluate import ResultsMetadataV2
from polaris.evaluate._predictions import BenchmarkPredictions

Expand All @@ -33,7 +41,8 @@ class BenchmarkPredictionsV2(BenchmarkPredictions, ResultsMetadataV2):
For additional metadata attributes, see the base classes.
"""

dataset_zarr_root: zarr.Group
predictions: dict = Field(exclude=True) # NumPy arrays cannot be JSON serialized
dataset_zarr_root: zarr.Group = Field(exclude=True) # Zarr Group cannot be JSON serialized
benchmark_artifact_id: str
_artifact_type = "prediction"
_zarr_root_path: str | None = PrivateAttr(None)
Expand Down Expand Up @@ -61,7 +70,9 @@ def to_zarr(self) -> Path:

This method should be called explicitly when ready to write predictions to disk.
"""
root = self.zarr_root
# Get zarr root for writing
store = zarr.DirectoryStore(self.zarr_root_path)
root = zarr.group(store=store)
dataset_root = self.dataset_zarr_root

for test_set_label, test_set_predictions in self.predictions.items():
Expand All @@ -70,15 +81,52 @@ def to_zarr(self) -> Path:
for col in self.target_labels:
data = test_set_predictions[col]
template = dataset_root[col]
test_set_group.array(
name=col,
data=data,
dtype=template.dtype,
compressor=template.compressor,
filters=template.filters,
chunks=template.chunks,
overwrite=True,
)

# Handle object data conversion
if template.dtype == object:
sample = next((item for item in data if item is not None), None)

# Define object type handlers
if isinstance(sample, Chem.Mol):
object_codec, final_data, filters = (
VLenBytes(),
[convert_mol_to_bytes(item) for item in data],
None,
)
elif isinstance(sample, struc.AtomArray):
object_codec, final_data, filters = (
MsgPack(),
[convert_atomarray_to_dict(item) for item in data],
None,
)
else:
object_codec, final_data, filters = None, list(data), template.filters

# Create array with object_codec for object types (Zarr v3 compatibility)
test_set_group.array(
name=col,
data=final_data,
dtype=template.dtype,
compressor=template.compressor,
filters=filters,
object_codec=object_codec,
chunks=template.chunks,
overwrite=True,
)
else:
# Non-object data uses original data and template filters
final_data = data
filters = template.filters

test_set_group.array(
name=col,
data=final_data,
dtype=template.dtype,
compressor=template.compressor,
filters=filters,
chunks=template.chunks,
overwrite=True,
)

return Path(self.zarr_root_path)

Expand All @@ -87,7 +135,8 @@ def zarr_root(self) -> zarr.Group:
"""Get the zarr Group object corresponding to the root, creating it if it doesn't exist."""
if self._zarr_root is None:
store = zarr.DirectoryStore(self.zarr_root_path)
self._zarr_root = zarr.group(store=store)
root = zarr.group(store=store)
self._zarr_root = root
return self._zarr_root

@property
Expand Down
68 changes: 68 additions & 0 deletions polaris/utils/zarr/codecs.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@ def encode(self, buf: np.ndarray):
for idx, atom_array in enumerate(buf):
# A chunk can have missing values
if atom_array is None:
to_pack[idx] = None # Explicitly set None values
continue

if not isinstance(atom_array, struc.AtomArray):
Expand Down Expand Up @@ -149,3 +150,70 @@ def decode(self, buf, out=None):

register_codec(RDKitMolCodec)
register_codec(AtomArrayCodec)


def convert_atomarray_to_dict(atom_array: struc.AtomArray | None) -> dict[str, list] | None:
"""Convert AtomArray to a dict that can be stored with standard MsgPack codec."""
if atom_array is None:
return None

data = {
"coord": atom_array.coord,
"chain_id": atom_array.chain_id,
"res_id": atom_array.res_id,
"ins_code": atom_array.ins_code,
"res_name": atom_array.res_name,
"hetero": atom_array.hetero,
"atom_name": atom_array.atom_name,
"element": atom_array.element,
"atom_id": atom_array.atom_id,
"b_factor": atom_array.b_factor,
"occupancy": atom_array.occupancy,
"charge": atom_array.charge,
}
return {k: v.tolist() for k, v in data.items()}


def convert_dict_to_atomarray(data: dict | None) -> struc.AtomArray | None:
"""Convert dict back to AtomArray."""
if data is None:
return None

atom_array = []
array_length = len(data["coord"])

for ind in range(array_length):
atom = struc.Atom(
coord=data["coord"][ind],
chain_id=data["chain_id"][ind],
res_id=data["res_id"][ind],
ins_code=data["ins_code"][ind],
res_name=data["res_name"][ind],
hetero=data["hetero"][ind],
atom_name=data["atom_name"][ind],
element=data["element"][ind],
b_factor=data["b_factor"][ind],
occupancy=data["occupancy"][ind],
charge=data["charge"][ind],
atom_id=data["atom_id"][ind],
)
atom_array.append(atom)

return struc.array(atom_array)


def convert_mol_to_bytes(mol: Chem.Mol | None) -> bytes:
"""Convert RDKit Mol to bytes that can be stored with standard VLenBytes codec."""
if mol is None:
return b""

props = Chem.PropertyPickleOptions.AllProps
return mol.ToBinary(props)


def convert_bytes_to_mol(mol_bytes: bytes) -> Chem.Mol | None:
"""Convert bytes back to RDKit Mol."""
if len(mol_bytes) == 0:
return None

return Chem.Mol(mol_bytes)
44 changes: 31 additions & 13 deletions tests/test_benchmark_predictions_v2.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from polaris.prediction._predictions_v2 import BenchmarkPredictionsV2
from polaris.utils.zarr.codecs import RDKitMolCodec, AtomArrayCodec
from polaris.utils.zarr.codecs import convert_bytes_to_mol, convert_dict_to_atomarray

from rdkit import Chem
import numpy as np
import pytest
Expand Down Expand Up @@ -35,20 +36,29 @@ def test_v2_rdkit_object_codec(v2_benchmark_with_rdkit_object_dtype):
assert bp.predictions["test"]["expt"].dtype == object
assert_deep_equal(bp.predictions, {"test": {"expt": np.array(mols, dtype=object)}})

# Check Zarr archive
# Check Zarr archive by reading through the BenchmarkPredictionsV2 object
zarr_path = bp.to_zarr()
assert zarr_path.exists()
root = zarr.open(str(zarr_path), mode="r")
arr = root["test"]["expt"][:]

# Read raw zarr content and decode via standard codec
raw_root = zarr.open(str(zarr_path), mode="r")
zarr_array = raw_root["test"]["expt"]
data = zarr_array[:]
# Decode using our helper to bytes->Mol
arr = [convert_bytes_to_mol(x) for x in data]
arr_smiles = [Chem.MolToSmiles(m) for m in arr]
mols_smiles = [Chem.MolToSmiles(m) for m in mols]
assert arr_smiles == mols_smiles

# Check that object_codec is correctly set as a filter (Zarr stores object_codec as filters)
zarr_array = root["test"]["expt"]
# Check that standard codec is used (direct zarr access shows raw format)
raw_root = zarr.open(str(zarr_path), mode="r")
zarr_array = raw_root["test"]["expt"]
assert zarr_array.filters is not None
assert len(zarr_array.filters) > 0
assert any(isinstance(f, RDKitMolCodec) for f in zarr_array.filters)
# Now we expect VLenBytes instead of custom codec
from numcodecs import VLenBytes

assert any(isinstance(f, VLenBytes) for f in zarr_array.filters)


def test_v2_atomarray_object_codec(v2_benchmark_with_atomarray_object_dtype, pdbs_structs):
Expand All @@ -66,20 +76,28 @@ def test_v2_atomarray_object_codec(v2_benchmark_with_atomarray_object_dtype, pdb
assert bp.predictions["test"]["expt"].dtype == object
assert_deep_equal(bp.predictions, {"test": {"expt": np.array(pdbs_structs[:2], dtype=object)}})

# Check Zarr archive (dtype and shape only)
# Check Zarr archive by reading through the BenchmarkPredictionsV2 object
zarr_path = bp.to_zarr()
assert zarr_path.exists()
root = zarr.open(str(zarr_path), mode="r")
arr = root["test"]["expt"][:]

# Read raw zarr content and decode via standard codec
raw_root = zarr.open(str(zarr_path), mode="r")
zarr_array = raw_root["test"]["expt"]
data = zarr_array[:]
arr = np.array([convert_dict_to_atomarray(x) for x in data], dtype=object)
assert arr.dtype == object
assert arr.shape == (2,)
assert all(isinstance(x, struc.AtomArray) for x in arr)

# Check that object_codec is correctly set as a filter (Zarr stores object_codec as filters)
zarr_array = root["test"]["expt"]
# Check that standard codec is used (direct zarr access shows raw format)
raw_root = zarr.open(str(zarr_path), mode="r")
zarr_array = raw_root["test"]["expt"]
assert zarr_array.filters is not None
assert len(zarr_array.filters) > 0
assert any(isinstance(f, AtomArrayCodec) for f in zarr_array.filters)
# Now we expect MsgPack instead of custom codec
from numcodecs import MsgPack

assert any(isinstance(f, MsgPack) for f in zarr_array.filters)


def test_v2_dtype_mismatch_raises(test_benchmark_v2):
Expand Down
Loading