Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
895 changes: 895 additions & 0 deletions example/demo/demo.py

Large diffs are not rendered by default.

3 changes: 3 additions & 0 deletions example/demo/run-in-memory.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
#!/bin/bash

uv run demo.py
11 changes: 11 additions & 0 deletions example/demo/run-postgres.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
#!/bin/bash
set -euox pipefail

CONTAINER=pg
docker rm -f $CONTAINER 2>/dev/null || true
docker run -d --name $CONTAINER --rm -e POSTGRES_PASSWORD=pass -p 5432:5432 postgres:17-alpine >/dev/null
until docker exec $CONTAINER pg_isready >/dev/null 2>&1; do sleep 1; done;

uv run demo.py "postgresql://demo:demo@localhost:5433/demo"

docker stop $CONTAINER >/dev/null
93 changes: 93 additions & 0 deletions example/tiny-demo/demo.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
# /// script
# dependencies = [
# "pytest",
# "pytest-asyncio",
# "waymark",
# ]
# ///

import asyncio
import sys
from dataclasses import dataclass
from typing import Annotated

import pytest
from waymark import Depend, Workflow, action, workflow
from waymark.workflow import workflow_registry

workflow_registry._workflows.clear()


@dataclass
class User:
id: str
email: str
active: bool


@dataclass
class EmailResult:
to: str
subject: str
success: bool


async def get_mock_db():
return {
"user1": User(id="user1", email="alice@example.com", active=True),
"user2": User(id="user2", email="bob@example.com", active=False),
"user3": User(id="user3", email="carol@example.com", active=True),
}


async def get_mock_email_client():
return "email_client"


@action
async def fetch_users(
user_ids: list[str],
db: Annotated[dict, Depend(get_mock_db)],
) -> list[User]:
return [db[uid] for uid in user_ids if uid in db]


@action
async def send_email(
to: str,
subject: str,
emailer: Annotated[str, Depend(get_mock_email_client)],
) -> EmailResult:
return EmailResult(to=to, subject=subject, success=True)


@workflow
class WelcomeEmailWorkflow(Workflow):
async def run(self, user_ids: list[str]) -> dict:
"""Send welcome emails to active users"""

users = await fetch_users(user_ids)
active_users = [user for user in users if user.active]

results = await asyncio.gather(
*[send_email(to=user.email, subject="Welcome") for user in active_users],
return_exceptions=True,
)

return {
"total_users": len(users),
"active_users": len(active_users),
"emails_sent": len(results),
}


@pytest.mark.asyncio
async def test_welcome_email_workflow():
result = await WelcomeEmailWorkflow().run(user_ids=["user1", "user2", "user3"])
assert result["total_users"] == 3
assert result["active_users"] == 2
assert result["emails_sent"] == 2


if __name__ == "__main__":
sys.exit(pytest.main(["-v", __file__]))
3 changes: 3 additions & 0 deletions example/tiny-demo/run.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
#!/bin/bash

uv run demo.py
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -10,80 +10,16 @@
from typing import Literal, Optional

import asyncpg
from example_app.workflows import BranchRequest, BranchResult, ChainRequest, ChainResult, ComputationRequest, ComputationResult, ConditionalBranchWorkflow, DurableSleepWorkflow, EarlyReturnLoopResult, EarlyReturnLoopWorkflow, ErrorHandlingWorkflow, ErrorRequest, ErrorResult, ExceptionMetadataWorkflow, GuardFallbackRequest, GuardFallbackResult, GuardFallbackWorkflow, KwOnlyLocationRequest, KwOnlyLocationResult, KwOnlyLocationWorkflow, LoopExceptionRequest, LoopExceptionResult, LoopExceptionWorkflow, LoopingSleepRequest, LoopingSleepResult, LoopingSleepWorkflow, LoopProcessingWorkflow, LoopRequest, LoopResult, LoopReturnRequest, LoopReturnResult, LoopReturnWorkflow, ManyActionsRequest, ManyActionsResult, ManyActionsWorkflow, NoOpWorkflow, ParallelMathWorkflow, RetryCounterRequest, RetryCounterResult, RetryCounterWorkflow, SequentialChainWorkflow, SleepRequest, SleepResult, SpreadEmptyCollectionWorkflow, SpreadEmptyRequest, SpreadEmptyResult, TimeoutProbeRequest, TimeoutProbeResult, TimeoutProbeWorkflow, UndefinedVariableWorkflow, WhileLoopRequest, WhileLoopResult, WhileLoopWorkflow
from fastapi import FastAPI, HTTPException, Request
from fastapi.responses import HTMLResponse, StreamingResponse
from fastapi.templating import Jinja2Templates
from pydantic import BaseModel, Field

