Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 21 additions & 11 deletions pyhealth/datasets/base_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,32 +53,42 @@ def path_exists(path: str) -> bool:
return Path(path).exists()


def scan_csv_gz_or_csv(path: str) -> pl.LazyFrame:
def scan_csv_gz_or_csv_tsv(path: str) -> pl.LazyFrame:
"""
Scan a CSV.gz or CSV file and returns a LazyFrame.
Scan a CSV.gz, CSV, TSV.gz, or TSV file and returns a LazyFrame.
It will fall back to the other extension if not found.

Args:
path (str): URL or local path to a .csv or .csv.gz file
path (str): URL or local path to a .csv, .csv.gz, .tsv, or .tsv.gz file

Returns:
pl.LazyFrame: The LazyFrame for the CSV.gz or CSV file.
pl.LazyFrame: The LazyFrame for the CSV.gz, CSV, TSV.gz, or TSV file.
"""
def scan_file(file_path: str) -> pl.LazyFrame:
separator = '\t' if '.tsv' in file_path else ','
return pl.scan_csv(file_path, separator=separator, infer_schema=False)

if path_exists(path):
return pl.scan_csv(path, infer_schema=False)
return scan_file(path)

# Try the alternative extension
if path.endswith(".csv.gz"):
alt_path = path[:-3] # Remove .gz
alt_path = path[:-3] # Remove .gz -> try .csv
elif path.endswith(".csv"):
alt_path = f"{path}.gz" # Add .gz
alt_path = f"{path}.gz" # Add .gz -> try .csv.gz
elif path.endswith(".tsv.gz"):
alt_path = path[:-3] # Remove .gz -> try .tsv
elif path.endswith(".tsv"):
alt_path = f"{path}.gz" # Add .gz -> try .tsv.gz
else:
raise FileNotFoundError(f"Path does not have expected extension: {path}")

if path_exists(alt_path):
logger.info(f"Original path does not exist. Using alternative: {alt_path}")
return pl.scan_csv(alt_path, infer_schema=False)
return scan_file(alt_path)

raise FileNotFoundError(f"Neither path exists: {path} or {alt_path}")


