diff --git a/cookbook/summarization_pipeline_hf.py b/cookbook/summarization_pipeline_hf.py new file mode 100644 index 00000000..6e92bf5b --- /dev/null +++ b/cookbook/summarization_pipeline_hf.py @@ -0,0 +1,73 @@ +import healthchain as hc + +from healthchain.pipeline import SummarizationPipeline +from healthchain.use_cases import ClinicalDecisionSupport +from healthchain.models import CdsFhirData, CDSRequest, CDSResponse +from healthchain.data_generators import CdsDataGenerator + +from langchain_huggingface.llms import HuggingFaceEndpoint +from langchain_huggingface import ChatHuggingFace + +from langchain_core.prompts import PromptTemplate +from langchain_core.output_parsers import StrOutputParser + +import getpass +import os + + +if not os.getenv("HUGGINGFACEHUB_API_TOKEN"): + os.environ["HUGGINGFACEHUB_API_TOKEN"] = getpass.getpass("Enter your token: ") + + +@hc.sandbox( + experiment_config={ + "storage_uri": "sqlite:///experiments.db", # Where to store experiment data + "project_name": "patient_summary", # Name for grouping experiments + } +) +class DischargeNoteSummarizer(ClinicalDecisionSupport): + def __init__(self): + # Initialize pipeline and data generator + chain = self._init_chain() + self.pipeline = SummarizationPipeline.load( + chain, source="langchain", template_path="templates/cds_card_template.json" + ) + self.data_generator = CdsDataGenerator() + + def _init_chain(self): + hf = HuggingFaceEndpoint( + repo_id="HuggingFaceH4/zephyr-7b-beta", + task="text-generation", + max_new_tokens=512, + do_sample=False, + repetition_penalty=1.03, + ) + model = ChatHuggingFace(llm=hf) + template = """ + You are a bed planner for a hospital. Provide a concise, objective summary of the input text in short bullet points separated by new lines, + focusing on key actions such as appointments and medication dispense instructions, without using second or third person pronouns.\n'''{text}''' + """ + prompt = PromptTemplate.from_template(template) + chain = prompt | model | StrOutputParser() + + return chain + + @hc.ehr(workflow="encounter-discharge") + def load_data_in_client(self) -> CdsFhirData: + # Generate synthetic FHIR data for testing + data = self.data_generator.generate( + free_text_path="data/discharge_notes.csv", column_name="text" + ) + return data + + @hc.api + def my_service(self, request: CDSRequest) -> CDSResponse: + # Process the request through our pipeline + result = self.pipeline(request) + return result + + +if __name__ == "__main__": + # Start the sandbox server + summarizer = DischargeNoteSummarizer() + summarizer.start_sandbox() diff --git a/docs/reference/tracking/experiment_tracking.md b/docs/reference/tracking/experiment_tracking.md new file mode 100644 index 00000000..462e9b86 --- /dev/null +++ b/docs/reference/tracking/experiment_tracking.md @@ -0,0 +1,204 @@ +# ExperimentTracker Documentation + +ExperimentTracker is a simple yet powerful tool for tracking your machine learning experiments. It automatically records experiment metadata, timing, and status with minimal configuration required. + +## Quick Start + +The easiest way to use ExperimentTracker is through the `@sandbox` decorator: + +```python +from healthchain import sandbox + +@sandbox( + experiment_config={ + "storage_uri": "sqlite:///experiments.db", # Where to store experiment data + "project_name": "my_project", # Name for grouping experiments + "tags" : {"environment": "production"} # Optional tags + } +) +class MyExperiment(BaseUseCase): + def __init__(self): + # ExperimentTracker is automatically initialized + # You can access it via self.experiment_tracker + pass +``` + +That's it! The system will automatically: +- Create a unique ID for each experiment run +- Track when experiments start and end +- Record the status (completed or failed) +- Save any tags you provide + +## Viewing Your Experiments + +### Database Schema + +ExperimentTracker uses SQLAlchemy to store experiment data in two tables: + +**experiments**: +- `id`: Unique identifier (UUID) +- `name`: Experiment name +- `start_time`: Start timestamp +- `end_time`: End timestamp +- `status`: Current status (RUNNING, COMPLETED, FAILED) +- `tags`: JSON field for custom tags +- `pipeline_config`: JSON field for pipeline configuration (optional) + +**pipeline_components**: +- `id`: Component ID +- `experiment_id`: Reference to parent experiment +- `name`: Component name +- `type`: Component type +- `stage`: Processing stage +- `position`: Order in pipeline + +### Using Python API + +```python +# Get details for a specific experiment +experiment = tracker.get_experiment(experiment_id) +print(f"Status: {experiment.status}") +print(f"Duration: {experiment.end_time - experiment.start_time}") + +# List all experiments +experiments = tracker.list_experiments() + +# Filter experiments by tags +prod_experiments = tracker.list_experiments( + filters={"tags": {"environment": "production"}} +) +``` + +### Querying the Database Directly + +The experiment data is stored in a local SQLite database that you can query directly in a python script or Jupyter notebook: + +```python +import sqlite3 + +# Connect to the database +conn = sqlite3.connect('experiments.db') +cursor = conn.cursor() + +# View recent experiments +cursor.execute(""" + SELECT id, name, start_time, status, tags + FROM experiments + ORDER BY start_time DESC + LIMIT 5; +""") +recent_experiments = cursor.fetchall() + +# View experiments with a specific tag +cursor.execute(""" + SELECT id, name, start_time, status + FROM experiments + WHERE json_extract(tags, '$.environment') = 'production'; +""") +prod_experiments = cursor.fetchall() + +# Get component details for an experiment +cursor.execute(""" + SELECT name, type, stage + FROM pipeline_components + WHERE experiment_id = ?; +""", (experiment_id,)) +components = cursor.fetchall() + +conn.close() +``` + + + +## Configuration Options + +The `experiment_config` dictionary supports two options: +- `storage_uri`: Where to store experiment data (default: "sqlite:///experiments.db") + - Use SQLite: "sqlite:///experiments.db" +- `project_name`: Name for grouping related experiments (default: "healthchain") + +## Example: Real-World Usage + +Here's a complete example showing how ExperimentTracker is used in practice: + +```python +import healthchain as hc + +from healthchain.pipeline import SummarizationPipeline +from healthchain.use_cases import ClinicalDecisionSupport +from healthchain.models import CdsFhirData, CDSRequest, CDSResponse +from healthchain.data_generators import CdsDataGenerator + +from langchain_huggingface.llms import HuggingFaceEndpoint +from langchain_huggingface import ChatHuggingFace + +from langchain_core.prompts import PromptTemplate +from langchain_core.output_parsers import StrOutputParser + +import getpass +import os + + +if not os.getenv("HUGGINGFACEHUB_API_TOKEN"): + os.environ["HUGGINGFACEHUB_API_TOKEN"] = getpass.getpass("Enter your token: ") + + +@hc.sandbox( + experiment_config={ + "storage_uri": "sqlite:///experiments.db", # Where to store experiment data + "project_name": "patient_summary", # Name for grouping experiments + } +) +class DischargeNoteSummarizer(ClinicalDecisionSupport): + def __init__(self): + # Initialize pipeline and data generator + chain = self._init_chain() + self.pipeline = SummarizationPipeline.load( + chain, source="langchain", template_path="templates/cds_card_template.json" + ) + self.data_generator = CdsDataGenerator() + + def _init_chain(self): + hf = HuggingFaceEndpoint( + repo_id="HuggingFaceH4/zephyr-7b-beta", + task="text-generation", + max_new_tokens=512, + do_sample=False, + repetition_penalty=1.03, + ) + model = ChatHuggingFace(llm=hf) + template = """ + You are a bed planner for a hospital. Provide a concise, objective summary of the input text in short bullet points separated by new lines, + focusing on key actions such as appointments and medication dispense instructions, without using second or third person pronouns.\n'''{text}''' + """ + prompt = PromptTemplate.from_template(template) + chain = prompt | model | StrOutputParser() + + return chain + + @hc.ehr(workflow="encounter-discharge") + def load_data_in_client(self) -> CdsFhirData: + # Generate synthetic FHIR data for testing + data = self.data_generator.generate( + free_text_path="data/discharge_notes.csv", column_name="text" + ) + return data + + @hc.api + def my_service(self, request: CDSRequest) -> CDSResponse: + # Process the request through our pipeline + result = self.pipeline(request) + return result + + +if __name__ == "__main__": + # Start the sandbox server + summarizer = DischargeNoteSummarizer() + summarizer.start_sandbox() +``` + + +### Performance Considerations + +- SQLite (default) works well for single-user scenarios +- Large-scale deployments may want to implement custom storage backends diff --git a/healthchain/decorators.py b/healthchain/decorators.py index 46cfb58d..0a11dc43 100644 --- a/healthchain/decorators.py +++ b/healthchain/decorators.py @@ -5,7 +5,6 @@ import json import uuid import requests - from time import sleep from pathlib import Path from datetime import datetime @@ -14,6 +13,7 @@ from healthchain.workflows import UseCaseType from healthchain.apimethod import APIMethod +from healthchain.tracking.experiment import ExperimentTracker, ExperimentStatus from .base import BaseUseCase from .service import Service @@ -124,15 +124,19 @@ def sandbox(arg: Optional[Any] = None, **kwargs: Any) -> Callable: Parameters: arg: Optional argument which can be either a callable (class) directly or a configuration dict. - **kwargs: Arbitrary keyword arguments, mainly used to pass in 'service_config'. - 'service_config' must be a dictionary of valid kwargs to pass into uvivorn.run() + **kwargs: Arbitrary keyword arguments for configuring the service and experiment tracking. + service_config: Dictionary of valid kwargs to pass into uvivorn.run() + experiment_config: Dictionary with storage_uri and project_name for experiment tracking Returns: If `arg` is callable, it applies the default decorator with no extra configuration. Otherwise, it uses the provided arguments to configure the service environment. Example: - @sandbox(service_config={"port": 9000}) + @sandbox( + service_config={"port": 9000}, + experiment_config={"storage_uri": "./experiments", "project_name": "my_project"} + ) class myCDS(ClinicalDecisionSupport): def __init__(self) -> None: self.data_generator = None @@ -143,30 +147,33 @@ def __init__(self) -> None: return sandbox_decorator()(cls) # Apply default decorator with default settings else: # Arguments were provided, or no arguments but with parentheses - if "service_config" not in kwargs: + valid_configs = {"service_config", "experiment_config"} + invalid_args = set(kwargs.keys()) - valid_configs + if invalid_args: log.warning( - f"{list(kwargs.keys())} is not a valid argument and will not be used; use 'service_config'." + f"{list(invalid_args)} are not valid arguments and will not be used; " + f"use {list(valid_configs)}." ) + service_config = arg if arg is not None else kwargs.get("service_config", {}) + experiment_config = kwargs.get("experiment_config", {}) - return sandbox_decorator(service_config) + return sandbox_decorator(service_config, experiment_config) -def sandbox_decorator(service_config: Optional[Dict] = None) -> Callable: +def sandbox_decorator( + service_config: Optional[Dict] = None, experiment_config: Optional[Dict] = None +) -> Callable: """ - A decorator function that sets up a sandbox environment. It modifies the class initialization - to incorporate service and client management based on provided configurations. It will: - - - Initialise the use case strategy class - - Set up a service instance - - Trigger .send_request() function from the configured client + Decorator that configures a sandbox class with service and experiment tracking. - Parameters: - service_config: A dictionary containing configurations for the service. - - Returns: - A wrapper function that modifies the class to which it is applied. + Args: + service_config: Optional configuration for the service + experiment_config: Optional configuration for experiment tracking + storage_uri: Where to store experiment data + project_name: Name of the project for grouping experiments """ + if service_config is None: service_config = {} @@ -179,9 +186,30 @@ def wrapper(cls: Type) -> Type: original_init = cls.__init__ def new_init(self, *args: Any, **kwargs: Any) -> None: - # initialse parent class, which should be a strategy use case + # Initialize parent class super(cls, self).__init__(*args, **kwargs, service_config=service_config) - original_init(self, *args, **kwargs) # Call the original __init__ + + # Initialize experiment tracker + storage_uri = ( + experiment_config.get("storage_uri", "sqlite:///experiments.db") + if experiment_config + else "./output/experiments" + ) + project_name = ( + experiment_config.get("project_name", "healthchain") + if experiment_config + else "healthchain" + ) + + if experiment_config: + self.experiment_tracker = ExperimentTracker( + storage_uri=storage_uri, project_name=project_name + ) + else: + self.experiment_tracker = None + + # Call original init + original_init(self, *args, **kwargs) service_route_count = 0 client_count = 0 @@ -230,6 +258,21 @@ def start_sandbox( self.sandbox_id = uuid.uuid4() + # Start experiment tracking + pipeline = getattr(self, "pipeline", None) + if self.experiment_tracker: + self.experiment_tracker.start_experiment( + name=f"{self.__class__.__name__}_{datetime.now().strftime('%Y%m%d_%H%M%S')}", + pipeline=pipeline, + tags={ + "sandbox_class": self.__class__.__name__, + "workflow": self._client.workflow.value + if self._client + else None, + }, + ) + + # Configure logging if logging_config: logging.config.dictConfig(logging_config) else: @@ -259,19 +302,16 @@ def start_sandbox( service_id=service_id, ) - # Send async request from client + # Send request and get response log.info( - f"Sending {len(self._client.request_data)} requests generated by {self._client.__class__.__name__} to {self.url.route}" + f"Sending {len(self._client.request_data)} requests to {self.url.route}" + ) + self.responses = asyncio.run( + self._client.send_request(url=self.url.service) ) - - try: - self.responses = asyncio.run( - self._client.send_request(url=self.url.service) - ) - except Exception as e: - log.error(f"Couldn't start client: {e}") if save_data: + # Save request/response data as before save_dir = Path(save_dir) request_path = ensure_directory_exists(save_dir / "requests") if self.type == UseCaseType.clindoc: @@ -317,9 +357,16 @@ def stop_sandbox(self) -> None: log.info("Shutting down server...") requests.get(self.url.base + "/shutdown") + # End experiment successfully + if self.experiment_tracker: + self.experiment_tracker.end_experiment(ExperimentStatus.COMPLETED) + cls.start_sandbox = start_sandbox cls.stop_sandbox = stop_sandbox - return cls return wrapper + + +# Update the sandbox alias +# sandbox = sandbox_decorator diff --git a/healthchain/tracking/database.py b/healthchain/tracking/database.py new file mode 100644 index 00000000..4a525a88 --- /dev/null +++ b/healthchain/tracking/database.py @@ -0,0 +1,31 @@ +from sqlalchemy import Column, Integer, String, DateTime, JSON, ForeignKey +from sqlalchemy.ext.declarative import declarative_base +from sqlalchemy.orm import relationship + +Base = declarative_base() + + +class Experiment(Base): + __tablename__ = "experiments" + + id = Column(String, primary_key=True) + name = Column(String) + start_time = Column(DateTime) + end_time = Column(DateTime, nullable=True) + status = Column(String) + tags = Column(JSON) + pipeline_config = Column(JSON, nullable=True) + components = relationship("PipelineComponent", back_populates="experiment") + + +class PipelineComponent(Base): + __tablename__ = "pipeline_components" + + id = Column(Integer, primary_key=True) + experiment_id = Column(String, ForeignKey("experiments.id")) + name = Column(String) + type = Column(String) + stage = Column(String) + position = Column(Integer) + + experiment = relationship("Experiment", back_populates="components") diff --git a/healthchain/tracking/experiment.py b/healthchain/tracking/experiment.py new file mode 100644 index 00000000..b78e5b6e --- /dev/null +++ b/healthchain/tracking/experiment.py @@ -0,0 +1,182 @@ +from dataclasses import dataclass +from enum import Enum +import json +from typing import Any, Dict, List, Optional, Set +from sqlalchemy import create_engine +from sqlalchemy.orm import sessionmaker +from datetime import datetime +import uuid +import dill +from pathlib import Path + +from .database import Base, Experiment, PipelineComponent + + +class ExperimentStatus(Enum): + RUNNING = "running" + COMPLETED = "completed" + FAILED = "failed" + + +@dataclass +class ComponentMetadata: + name: str + type: str + stage: str + config: Dict[str, Any] + input_nodes: Set[str] + output_nodes: Set[str] + + +@dataclass +class PipelineMetadata: + name: str + components: Dict[str, ComponentMetadata] + input_components: List[str] + output_components: List[str] + stages: List[str] + + +@dataclass +class ExperimentMetadata: + id: str + name: str + start_time: datetime + end_time: Optional[datetime] + status: ExperimentStatus + sandbox_class: Optional[str] + workflow: Optional[str] + pipeline: Optional[PipelineMetadata] + input_schema: Dict[str, Any] + output_schema: Dict[str, Any] + tags: Dict[str, str] + + +class ExperimentTracker: + def __init__( + self, + storage_uri: str = "sqlite:///experiments.db", + project_name: str = "healthchain", + ): + self.engine = create_engine(storage_uri) + Base.metadata.create_all(self.engine) + self.Session = sessionmaker(bind=self.engine) + self.project_name = project_name + self.current_experiment = None + + # Create directory for pipeline serialization + self.pipeline_dir = Path("./output/pipelines") + self.pipeline_dir.mkdir(parents=True, exist_ok=True) + + def start_experiment( + self, name: str, pipeline=None, tags: Dict[str, str] = None + ) -> str: + experiment_id = str(uuid.uuid4()) + + # Create new experiment record + experiment = Experiment( + id=experiment_id, + name=name, + start_time=datetime.now(), + status="RUNNING", + tags=tags or {}, + ) + + # Save pipeline configuration if provided + if pipeline is not None: + # Save pipeline configuration + pipeline_config = self._extract_pipeline_metadata(pipeline) + experiment.pipeline_config = pipeline_config + + # Serialize pipeline components + for i, component in enumerate(pipeline._components): + pc = PipelineComponent( + experiment_id=experiment_id, + name=component.name, + type=component.func.__class__.__name__, + stage=component.stage, + position=i, + ) + experiment.components.append(pc) + + # Save to database + session = self.Session() + session.add(experiment) + session.commit() + + self.current_experiment = experiment + return experiment_id + + def end_experiment(self, status: ExperimentStatus): + if self.current_experiment: + session = self.Session() + experiment = session.query(Experiment).get(self.current_experiment.id) + experiment.end_time = datetime.now() + experiment.status = status.value + session.commit() + session.close() + + def load_pipeline(self, experiment_id: str): + """Load a serialized pipeline from an experiment""" + pipeline_path = self.pipeline_dir / f"{experiment_id}.pkl" + if pipeline_path.exists(): + with open(pipeline_path, "rb") as f: + return dill.load(f) + return None + + def get_experiment(self, experiment_id: str) -> Optional[Experiment]: + """Retrieve experiment details from database""" + session = self.Session() + experiment = session.query(Experiment).get(experiment_id) + session.close() + return experiment + + def list_experiments(self, filters: Dict[str, Any] = None) -> List[Experiment]: + """List all experiments with optional filtering""" + session = self.Session() + query = session.query(Experiment) + + if filters: + for key, value in filters.items(): + if hasattr(Experiment, key): + query = query.filter(getattr(Experiment, key) == value) + + experiments = query.all() + session.close() + return experiments + + def _extract_pipeline_metadata(self, pipeline) -> Dict[str, Any]: + components = {} + for component in pipeline._components: + if not component.name.startswith("_"): + config = {} + if hasattr(component, "get_config"): + config = component.get_config() + elif hasattr(component, "__dict__"): + config = { + k: v + for k, v in component.__dict__.items() + if not k.startswith("_") and self._is_json_serializable(v) + } + + components[component.name] = { + "name": component.__class__.__name__, + "type": f"{component.__class__.__module__}.{component.__class__.__name__}", + "stage": "unknown", + "config": config, + } + + return { + "name": pipeline.__class__.__name__, + "components": components, + "input_components": [], + "output_components": [], + "stages": [], + } + + def _is_json_serializable(self, obj): + try: + json.dumps(obj) + return True + except (TypeError, OverflowError): + return False