from waymark import (
bridge,
delete_schedule,
pause_schedule,
resume_schedule,
schedule_workflow,
)

from example_app.workflows import (
BranchRequest,
BranchResult,
ChainRequest,
ChainResult,
ComputationRequest,
ComputationResult,
ConditionalBranchWorkflow,
DurableSleepWorkflow,
GuardFallbackRequest,
GuardFallbackResult,
GuardFallbackWorkflow,
EarlyReturnLoopResult,
EarlyReturnLoopWorkflow,
ErrorHandlingWorkflow,
ErrorRequest,
ErrorResult,
ExceptionMetadataWorkflow,
KwOnlyLocationRequest,
KwOnlyLocationResult,
KwOnlyLocationWorkflow,
LoopExceptionRequest,
LoopExceptionResult,
LoopExceptionWorkflow,
LoopReturnRequest,
LoopReturnResult,
LoopReturnWorkflow,
LoopProcessingWorkflow,
LoopRequest,
LoopResult,
LoopingSleepRequest,
LoopingSleepResult,
LoopingSleepWorkflow,
RetryCounterRequest,
RetryCounterResult,
RetryCounterWorkflow,
TimeoutProbeRequest,
TimeoutProbeResult,
TimeoutProbeWorkflow,
ManyActionsRequest,
ManyActionsResult,
ManyActionsWorkflow,
NoOpWorkflow,
ParallelMathWorkflow,
SequentialChainWorkflow,
SleepRequest,
SleepResult,
SpreadEmptyCollectionWorkflow,
SpreadEmptyRequest,
SpreadEmptyResult,
UndefinedVariableWorkflow,
WhileLoopRequest,
WhileLoopResult,
WhileLoopWorkflow,
)
from waymark import bridge, delete_schedule, pause_schedule, resume_schedule, schedule_workflow

app = FastAPI(title="Waymark Example")

templates = Jinja2Templates(
directory=str(Path(__file__).resolve().parent / "templates")
)
templates = Jinja2Templates(directory=str(Path(__file__).resolve().parent / "templates"))


@app.get("/", response_class=HTMLResponse)
Expand Down Expand Up @@ -308,9 +244,7 @@ async def run_undefined_variable_workflow(payload: UndefinedVariableRequest) ->


class EarlyReturnLoopRequest(BaseModel):
input_text: str = Field(
description="Input text to parse. Use 'no_session:' prefix for early return path, or comma-separated items for loop path."
)
input_text: str = Field(description="Input text to parse. Use 'no_session:' prefix for early return path, or comma-separated items for loop path.")


@app.post("/api/early-return-loop", response_model=EarlyReturnLoopResult)
Expand Down Expand Up @@ -355,9 +289,7 @@ async def run_many_actions_workflow(payload: ManyActionsRequest) -> ManyActionsR
Executes a configurable number of actions either in parallel or sequentially.
"""
workflow = ManyActionsWorkflow()
return await workflow.run(
action_count=payload.action_count, parallel=payload.parallel
)
return await workflow.run(action_count=payload.action_count, parallel=payload.parallel)


# =============================================================================
Expand All @@ -376,9 +308,7 @@ async def run_looping_sleep_workflow(
Useful for testing looping sleep workflows.
"""
workflow = LoopingSleepWorkflow()
return await workflow.run(
iterations=payload.iterations, sleep_seconds=payload.sleep_seconds
)
return await workflow.run(iterations=payload.iterations, sleep_seconds=payload.sleep_seconds)


