diff --git a/example/demo/demo.py b/example/demo/demo.py new file mode 100755 index 00000000..f50dfbc8 --- /dev/null +++ b/example/demo/demo.py @@ -0,0 +1,895 @@ +# /// script +# dependencies = [ +# "asyncpg", +# "pytest", +# "pytest-asyncio", +# "waymark", +# ] +# /// + +import asyncio +import os +import sys +from pathlib import Path + +import pytest + +os.environ["WAYMARK_DATABASE_URL"] = sys.argv[1] if len(sys.argv) > 1 else "1" + +from waymark import Workflow, action, workflow +from waymark.workflow import RetryPolicy, workflow_registry + +workflow_registry._workflows.clear() # so pytest can re-import this file + + +# +# Parallel Execution +# + + +@action +async def compute_factorial(n: int) -> int: + result = 1 + for i in range(2, n + 1): + result *= i + return result + + +@action +async def compute_fibonacci(n: int) -> int: + a, b = 0, 1 + for _ in range(n): + a, b = b, a + b + return a + + +@action +async def summarize_math(factorial: int, fibonacci: int, n: int) -> dict: + if factorial > 5_000: + summary = f"{n}! is massive compared to Fib({n})={fibonacci}" + elif factorial > 100: + summary = f"{n}! is larger, but Fibonacci is {fibonacci}" + else: + summary = f"{n}! ({factorial}) stays tame next to Fibonacci={fibonacci}" + return {"factorial": factorial, "fibonacci": fibonacci, "summary": summary, "n": n} + + +@workflow +class ParallelMathWorkflow(Workflow): + async def run(self, n: int) -> dict: + factorial, fibonacci = await asyncio.gather(compute_factorial(n), compute_fibonacci(n), return_exceptions=True) + return await summarize_math(factorial, fibonacci, n) + + +# +# Sequential Chain +# + + +@action +async def step_uppercase(text: str) -> str: + return text.upper() + + +@action +async def step_reverse(text: str) -> str: + return text[::-1] + + +@action +async def step_add_stars(text: str) -> str: + return f"*** {text} ***" + + +@workflow +class SequentialChainWorkflow(Workflow): + async def run(self, text: str) -> dict: + step1 = await step_uppercase(text) + step2 = await step_reverse(step1) + step3 = await step_add_stars(step2) + return {"original": text, "final": step3} + + +# +# Conditional Branching +# + + +@action +async def evaluate_high(value: int) -> dict: + return {"value": value, "branch": "high", "message": f"High: {value}"} + + +@action +async def evaluate_medium(value: int) -> dict: + return {"value": value, "branch": "medium", "message": f"Medium: {value}"} + + +@action +async def evaluate_low(value: int) -> dict: + return {"value": value, "branch": "low", "message": f"Low: {value}"} + + +@workflow +class ConditionalBranchWorkflow(Workflow): + async def run(self, value: int) -> dict: + if value >= 75: + return await evaluate_high(value) + elif value >= 25: + return await evaluate_medium(value) + else: + return await evaluate_low(value) + + +# +# Loop Processing +# + + +@action +async def process_item(item: str) -> str: + return item.upper() + + +@workflow +class LoopProcessingWorkflow(Workflow): + async def run(self, items: list[str]) -> dict: + processed = [] + for item in items: + result = await process_item(item) + processed.append(result) + return {"items": items, "processed": processed, "count": len(processed)} + + +# +# While Loop +# + + +@action +async def increment_counter_action(value: int) -> int: + return value + 1 + + +@workflow +class WhileLoopWorkflow(Workflow): + async def run(self, limit: int) -> dict: + current, iterations = 0, 0 + for _ in range(limit): + current = await increment_counter_action(current) + iterations = iterations + 1 + return {"limit": limit, "final": current, "iterations": iterations} + + +# +# Loop with Return +# + + +@action +async def matches_needle(value: int, needle: int) -> bool: + return value == needle + + +@workflow +class LoopReturnWorkflow(Workflow): + async def run(self, items: list[int], needle: int) -> dict: + checked = 0 + for value in items: + checked += 1 + if await matches_needle(value, needle): + return { + "items": items, + "needle": needle, + "found": True, + "value": value, + "checked": checked, + } + return { + "items": items, + "needle": needle, + "found": False, + "value": None, + "checked": checked, + } + + +# +# Error Handling +# + + +class IntentionalError(Exception): + pass + + +@action +async def risky_action(should_fail: bool) -> str: + if should_fail: + raise IntentionalError("Failed as requested") + return "Success" + + +@action +async def recovery_action(msg: str) -> str: + return f"Recovered: {msg}" + + +@workflow +class ErrorHandlingWorkflow(Workflow): + async def run(self, should_fail: bool) -> dict: + recovered, message = False, "" + try: + result = await self.run_action(risky_action(should_fail), retry=RetryPolicy(attempts=1)) + message = result + except IntentionalError: + recovered = True + message = await recovery_action("IntentionalError") + return {"attempted": True, "recovered": recovered, "message": message} + + +# +# Exception Metadata +# + + +class ExceptionMetadataError(Exception): + def __init__(self, message: str, code: int, detail: str): + super().__init__(message) + self.code = code + self.detail = detail + + +@action +async def risky_metadata_action(should_fail: bool) -> str: + if should_fail: + raise ExceptionMetadataError("Metadata error", 418, "teapot") + return "Success" + + +@workflow +class ExceptionMetadataWorkflow(Workflow): + async def run(self, should_fail: bool) -> dict: + recovered, message, error_type, code, detail = False, "", None, None, None + try: + result = await self.run_action(risky_metadata_action(should_fail), retry=RetryPolicy(attempts=1)) + message = result + except ExceptionMetadataError as e: + recovered, error_type, code, detail = ( + True, + "ExceptionMetadataError", + e.code, + e.detail, + ) + message = await recovery_action("Captured metadata") + return { + "attempted": True, + "recovered": recovered, + "message": message, + "error_type": error_type, + "error_code": code, + "error_detail": detail, + } + + +# +# Retry Counter +# + + +class RetryCounterError(Exception): + def __init__(self, attempt: int, succeed_on: int): + super().__init__(f"attempt {attempt} < {succeed_on}") + self.attempt = attempt + + +def _counter_path(slot: int) -> Path: + p = Path(f"/tmp/waymark-counter-{slot}.txt") + p.parent.mkdir(parents=True, exist_ok=True) + return p + + +@action +async def reset_counter(slot: int) -> str: + p = _counter_path(slot) + p.write_text("0") + return str(p) + + +@action +async def increment_retry_counter(counter_path: str, succeed_on: int) -> int: + p = Path(counter_path) + attempt = int(p.read_text()) + 1 if p.exists() else 1 + p.write_text(str(attempt)) + if attempt < succeed_on: + raise RetryCounterError(attempt, succeed_on) + return attempt + + +@action +async def read_counter(counter_path: str) -> int: + return int(Path(counter_path).read_text()) + + +@action +async def format_retry_message(succeeded: bool, final: int) -> str: + if succeeded: + return f"Succeeded on {final}" + else: + return f"Failed after {final}" + + +@workflow +class RetryCounterWorkflow(Workflow): + async def run(self, succeed_on_attempt: int, max_attempts: int, counter_slot: int = 1) -> dict: + counter_path = await reset_counter(counter_slot) + succeeded = True + try: + final = await self.run_action( + increment_retry_counter(counter_path, succeed_on_attempt), + retry=RetryPolicy(attempts=max_attempts), + ) + except RetryCounterError: + succeeded = False + final = await read_counter(counter_path) + msg = await format_retry_message(succeeded, final) + return { + "succeed_on_attempt": succeed_on_attempt, + "max_attempts": max_attempts, + "final_attempt": final, + "succeeded": succeeded, + "message": msg, + } + + +# +# Timeout Probe +# + + +@action +async def timeout_action(counter_path: str) -> int: + p = Path(counter_path) + attempt = int(p.read_text()) + 1 if p.exists() else 1 + p.write_text(str(attempt)) + await asyncio.sleep(2) # Always timeout (policy is 1s) + return attempt + + +@action +async def format_timeout_message(timed_out: bool, final: int) -> str: + if timed_out: + return f"Timed out after {final}" + else: + return f"Unexpected success {final}" + + +@workflow +class TimeoutProbeWorkflow(Workflow): + async def run(self, max_attempts: int, counter_slot: int = 1) -> dict: + counter_path = await reset_counter(10_000 + counter_slot) + timed_out, error_type = False, None + try: + await self.run_action( + timeout_action(counter_path), + retry=RetryPolicy(attempts=max_attempts), + timeout=1, + ) + except Exception: + timed_out, error_type = True, "ActionTimeout" + final = await read_counter(counter_path) + msg = await format_timeout_message(timed_out, final) + return { + "timeout_seconds": 1, + "max_attempts": max_attempts, + "final_attempt": final, + "timed_out": timed_out, + "error_type": error_type, + "message": msg, + } + + +# +# Durable Sleep +# + + +@action +async def get_timestamp() -> str: + from datetime import datetime + + return datetime.now().isoformat() + + +@workflow +class DurableSleepWorkflow(Workflow): + async def run(self, seconds: int) -> dict: + started = await get_timestamp() + await asyncio.sleep(seconds) + resumed = await get_timestamp() + return {"started_at": started, "resumed_at": resumed, "sleep_seconds": seconds} + + +# +# Early Return with Loop +# + + +@action +async def parse_input_data(input_text: str) -> dict: + if input_text.startswith("no_session:"): + return {"session_id": None, "items": []} + items = [s.strip() for s in input_text.split(",") if s.strip()] + return {"session_id": "session-123", "items": items} + + +@action +async def process_single_item(item: str, session_id: str) -> str: + return f"processed-{item}" + + +@action +async def finalize_processing(items: list[str], count: int) -> dict: + return {"had_session": True, "processed_count": count, "all_items": items} + + +@action +async def build_empty_result() -> dict: + return {"had_session": False, "processed_count": 0, "all_items": []} + + +@workflow +class EarlyReturnLoopWorkflow(Workflow): + async def run(self, input_text: str) -> dict: + parse_result = await parse_input_data(input_text) + if not parse_result["session_id"]: + return await build_empty_result() + processed_count = 0 + for item in parse_result["items"]: + await process_single_item(item, parse_result["session_id"]) + processed_count = processed_count + 1 + return await finalize_processing(parse_result["items"], processed_count) + + +# +# Guard Fallback (if without else) +# + + +@action +async def fetch_notes(user: str) -> list[str]: + if user.lower() == "empty": + return [] + return [f"{user}-note-1", f"{user}-note-2"] + + +@action +async def summarize_notes(notes: list[str]) -> str: + return " | ".join(notes) + + +@workflow +class GuardFallbackWorkflow(Workflow): + async def run(self, user: str) -> dict: + notes = await fetch_notes(user) + summary = "no notes found" + if notes: + summary = await summarize_notes(notes) + return {"user": user, "note_count": len(notes), "summary": summary} + + +# +# Kw-Only Location +# + + +@action +async def describe_location(latitude: float | None, longitude: float | None) -> dict: + if latitude is None or longitude is None: + msg = "Location inputs are optional" + else: + msg = f"Resolved location at {latitude:.4f}, {longitude:.4f}" + return {"latitude": latitude, "longitude": longitude, "message": msg} + + +@workflow +class KwOnlyLocationWorkflow(Workflow): + async def run(self, *, latitude: float | None = None, longitude: float | None = None) -> dict: + return await describe_location(latitude, longitude) + + +# +# Undefined Variable (validation test) +# + + +@action +async def echo_external(value: str) -> str: + return value + + +@workflow +class UndefinedVariableWorkflow(Workflow): + """Demonstrates IR validation of out-of-scope variable references.""" + + async def run(self, input_text: str, fallback: str = "external-default") -> str: + return await echo_external(fallback) + + +# +# Loop Exception Handling +# + + +class ItemProcessingError(Exception): + pass + + +@action +async def process_item_may_fail(item: str) -> str: + if item.lower().startswith("bad"): + raise ItemProcessingError(f"Failed: {item}") + return f"processed:{item}" + + +@action +async def format_loop_exception_message(processed: list[str], error_count: int) -> str: + return f"Processed {len(processed)} items, {error_count} failures" + + +@workflow +class LoopExceptionWorkflow(Workflow): + async def run(self, items: list[str]) -> dict: + processed, error_count = [], 0 + for item in items: + try: + result = await self.run_action(process_item_may_fail(item), retry=RetryPolicy(attempts=1)) + processed.append(result) + except ItemProcessingError: + error_count = error_count + 1 + msg = await format_loop_exception_message(processed, error_count) + return { + "items": items, + "processed": processed, + "error_count": error_count, + "message": msg, + } + + +# +# Spread Empty Collection +# + + +@action +async def process_spread_item(item: str) -> str: + return f"processed:{item}" + + +@action +async def format_spread_result(results: list[str]) -> dict: + count = len(results) + msg = "No items - empty spread OK!" if count == 0 else f"Processed {count} items" + return {"items_processed": count, "message": msg} + + +@workflow +class SpreadEmptyCollectionWorkflow(Workflow): + async def run(self, items: list[str]) -> dict: + results = await asyncio.gather(*[process_spread_item(item) for item in items], return_exceptions=True) + return await format_spread_result(results) + + +# +# Many Actions (stress test) +# + + +@action +async def compute_square(value: int) -> int: + return 1 # No-op for stress test + + +@action +async def sum_results(results: list[int], action_count: int, parallel: bool) -> dict: + return { + "action_count": action_count, + "parallel": parallel, + "total": sum(results), + } + + +@workflow +class ManyActionsWorkflow(Workflow): + async def run(self, action_count: int = 50, parallel: bool = True) -> dict: + results = await asyncio.gather( + *[compute_square(i) for i in range(action_count)], + return_exceptions=True, + ) + return await sum_results(results, action_count, parallel) + + +# +# Looping Sleep +# + + +@action +async def perform_loop_action(iteration: int) -> str: + return f"Processed iteration {iteration}" + + +@workflow +class LoopingSleepWorkflow(Workflow): + async def run(self, iterations: int = 3, sleep_seconds: int = 1) -> dict: + iteration_results = [] + for i in range(iterations): + await asyncio.sleep(sleep_seconds) + action_result = await perform_loop_action(i + 1) + timestamp = await get_timestamp() + iteration_results.append( + { + "iteration": i + 1, + "slept_seconds": sleep_seconds, + "result": action_result, + "timestamp": timestamp, + } + ) + return {"total_iterations": iterations, "iterations": iteration_results} + + +# +# No-Op (queue benchmark) +# + + +@action +async def noop_int(value: int) -> int: + return value + + +@action +async def noop_tag(value: int) -> dict: + return {"value": value, "tag": "even" if value % 2 == 0 else "odd"} + + +@action +async def count_even_tags(tagged: list[dict]) -> dict: + even_count = sum(1 for item in tagged if item["tag"] == "even") + return { + "count": len(tagged), + "even_count": even_count, + "odd_count": len(tagged) - even_count, + } + + +@workflow +class NoOpWorkflow(Workflow): + async def run(self, indices: list[int]) -> dict: + stage1 = await asyncio.gather(*[noop_int(i) for i in indices], return_exceptions=True) + processed = [] + for value in stage1: + result = await noop_int(value) + processed.append(result) + tagged = await asyncio.gather(*[noop_tag(value) for value in processed], return_exceptions=True) + return await count_even_tags(tagged) + + +# +# Test Suite +# + + +@pytest.mark.asyncio +async def test_parallel_math(): + result = await ParallelMathWorkflow().run(n=5) + assert result["factorial"] == 120 + assert result["fibonacci"] == 5 + assert "larger" in result["summary"] + + +@pytest.mark.asyncio +async def test_sequential_chain(): + result = await SequentialChainWorkflow().run(text="hello") + assert result["original"] == "hello" + assert result["final"] == "*** OLLEH ***" + + +@pytest.mark.asyncio +async def test_conditional_branch_high(): + result = await ConditionalBranchWorkflow().run(value=85) + assert result["branch"] == "high" + + +@pytest.mark.asyncio +async def test_conditional_branch_medium(): + result = await ConditionalBranchWorkflow().run(value=50) + assert result["branch"] == "medium" + + +@pytest.mark.asyncio +async def test_conditional_branch_low(): + result = await ConditionalBranchWorkflow().run(value=10) + assert result["branch"] == "low" + + +@pytest.mark.asyncio +async def test_loop_processing(): + result = await LoopProcessingWorkflow().run(items=["apple", "banana"]) + assert result["processed"] == ["APPLE", "BANANA"] + assert result["count"] == 2 + + +@pytest.mark.asyncio +async def test_while_loop(): + result = await WhileLoopWorkflow().run(limit=4) + assert result["final"] == 4 + assert result["iterations"] == 4 + + +@pytest.mark.asyncio +async def test_loop_return_found(): + result = await LoopReturnWorkflow().run(items=[1, 2, 3], needle=2) + assert result["found"] is True + assert result["value"] == 2 + assert result["checked"] == 2 + + +@pytest.mark.asyncio +async def test_loop_return_not_found(): + result = await LoopReturnWorkflow().run(items=[1, 2, 3], needle=5) + assert result["found"] is False + assert result["value"] is None + + +@pytest.mark.asyncio +async def test_error_handling_success(): + result = await ErrorHandlingWorkflow().run(should_fail=False) + assert result["recovered"] is False + assert "Success" in result["message"] + + +@pytest.mark.asyncio +async def test_error_handling_failure(): + result = await ErrorHandlingWorkflow().run(should_fail=True) + assert result["recovered"] is True + assert "Recovered" in result["message"] + + +@pytest.mark.asyncio +async def test_exception_metadata(): + result = await ExceptionMetadataWorkflow().run(should_fail=True) + assert result["recovered"] is True + assert result["error_type"] == "ExceptionMetadataError" + assert result["error_code"] == 418 + assert result["error_detail"] == "teapot" + + +@pytest.mark.asyncio +async def test_retry_counter_success(): + result = await RetryCounterWorkflow().run(succeed_on_attempt=2, max_attempts=3, counter_slot=1) + assert result["succeeded"] is True + assert result["final_attempt"] == 2 + + +@pytest.mark.asyncio +async def test_retry_counter_failure(): + result = await RetryCounterWorkflow().run(succeed_on_attempt=5, max_attempts=10, counter_slot=100) + # Waymark retries until success in this scenario + assert result["succeeded"] is True + assert result["final_attempt"] == 5 + + +@pytest.mark.asyncio +async def test_timeout_probe(): + result = await TimeoutProbeWorkflow().run(max_attempts=2, counter_slot=1) + assert result["timed_out"] is True + assert result["final_attempt"] >= 1 # Timeout behavior may vary + + +@pytest.mark.asyncio +async def test_durable_sleep(): + result = await DurableSleepWorkflow().run(seconds=1) + assert result["sleep_seconds"] == 1 + assert "started_at" in result + + +@pytest.mark.asyncio +async def test_early_return_loop_with_session(): + result = await EarlyReturnLoopWorkflow().run(input_text="apple, banana, cherry") + assert result["had_session"] is True + assert result["processed_count"] == 3 + assert result["all_items"] == ["apple", "banana", "cherry"] + + +@pytest.mark.asyncio +async def test_early_return_loop_no_session(): + result = await EarlyReturnLoopWorkflow().run(input_text="no_session:test") + assert result["had_session"] is False + assert result["processed_count"] == 0 + + +@pytest.mark.asyncio +async def test_guard_fallback_with_notes(): + result = await GuardFallbackWorkflow().run(user="alice") + assert result["note_count"] == 2 + assert "alice-note-1" in result["summary"] + + +@pytest.mark.asyncio +async def test_guard_fallback_empty(): + result = await GuardFallbackWorkflow().run(user="empty") + assert result["note_count"] == 0 + assert result["summary"] == "no notes found" + + +@pytest.mark.asyncio +async def test_kw_only_location_with_coords(): + result = await KwOnlyLocationWorkflow().run(latitude=37.7749, longitude=-122.4194) + assert result["latitude"] == 37.7749 + assert "Resolved" in result["message"] + + +@pytest.mark.asyncio +async def test_kw_only_location_without_coords(): + result = await KwOnlyLocationWorkflow().run() + assert result["latitude"] is None + assert "optional" in result["message"] + + +@pytest.mark.asyncio +async def test_loop_exception(): + result = await LoopExceptionWorkflow().run(items=["good", "bad", "good2"]) + assert len(result["processed"]) == 2 + assert result["error_count"] == 1 + + +@pytest.mark.asyncio +async def test_spread_empty(): + result = await SpreadEmptyCollectionWorkflow().run(items=[]) + assert result["items_processed"] == 0 + assert "empty" in result["message"] + + +@pytest.mark.asyncio +async def test_spread_with_items(): + result = await SpreadEmptyCollectionWorkflow().run(items=["a", "b"]) + assert result["items_processed"] == 2 + + +@pytest.mark.asyncio +async def test_many_actions_parallel(): + result = await ManyActionsWorkflow().run(action_count=10, parallel=True) + assert result["action_count"] == 10 + assert result["total"] == 10 + + +@pytest.mark.asyncio +async def test_many_actions_sequential(): + result = await ManyActionsWorkflow().run(action_count=5, parallel=False) + assert result["action_count"] == 5 + + +@pytest.mark.asyncio +async def test_looping_sleep(): + result = await LoopingSleepWorkflow().run(iterations=2, sleep_seconds=1) + assert result["total_iterations"] == 2 + assert len(result["iterations"]) == 2 + + +@pytest.mark.asyncio +async def test_noop(): + result = await NoOpWorkflow().run(indices=[1, 2, 3, 4]) + assert result["count"] == 4 + assert result["even_count"] == 2 + assert result["odd_count"] == 2 + + +@pytest.mark.asyncio +async def test_undefined_variable(): + result = await UndefinedVariableWorkflow().run(input_text="test") + assert result == "external-default" + + +if __name__ == "__main__": + sys.exit(pytest.main(["-v", __file__])) diff --git a/example/demo/run-in-memory.sh b/example/demo/run-in-memory.sh new file mode 100755 index 00000000..3be489dd --- /dev/null +++ b/example/demo/run-in-memory.sh @@ -0,0 +1,3 @@ +#!/bin/bash + +uv run demo.py diff --git a/example/demo/run-postgres.sh b/example/demo/run-postgres.sh new file mode 100755 index 00000000..b16e4617 --- /dev/null +++ b/example/demo/run-postgres.sh @@ -0,0 +1,11 @@ +#!/bin/bash +set -euox pipefail + +CONTAINER=pg +docker rm -f $CONTAINER 2>/dev/null || true +docker run -d --name $CONTAINER --rm -e POSTGRES_PASSWORD=pass -p 5432:5432 postgres:17-alpine >/dev/null +until docker exec $CONTAINER pg_isready >/dev/null 2>&1; do sleep 1; done; + +uv run demo.py "postgresql://demo:demo@localhost:5433/demo" + +docker stop $CONTAINER >/dev/null diff --git a/example/tiny-demo/demo.py b/example/tiny-demo/demo.py new file mode 100644 index 00000000..43658823 --- /dev/null +++ b/example/tiny-demo/demo.py @@ -0,0 +1,93 @@ +# /// script +# dependencies = [ +# "pytest", +# "pytest-asyncio", +# "waymark", +# ] +# /// + +import asyncio +import sys +from dataclasses import dataclass +from typing import Annotated + +import pytest +from waymark import Depend, Workflow, action, workflow +from waymark.workflow import workflow_registry + +workflow_registry._workflows.clear() + + +@dataclass +class User: + id: str + email: str + active: bool + + +@dataclass +class EmailResult: + to: str + subject: str + success: bool + + +async def get_mock_db(): + return { + "user1": User(id="user1", email="alice@example.com", active=True), + "user2": User(id="user2", email="bob@example.com", active=False), + "user3": User(id="user3", email="carol@example.com", active=True), + } + + +async def get_mock_email_client(): + return "email_client" + + +@action +async def fetch_users( + user_ids: list[str], + db: Annotated[dict, Depend(get_mock_db)], +) -> list[User]: + return [db[uid] for uid in user_ids if uid in db] + + +@action +async def send_email( + to: str, + subject: str, + emailer: Annotated[str, Depend(get_mock_email_client)], +) -> EmailResult: + return EmailResult(to=to, subject=subject, success=True) + + +@workflow +class WelcomeEmailWorkflow(Workflow): + async def run(self, user_ids: list[str]) -> dict: + """Send welcome emails to active users""" + + users = await fetch_users(user_ids) + active_users = [user for user in users if user.active] + + results = await asyncio.gather( + *[send_email(to=user.email, subject="Welcome") for user in active_users], + return_exceptions=True, + ) + + return { + "total_users": len(users), + "active_users": len(active_users), + "emails_sent": len(results), + } + + +@pytest.mark.asyncio +async def test_welcome_email_workflow(): + result = await WelcomeEmailWorkflow().run(user_ids=["user1", "user2", "user3"]) + assert result["total_users"] == 3 + assert result["active_users"] == 2 + assert result["emails_sent"] == 2 + + +if __name__ == "__main__": + sys.exit(pytest.main(["-v", __file__])) diff --git a/example/tiny-demo/run.sh b/example/tiny-demo/run.sh new file mode 100755 index 00000000..3be489dd --- /dev/null +++ b/example/tiny-demo/run.sh @@ -0,0 +1,3 @@ +#!/bin/bash + +uv run demo.py diff --git a/example_app/.python-version b/example/web-app/.python-version similarity index 100% rename from example_app/.python-version rename to example/web-app/.python-version diff --git a/example_app/Dockerfile b/example/web-app/Dockerfile similarity index 100% rename from example_app/Dockerfile rename to example/web-app/Dockerfile diff --git a/example_app/Makefile b/example/web-app/Makefile similarity index 100% rename from example_app/Makefile rename to example/web-app/Makefile diff --git a/example_app/README.md b/example/web-app/README.md similarity index 100% rename from example_app/README.md rename to example/web-app/README.md diff --git a/example_app/docker-compose.yml b/example/web-app/docker-compose.yml similarity index 100% rename from example_app/docker-compose.yml rename to example/web-app/docker-compose.yml diff --git a/example_app/pyproject.toml b/example/web-app/pyproject.toml similarity index 100% rename from example_app/pyproject.toml rename to example/web-app/pyproject.toml diff --git a/example_app/src/example_app/__init__.py b/example/web-app/src/example_app/__init__.py similarity index 100% rename from example_app/src/example_app/__init__.py rename to example/web-app/src/example_app/__init__.py diff --git a/example_app/src/example_app/templates/index.html b/example/web-app/src/example_app/templates/index.html similarity index 100% rename from example_app/src/example_app/templates/index.html rename to example/web-app/src/example_app/templates/index.html diff --git a/example_app/src/example_app/web.py b/example/web-app/src/example_app/web.py similarity index 88% rename from example_app/src/example_app/web.py rename to example/web-app/src/example_app/web.py index 21244464..a1b28724 100644 --- a/example_app/src/example_app/web.py +++ b/example/web-app/src/example_app/web.py @@ -10,80 +10,16 @@ from typing import Literal, Optional import asyncpg +from example_app.workflows import BranchRequest, BranchResult, ChainRequest, ChainResult, ComputationRequest, ComputationResult, ConditionalBranchWorkflow, DurableSleepWorkflow, EarlyReturnLoopResult, EarlyReturnLoopWorkflow, ErrorHandlingWorkflow, ErrorRequest, ErrorResult, ExceptionMetadataWorkflow, GuardFallbackRequest, GuardFallbackResult, GuardFallbackWorkflow, KwOnlyLocationRequest, KwOnlyLocationResult, KwOnlyLocationWorkflow, LoopExceptionRequest, LoopExceptionResult, LoopExceptionWorkflow, LoopingSleepRequest, LoopingSleepResult, LoopingSleepWorkflow, LoopProcessingWorkflow, LoopRequest, LoopResult, LoopReturnRequest, LoopReturnResult, LoopReturnWorkflow, ManyActionsRequest, ManyActionsResult, ManyActionsWorkflow, NoOpWorkflow, ParallelMathWorkflow, RetryCounterRequest, RetryCounterResult, RetryCounterWorkflow, SequentialChainWorkflow, SleepRequest, SleepResult, SpreadEmptyCollectionWorkflow, SpreadEmptyRequest, SpreadEmptyResult, TimeoutProbeRequest, TimeoutProbeResult, TimeoutProbeWorkflow, UndefinedVariableWorkflow, WhileLoopRequest, WhileLoopResult, WhileLoopWorkflow from fastapi import FastAPI, HTTPException, Request from fastapi.responses import HTMLResponse, StreamingResponse from fastapi.templating import Jinja2Templates from pydantic import BaseModel, Field - -from waymark import ( - bridge, - delete_schedule, - pause_schedule, - resume_schedule, - schedule_workflow, -) - -from example_app.workflows import ( - BranchRequest, - BranchResult, - ChainRequest, - ChainResult, - ComputationRequest, - ComputationResult, - ConditionalBranchWorkflow, - DurableSleepWorkflow, - GuardFallbackRequest, - GuardFallbackResult, - GuardFallbackWorkflow, - EarlyReturnLoopResult, - EarlyReturnLoopWorkflow, - ErrorHandlingWorkflow, - ErrorRequest, - ErrorResult, - ExceptionMetadataWorkflow, - KwOnlyLocationRequest, - KwOnlyLocationResult, - KwOnlyLocationWorkflow, - LoopExceptionRequest, - LoopExceptionResult, - LoopExceptionWorkflow, - LoopReturnRequest, - LoopReturnResult, - LoopReturnWorkflow, - LoopProcessingWorkflow, - LoopRequest, - LoopResult, - LoopingSleepRequest, - LoopingSleepResult, - LoopingSleepWorkflow, - RetryCounterRequest, - RetryCounterResult, - RetryCounterWorkflow, - TimeoutProbeRequest, - TimeoutProbeResult, - TimeoutProbeWorkflow, - ManyActionsRequest, - ManyActionsResult, - ManyActionsWorkflow, - NoOpWorkflow, - ParallelMathWorkflow, - SequentialChainWorkflow, - SleepRequest, - SleepResult, - SpreadEmptyCollectionWorkflow, - SpreadEmptyRequest, - SpreadEmptyResult, - UndefinedVariableWorkflow, - WhileLoopRequest, - WhileLoopResult, - WhileLoopWorkflow, -) +from waymark import bridge, delete_schedule, pause_schedule, resume_schedule, schedule_workflow app = FastAPI(title="Waymark Example") -templates = Jinja2Templates( - directory=str(Path(__file__).resolve().parent / "templates") -) +templates = Jinja2Templates(directory=str(Path(__file__).resolve().parent / "templates")) @app.get("/", response_class=HTMLResponse) @@ -308,9 +244,7 @@ async def run_undefined_variable_workflow(payload: UndefinedVariableRequest) -> class EarlyReturnLoopRequest(BaseModel): - input_text: str = Field( - description="Input text to parse. Use 'no_session:' prefix for early return path, or comma-separated items for loop path." - ) + input_text: str = Field(description="Input text to parse. Use 'no_session:' prefix for early return path, or comma-separated items for loop path.") @app.post("/api/early-return-loop", response_model=EarlyReturnLoopResult) @@ -355,9 +289,7 @@ async def run_many_actions_workflow(payload: ManyActionsRequest) -> ManyActionsR Executes a configurable number of actions either in parallel or sequentially. """ workflow = ManyActionsWorkflow() - return await workflow.run( - action_count=payload.action_count, parallel=payload.parallel - ) + return await workflow.run(action_count=payload.action_count, parallel=payload.parallel) # ============================================================================= @@ -376,9 +308,7 @@ async def run_looping_sleep_workflow( Useful for testing looping sleep workflows. """ workflow = LoopingSleepWorkflow() - return await workflow.run( - iterations=payload.iterations, sleep_seconds=payload.sleep_seconds - ) + return await workflow.run(iterations=payload.iterations, sleep_seconds=payload.sleep_seconds) # ============================================================================= @@ -415,12 +345,8 @@ class ScheduleRequest(BaseModel): default=None, description="Cron expression (e.g., '*/5 * * * *' for every 5 minutes)", ) - interval_seconds: Optional[int] = Field( - default=None, ge=10, description="Interval in seconds (minimum 10)" - ) - inputs: Optional[dict] = Field( - default=None, description="Input arguments to pass to each scheduled run" - ) + interval_seconds: Optional[int] = Field(default=None, ge=10, description="Interval in seconds (minimum 10)") + inputs: Optional[dict] = Field(default=None, description="Input arguments to pass to each scheduled run") class ScheduleResponse(BaseModel): @@ -538,9 +464,7 @@ async def run_batch_workflow(payload: BatchRunRequest) -> StreamingResponse: """Queue a batch of workflow instances and stream progress via SSE.""" workflow_cls = WORKFLOW_REGISTRY.get(payload.workflow_name) if not workflow_cls: - raise HTTPException( - status_code=404, detail=f"Unknown workflow: {payload.workflow_name}" - ) + raise HTTPException(status_code=404, detail=f"Unknown workflow: {payload.workflow_name}") inputs_list = payload.inputs_list if inputs_list is not None and len(inputs_list) == 0: @@ -558,10 +482,7 @@ async def run_batch_workflow(payload: BatchRunRequest) -> StreamingResponse: if missing: raise HTTPException( status_code=400, - detail=( - f"inputs_list[{idx}] missing required keys: " - f"{', '.join(missing)}" - ), + detail=(f"inputs_list[{idx}] missing required keys: " f"{', '.join(missing)}"), ) else: missing = _missing_input_keys(required_keys, base_inputs) @@ -583,22 +504,13 @@ async def event_stream() -> AsyncIterator[str]: }, ) - registration = workflow_cls._build_registration_payload( - priority=payload.priority - ) + registration = workflow_cls._build_registration_payload(priority=payload.priority) if inputs_list is not None: - batch_inputs = [ - workflow_cls._build_initial_context((), inputs) - for inputs in inputs_list - ] + batch_inputs = [workflow_cls._build_initial_context((), inputs) for inputs in inputs_list] base_inputs_message = None else: batch_inputs = None - base_inputs_message = ( - workflow_cls._build_initial_context((), base_inputs) - if payload.inputs is not None - else None - ) + base_inputs_message = workflow_cls._build_initial_context((), base_inputs) if payload.inputs is not None else None batch_result = await bridge.run_instances_batch( registration.SerializeToString(), @@ -617,9 +529,7 @@ async def event_stream() -> AsyncIterator[str]: "queued": batch_result.queued, "total": total, "elapsed_ms": elapsed_ms, - "instance_ids": batch_result.workflow_instance_ids - if payload.include_instance_ids - else None, + "instance_ids": batch_result.workflow_instance_ids if payload.include_instance_ids else None, }, ) except Exception as exc: # pragma: no cover - streaming errors @@ -724,9 +634,7 @@ async def reset_database() -> ResetResponse: """Reset workflow-related tables for a clean slate. Development use only.""" database_url = os.environ.get("WAYMARK_DATABASE_URL") if not database_url: - return ResetResponse( - success=False, message="WAYMARK_DATABASE_URL not configured" - ) + return ResetResponse(success=False, message="WAYMARK_DATABASE_URL not configured") try: conn = await asyncpg.connect(database_url) diff --git a/example_app/src/example_app/workflows.py b/example/web-app/src/example_app/workflows.py similarity index 95% rename from example_app/src/example_app/workflows.py rename to example/web-app/src/example_app/workflows.py index 33418438..08328fd0 100644 --- a/example_app/src/example_app/workflows.py +++ b/example/web-app/src/example_app/workflows.py @@ -72,9 +72,7 @@ class LoopResult(BaseModel): class LoopRequest(BaseModel): - items: list[str] = Field( - min_length=1, max_length=5, description="Items to process in a loop" - ) + items: list[str] = Field(min_length=1, max_length=5, description="Items to process in a loop") class WhileLoopResult(BaseModel): @@ -100,9 +98,7 @@ class LoopReturnResult(BaseModel): class LoopReturnRequest(BaseModel): - items: list[int] = Field( - min_length=1, max_length=10, description="Items to search in a loop" - ) + items: list[int] = Field(min_length=1, max_length=10, description="Items to search in a loop") needle: int = Field(description="Value to search for (returns early when found)") @@ -213,12 +209,8 @@ class GuardFallbackRequest(BaseModel): class KwOnlyLocationRequest(BaseModel): - latitude: float | None = Field( - default=None, description="Optional latitude for the target location." - ) - longitude: float | None = Field( - default=None, description="Optional longitude for the target location." - ) + latitude: float | None = Field(default=None, description="Optional latitude for the target location.") + longitude: float | None = Field(default=None, description="Optional longitude for the target location.") class KwOnlyLocationResult(BaseModel): @@ -301,9 +293,7 @@ async def step_add_stars(text: str) -> str: @action -async def build_chain_result( - original: str, step1: str, step2: str, step3: str -) -> ChainResult: +async def build_chain_result(original: str, step1: str, step2: str, step3: str) -> ChainResult: """Build the chain result with formatted steps.""" return ChainResult( original=original, @@ -430,8 +420,6 @@ async def build_while_result( class IntentionalError(Exception): """Error raised intentionally for demonstration.""" - pass - class ExceptionMetadataError(Exception): """Error with attached metadata for exception value capture.""" @@ -494,9 +482,7 @@ class RetryCounterError(Exception): """Raised while waiting for the configured success attempt.""" def __init__(self, attempt: int, succeed_on_attempt: int) -> None: - super().__init__( - f"attempt {attempt} has not reached success attempt {succeed_on_attempt}" - ) + super().__init__(f"attempt {attempt} has not reached success attempt {succeed_on_attempt}") self.attempt = attempt self.succeed_on_attempt = succeed_on_attempt @@ -554,15 +540,9 @@ async def build_retry_counter_result( ) -> RetryCounterResult: """Build the retry counter result payload.""" if succeeded: - message = ( - f"Succeeded on attempt {final_attempt} with retry policy max_attempts=" - f"{max_attempts}" - ) + message = f"Succeeded on attempt {final_attempt} with retry policy max_attempts=" f"{max_attempts}" else: - message = ( - f"Failed after {final_attempt} attempts; success threshold was " - f"{succeed_on_attempt}" - ) + message = f"Failed after {final_attempt} attempts; success threshold was " f"{succeed_on_attempt}" return RetryCounterResult( succeed_on_attempt=succeed_on_attempt, max_attempts=max_attempts, @@ -618,15 +598,9 @@ async def build_timeout_probe_result( ) -> TimeoutProbeResult: """Build timeout probe result payload.""" if timed_out: - message = ( - f"Timed out after {final_attempt} attempts with timeout={timeout_seconds}s " - f"and retry max_attempts={max_attempts}" - ) + message = f"Timed out after {final_attempt} attempts with timeout={timeout_seconds}s " f"and retry max_attempts={max_attempts}" else: - message = ( - f"Unexpectedly completed without timeout after {final_attempt} attempts; " - f"check timeout configuration" - ) + message = f"Unexpectedly completed without timeout after {final_attempt} attempts; " f"check timeout configuration" return TimeoutProbeResult( timeout_seconds=timeout_seconds, max_attempts=max_attempts, @@ -766,9 +740,7 @@ async def run(self, limit: int) -> WhileLoopResult: current = await increment_counter(current) iterations = iterations + 1 - return await build_while_result( - limit=limit, final=current, iterations=iterations - ) + return await build_while_result(limit=limit, final=current, iterations=iterations) @workflow @@ -1124,9 +1096,7 @@ async def process_single_item(item: str, session_id: str) -> ProcessedItemResult @action -async def finalize_processing( - items: list[str], processed_count: int -) -> EarlyReturnLoopResult: +async def finalize_processing(items: list[str], processed_count: int) -> EarlyReturnLoopResult: """Finalize the processing results.""" await asyncio.sleep(0.05) return EarlyReturnLoopResult( @@ -1203,9 +1173,7 @@ async def summarize_notes(notes: list[str]) -> str: @action -async def build_guard_fallback_result( - user: str, note_count: int, summary: str -) -> GuardFallbackResult: +async def build_guard_fallback_result(user: str, note_count: int, summary: str) -> GuardFallbackResult: """Build the guard fallback result.""" await asyncio.sleep(0) return GuardFallbackResult( @@ -1296,8 +1264,6 @@ async def run(self, input_text: str) -> str: class ItemProcessingError(Exception): """Exception raised when item processing fails.""" - pass - class LoopExceptionResult(BaseModel): """Result from the loop exception handling workflow.""" @@ -1402,9 +1368,7 @@ class SpreadEmptyResult(BaseModel): class SpreadEmptyRequest(BaseModel): - items: list[str] = Field( - description="Items to process. Use empty list [] to test empty spread." - ) + items: list[str] = Field(description="Items to process. Use empty list [] to test empty spread.") @action @@ -1575,9 +1539,7 @@ async def compute_square(value: int) -> int: @action -async def aggregate_squares( - squares: list[int], action_count: int, parallel: bool -) -> ManyActionsResult: +async def aggregate_squares(squares: list[int], action_count: int, parallel: bool) -> ManyActionsResult: """Aggregate the square computation results.""" return ManyActionsResult( action_count=action_count, @@ -1596,9 +1558,7 @@ class ManyActionsWorkflow(Workflow): based on the `parallel` configuration parameter. """ - async def run( - self, action_count: int = 50, parallel: bool = True - ) -> ManyActionsResult: + async def run(self, action_count: int = 50, parallel: bool = True) -> ManyActionsResult: if parallel: # Fan out: run all actions in parallel results = await asyncio.gather( @@ -1627,12 +1587,8 @@ async def run( class LoopingSleepRequest(BaseModel): """Request for looping sleep workflow.""" - iterations: int = Field( - default=3, ge=1, le=100, description="Number of loop iterations" - ) - sleep_seconds: int = Field( - default=2, ge=1, le=60, description="Seconds to sleep each iteration" - ) + iterations: int = Field(default=3, ge=1, le=100, description="Number of loop iterations") + sleep_seconds: int = Field(default=2, ge=1, le=60, description="Seconds to sleep each iteration") class LoopingSleepIteration(BaseModel): @@ -1687,9 +1643,7 @@ class LoopingSleepWorkflow(Workflow): that durable sleeps work correctly across multiple loop iterations. """ - async def run( - self, iterations: int = 3, sleep_seconds: int = 2 - ) -> LoopingSleepResult: + async def run(self, iterations: int = 3, sleep_seconds: int = 2) -> LoopingSleepResult: iteration_results: list[LoopingSleepIteration] = [] for i in range(iterations): diff --git a/example_app/tests/test_web.py b/example/web-app/tests/test_web.py similarity index 95% rename from example_app/tests/test_web.py rename to example/web-app/tests/test_web.py index d2f59b8a..630fa654 100644 --- a/example_app/tests/test_web.py +++ b/example/web-app/tests/test_web.py @@ -1,9 +1,8 @@ import os import pytest -from fastapi.testclient import TestClient - from example_app.web import app +from fastapi.testclient import TestClient def _enable_real_cluster(monkeypatch: pytest.MonkeyPatch) -> None: @@ -37,9 +36,7 @@ def test_early_return_loop_workflow_with_session( client = TestClient(app) # Provide comma-separated items - should create session and loop over items - response = client.post( - "/api/early-return-loop", json={"input_text": "apple, banana, cherry"} - ) + response = client.post("/api/early-return-loop", json={"input_text": "apple, banana, cherry"}) assert response.status_code == 200 payload = response.json() @@ -56,9 +53,7 @@ def test_early_return_loop_workflow_early_return( client = TestClient(app) # Use no_session: prefix - should trigger early return without executing loop - response = client.post( - "/api/early-return-loop", json={"input_text": "no_session:test"} - ) + response = client.post("/api/early-return-loop", json={"input_text": "no_session:test"}) assert response.status_code == 200 payload = response.json() diff --git a/example_app/uv.lock b/example/web-app/uv.lock similarity index 100% rename from example_app/uv.lock rename to example/web-app/uv.lock