class BaseDataset(ABC):
"""Abstract base class for all PyHealth datasets.

Expand Down Expand Up @@ -198,7 +208,7 @@ def load_table(self, table_name: str) -> pl.LazyFrame:
csv_path = clean_path(csv_path)

logger.info(f"Scanning table: {table_name} from {csv_path}")
df = scan_csv_gz_or_csv(csv_path)
df = scan_csv_gz_or_csv_tsv(csv_path)

# Convert column names to lowercase before calling preprocess_func
col_names = df.collect_schema().names()
Expand All @@ -219,7 +229,7 @@ def load_table(self, table_name: str) -> pl.LazyFrame:
other_csv_path = f"{self.root}/{join_cfg.file_path}"
other_csv_path = clean_path(other_csv_path)
logger.info(f"Joining with table: {other_csv_path}")
join_df = scan_csv_gz_or_csv(other_csv_path)
join_df = scan_csv_gz_or_csv_tsv(other_csv_path)
join_df = join_df.with_columns(
[
pl.col(col).alias(col.lower())
Expand Down
259 changes: 259 additions & 0 deletions tests/core/test_tsv_load.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,259 @@
import os
import tempfile
import unittest
from pathlib import Path

import polars as pl
import yaml

from pyhealth.datasets.base_dataset import BaseDataset


class TestTSVLoad(unittest.TestCase):
"""Test TSV loading functionality with BaseDataset."""

def setUp(self):
"""Set up temporary directory and create pseudo dataset."""
self.temp_dir = tempfile.mkdtemp()
self._create_pseudo_dataset()
self._create_config_file()

def tearDown(self):
"""Clean up temporary directory."""
import shutil

if os.path.exists(self.temp_dir):
shutil.rmtree(self.temp_dir)

def _create_pseudo_dataset(self):
"""Create pseudo TSV dataset files with random data."""
# Create patients table
patients_data = {
"patient_id": ["P001", "P002", "P003", "P004", "P005"],
"gender": ["M", "F", "M", "F", "M"],
"age": [45, 32, 67, 28, 53],
"admission_date": [
"2023-01-15",
"2023-02-20",
"2023-03-10",
"2023-01-25",
"2023-04-05",
],
}
patients_df = pl.DataFrame(patients_data)
patients_path = Path(self.temp_dir) / "patients.tsv"
patients_df.write_csv(patients_path, separator="\t")

# Create diagnoses table
diagnoses_data = {
"patient_id": ["P001", "P001", "P002", "P003", "P004", "P005"],
"diagnosis_code": ["A01.1", "B15.9", "C78.0", "D50.0", "E11.9", "F32.9"],
"diagnosis_desc": [
"Typhoid fever",
"Hepatitis A",
"Lung cancer",
"Iron deficiency",
"Type 2 diabetes",
"Depression",
],
"timestamp": [
"2023-01-15 10:00",
"2023-01-16 14:30",
"2023-02-20 09:15",
"2023-03-10 11:45",
"2023-01-25 16:20",
"2023-04-05 08:30",
],
}
diagnoses_df = pl.DataFrame(diagnoses_data)
diagnoses_path = Path(self.temp_dir) / "diagnoses.tsv"
diagnoses_df.write_csv(diagnoses_path, separator="\t")

# Create procedures table
procedures_data = {
"patient_id": ["P001", "P002", "P003", "P004", "P005"],
"procedure_code": ["99213", "99214", "99215", "99213", "99214"],
"procedure_desc": [
"Office visit",
"Extended visit",
"Complex visit",
"Office visit",
"Extended visit",
],
"timestamp": [
"2023-01-15 11:00",
"2023-02-20 10:30",
"2023-03-10 12:00",
"2023-01-25 17:00",
"2023-04-05 09:00",
],
}
procedures_df = pl.DataFrame(procedures_data)
procedures_path = Path(self.temp_dir) / "procedures.tsv"
procedures_df.write_csv(procedures_path, separator="\t")

self.patients_file = str(patients_path)
self.diagnoses_file = str(diagnoses_path)
self.procedures_file = str(procedures_path)

def _create_config_file(self):
"""Create YAML configuration file for the pseudo dataset."""
config_data = {
"version": "1.0",
"tables": {
"patients": {
"file_path": "patients.tsv",
"patient_id": "patient_id",
"timestamp": None,
"attributes": ["gender", "age", "admission_date"],
},
"diagnoses": {
"file_path": "diagnoses.tsv",
"patient_id": "patient_id",
"timestamp": "timestamp",
"timestamp_format": "%Y-%m-%d %H:%M",
"attributes": ["diagnosis_code", "diagnosis_desc"],
},
"procedures": {
"file_path": "procedures.tsv",
"patient_id": "patient_id",
"timestamp": "timestamp",
"timestamp_format": "%Y-%m-%d %H:%M",
"attributes": ["procedure_code", "procedure_desc"],
},
},
}

self.config_path = Path(self.temp_dir) / "test_config.yaml"
with open(self.config_path, "w") as f:
yaml.dump(config_data, f, default_flow_style=False)

def test_tsv_load(self):
"""Test loading TSV dataset with BaseDataset and using stats() function."""
# Test loading the dataset with different table combinations
tables_to_test = [
["patients"],
["diagnoses"],
["procedures"],
["patients", "diagnoses"],
["diagnoses", "procedures"],
["patients", "diagnoses", "procedures"],
]

for tables in tables_to_test:
with self.subTest(tables=tables):
# Create BaseDataset instance
dataset = BaseDataset(
root=self.temp_dir,
tables=tables,
dataset_name="TestTSVDataset",
config_path=str(self.config_path),
dev=False,
)

# Verify the dataset was loaded
self.assertIsNotNone(dataset.global_event_df)
self.assertIsNotNone(dataset.config)

# Test that we can collect the dataframe
collected_df = dataset.collected_global_event_df
self.assertIsInstance(collected_df, pl.DataFrame)
self.assertGreater(
collected_df.height, 0, "Dataset should have at least one row"
)

# Verify patient_id column exists
self.assertIn("patient_id", collected_df.columns)

# Test stats() function
try:
dataset.stats()
except Exception as e:
self.fail(f"dataset.stats() failed with tables {tables}: {e}")

def test_tsv_load_dev_mode(self):
"""Test loading TSV dataset in dev mode."""
# Create dataset in dev mode
dataset = BaseDataset(
root=self.temp_dir,
tables=["patients", "diagnoses", "procedures"],
dataset_name="TestTSVDatasetDev",
config_path=str(self.config_path),
dev=True,
)

# Verify dev mode is enabled
self.assertTrue(dataset.dev)

# Test stats() function in dev mode
try:
dataset.stats()
except Exception as e:
self.fail(f"dataset.stats() failed in dev mode: {e}")

def test_tsv_file_detection(self):
"""Test that TSV files are correctly detected and loaded."""
dataset = BaseDataset(
root=self.temp_dir,
tables=["patients"],
dataset_name="TestTSVDetection",
config_path=str(self.config_path),
dev=False,
)

collected_df = dataset.collected_global_event_df

# Verify we have the expected number of patients
self.assertEqual(collected_df["patient_id"].n_unique(), 5)

# Verify we have the expected columns from the patients table
# Note: attribute columns are prefixed with table name (e.g., "patients/gender")
expected_base_columns = ["patient_id", "event_type", "timestamp"]
expected_patient_columns = [
"patients/gender",
"patients/age",
"patients/admission_date",
]

for col in expected_base_columns:
self.assertIn(col, collected_df.columns)

for col in expected_patient_columns:
self.assertIn(col, collected_df.columns)

def test_multiple_tsv_tables(self):
"""Test loading and joining multiple TSV tables."""
dataset = BaseDataset(
root=self.temp_dir,
tables=["diagnoses", "procedures"],
dataset_name="TestMultipleTSV",
config_path=str(self.config_path),
dev=False,
)

collected_df = dataset.collected_global_event_df

# Should have data from both tables
self.assertGreater(collected_df.height, 5) # More than just patients table

# Should have timestamp column since both diagnoses and procedures have timestamps
self.assertIn("timestamp", collected_df.columns)

# Should have both diagnosis and procedure data
# Note: columns from different tables are prefixed with table names
all_columns = set(collected_df.columns)

# Check for diagnosis-specific columns (prefixed with table name)
diagnosis_columns = {"diagnoses/diagnosis_code", "diagnoses/diagnosis_desc"}
procedure_columns = {"procedures/procedure_code", "procedures/procedure_desc"}

# At least some of these should be present in the concatenated result
self.assertTrue(
len(diagnosis_columns.intersection(all_columns)) > 0
or len(procedure_columns.intersection(all_columns)) > 0,
f"Expected some diagnosis or procedure columns in {all_columns}",
)


if __name__ == "__main__":
unittest.main()