# =============================================================================
Expand Down Expand Up @@ -415,12 +345,8 @@ class ScheduleRequest(BaseModel):
default=None,
description="Cron expression (e.g., '*/5 * * * *' for every 5 minutes)",
)
interval_seconds: Optional[int] = Field(
default=None, ge=10, description="Interval in seconds (minimum 10)"
)
inputs: Optional[dict] = Field(
default=None, description="Input arguments to pass to each scheduled run"
)
interval_seconds: Optional[int] = Field(default=None, ge=10, description="Interval in seconds (minimum 10)")
inputs: Optional[dict] = Field(default=None, description="Input arguments to pass to each scheduled run")


class ScheduleResponse(BaseModel):
Expand Down Expand Up @@ -538,9 +464,7 @@ async def run_batch_workflow(payload: BatchRunRequest) -> StreamingResponse:
"""Queue a batch of workflow instances and stream progress via SSE."""
workflow_cls = WORKFLOW_REGISTRY.get(payload.workflow_name)
if not workflow_cls:
raise HTTPException(
status_code=404, detail=f"Unknown workflow: {payload.workflow_name}"
)
raise HTTPException(status_code=404, detail=f"Unknown workflow: {payload.workflow_name}")

inputs_list = payload.inputs_list
if inputs_list is not None and len(inputs_list) == 0:
Expand All @@ -558,10 +482,7 @@ async def run_batch_workflow(payload: BatchRunRequest) -> StreamingResponse:
if missing:
raise HTTPException(
status_code=400,
detail=(
f"inputs_list[{idx}] missing required keys: "
f"{', '.join(missing)}"
),
detail=(f"inputs_list[{idx}] missing required keys: " f"{', '.join(missing)}"),
)
else:
missing = _missing_input_keys(required_keys, base_inputs)
Expand All @@ -583,22 +504,13 @@ async def event_stream() -> AsyncIterator[str]:
},
)

registration = workflow_cls._build_registration_payload(
priority=payload.priority
)
registration = workflow_cls._build_registration_payload(priority=payload.priority)
if inputs_list is not None:
batch_inputs = [
workflow_cls._build_initial_context((), inputs)
for inputs in inputs_list
]
batch_inputs = [workflow_cls._build_initial_context((), inputs) for inputs in inputs_list]
base_inputs_message = None
else:
batch_inputs = None
base_inputs_message = (
workflow_cls._build_initial_context((), base_inputs)
if payload.inputs is not None
else None
)
base_inputs_message = workflow_cls._build_initial_context((), base_inputs) if payload.inputs is not None else None

batch_result = await bridge.run_instances_batch(
registration.SerializeToString(),
Expand All @@ -617,9 +529,7 @@ async def event_stream() -> AsyncIterator[str]:
"queued": batch_result.queued,
"total": total,
"elapsed_ms": elapsed_ms,
"instance_ids": batch_result.workflow_instance_ids
if payload.include_instance_ids
else None,
"instance_ids": batch_result.workflow_instance_ids if payload.include_instance_ids else None,
},
)
except Exception as exc: # pragma: no cover - streaming errors
Expand Down Expand Up @@ -724,9 +634,7 @@ async def reset_database() -> ResetResponse:
"""Reset workflow-related tables for a clean slate. Development use only."""
database_url = os.environ.get("WAYMARK_DATABASE_URL")
if not database_url:
return ResetResponse(
success=False, message="WAYMARK_DATABASE_URL not configured"
)
return ResetResponse(success=False, message="WAYMARK_DATABASE_URL not configured")

try:
conn = await asyncpg.connect(database_url)
Expand Down
Loading