diff --git a/pyhealth/datasets/base_dataset.py b/pyhealth/datasets/base_dataset.py index dcd194bc4..778a6f362 100644 --- a/pyhealth/datasets/base_dataset.py +++ b/pyhealth/datasets/base_dataset.py @@ -519,7 +519,11 @@ def global_event_df(self) -> pl.LazyFrame: if self._global_event_df is None: ret_path = self.cache_dir / "global_event_df.parquet" - if not ret_path.exists(): + cache_valid = ret_path.is_dir() and any(ret_path.glob("*.parquet")) + if not cache_valid: + if ret_path.exists(): + logger.warning(f"Incomplete parquet cache at {ret_path} (directory exists but contains no parquet files). Removing and rebuilding.") + shutil.rmtree(ret_path) logger.info(f"No cached event dataframe found. Creating: {ret_path}") self._event_transform(ret_path) else: diff --git a/pyhealth/datasets/configs/mimic3.yaml b/pyhealth/datasets/configs/mimic3.yaml index 944719850..76b355e55 100644 --- a/pyhealth/datasets/configs/mimic3.yaml +++ b/pyhealth/datasets/configs/mimic3.yaml @@ -134,6 +134,7 @@ tables: timestamp: - "charttime" attributes: + - "row_id" - "hadm_id" - "text" - "category" diff --git a/tests/core/test_caching.py b/tests/core/test_caching.py index 4b7f1d512..a346b98e1 100644 --- a/tests/core/test_caching.py +++ b/tests/core/test_caching.py @@ -107,6 +107,9 @@ def setUp(self): self.dataset = MockDataset(cache_dir=self.temp_dir.name) self.task = MockTask() + def tearDown(self): + self.temp_dir.cleanup() + def test_set_task_signature(self): """Test that set_task has the correct method signature.""" import inspect @@ -296,5 +299,39 @@ def test_datasets_with_diff_dev_values_get_diff_caches(self): self.assertNotEqual(dataset1.cache_dir, dataset2.cache_dir) + def test_incomplete_parquet_cache_triggers_rebuild(self): + """A global_event_df.parquet dir that exists but holds no *.parquet files + is detected as incomplete, logged as a warning, torn down, and rebuilt.""" + ret_path = self.dataset.cache_dir / "global_event_df.parquet" + + # Simulate a partial/interrupted write: directory exists but no parquet files + ret_path.mkdir(parents=True, exist_ok=True) + stale_file = ret_path / "incomplete.tmp" + stale_file.touch() + + self.assertTrue(ret_path.is_dir()) + self.assertFalse(any(ret_path.glob("*.parquet"))) + + # global_event_df should detect the incomplete cache, emit a warning, + # remove the directory, and rebuild it from scratch. + with self.assertLogs("pyhealth.datasets.base_dataset", level="WARNING") as log: + _ = self.dataset.global_event_df + + self.assertTrue( + any("Incomplete parquet cache" in msg for msg in log.output), + msg=f"Expected 'Incomplete parquet cache' warning; got: {log.output}", + ) + + # After rebuild the directory must exist and contain valid parquet files + self.assertTrue(ret_path.is_dir()) + self.assertTrue( + any(ret_path.glob("*.parquet")), + msg="Expected *.parquet files to be present after cache rebuild", + ) + + # The stale file should be gone (the whole directory was removed then recreated) + self.assertFalse(stale_file.exists()) + + if __name__ == "__main__": unittest.main()