Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion pyhealth/datasets/base_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -519,7 +519,11 @@ def global_event_df(self) -> pl.LazyFrame:

if self._global_event_df is None:
ret_path = self.cache_dir / "global_event_df.parquet"
if not ret_path.exists():
cache_valid = ret_path.is_dir() and any(ret_path.glob("*.parquet"))
if not cache_valid:
if ret_path.exists():
logger.warning(f"Incomplete parquet cache at {ret_path} (directory exists but contains no parquet files). Removing and rebuilding.")
shutil.rmtree(ret_path)
logger.info(f"No cached event dataframe found. Creating: {ret_path}")
self._event_transform(ret_path)
else:
Expand Down
1 change: 1 addition & 0 deletions pyhealth/datasets/configs/mimic3.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,7 @@ tables:
timestamp:
- "charttime"
attributes:
- "row_id"
- "hadm_id"
- "text"
- "category"
Expand Down
37 changes: 37 additions & 0 deletions tests/core/test_caching.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,9 @@ def setUp(self):
self.dataset = MockDataset(cache_dir=self.temp_dir.name)
self.task = MockTask()

def tearDown(self):
self.temp_dir.cleanup()

def test_set_task_signature(self):
"""Test that set_task has the correct method signature."""
import inspect
Expand Down Expand Up @@ -296,5 +299,39 @@ def test_datasets_with_diff_dev_values_get_diff_caches(self):

self.assertNotEqual(dataset1.cache_dir, dataset2.cache_dir)

def test_incomplete_parquet_cache_triggers_rebuild(self):
"""A global_event_df.parquet dir that exists but holds no *.parquet files
is detected as incomplete, logged as a warning, torn down, and rebuilt."""
ret_path = self.dataset.cache_dir / "global_event_df.parquet"

# Simulate a partial/interrupted write: directory exists but no parquet files
ret_path.mkdir(parents=True, exist_ok=True)
stale_file = ret_path / "incomplete.tmp"
stale_file.touch()

self.assertTrue(ret_path.is_dir())
self.assertFalse(any(ret_path.glob("*.parquet")))

# global_event_df should detect the incomplete cache, emit a warning,
# remove the directory, and rebuild it from scratch.
with self.assertLogs("pyhealth.datasets.base_dataset", level="WARNING") as log:
_ = self.dataset.global_event_df

self.assertTrue(
any("Incomplete parquet cache" in msg for msg in log.output),
msg=f"Expected 'Incomplete parquet cache' warning; got: {log.output}",
)

# After rebuild the directory must exist and contain valid parquet files
self.assertTrue(ret_path.is_dir())
self.assertTrue(
any(ret_path.glob("*.parquet")),
msg="Expected *.parquet files to be present after cache rebuild",
)

# The stale file should be gone (the whole directory was removed then recreated)
self.assertFalse(stale_file.exists())


if __name__ == "__main__":
unittest.main()