From 26d3966d343aa45115465a8015baff8fb519296c Mon Sep 17 00:00:00 2001 From: Hassieb Pakzad <68423100+hassiebp@users.noreply.github.com> Date: Sat, 4 Apr 2026 15:23:03 +0200 Subject: [PATCH 01/23] split test suites by execution level --- .github/workflows/ci.yml | 57 +- pyproject.toml | 5 + tests/conftest.py | 149 ++ tests/e2e/__init__.py | 1 + tests/{ => e2e}/test_batch_evaluation.py | 2 +- tests/{ => e2e}/test_core_sdk.py | 4 +- tests/{ => e2e}/test_datasets.py | 2 +- tests/{ => e2e}/test_decorators.py | 2 +- tests/{ => e2e}/test_experiments.py | 2 +- tests/e2e/test_media.py | 58 + tests/e2e/test_prompt.py | 697 ++++++++ tests/live_provider/__init__.py | 1 + tests/{ => live_provider}/test_langchain.py | 2 +- .../test_langchain_integration.py | 4 +- tests/{ => live_provider}/test_openai.py | 2 +- tests/live_provider/test_prompt.py | 87 + tests/support/__init__.py | 1 + tests/{ => support}/api_wrapper.py | 0 tests/{ => support}/utils.py | 0 tests/test_prompt.py | 1524 ----------------- tests/unit/__init__.py | 1 + .../test_additional_headers_simple.py | 0 tests/{ => unit}/test_error_logging.py | 0 tests/{ => unit}/test_error_parsing.py | 0 tests/{ => unit}/test_initialization.py | 0 tests/{ => unit}/test_json.py | 0 tests/unit/test_langchain.py | 168 ++ tests/{ => unit}/test_logger.py | 0 tests/{ => unit}/test_media.py | 62 - tests/{ => unit}/test_media_manager.py | 0 tests/unit/test_openai.py | 238 +++ .../test_openai_prompt_extraction.py | 0 tests/{ => unit}/test_otel.py | 0 tests/{ => unit}/test_parse_usage_model.py | 0 tests/unit/test_prompt.py | 673 ++++++++ tests/{ => unit}/test_prompt_atexit.py | 0 tests/{ => unit}/test_prompt_compilation.py | 0 tests/{ => unit}/test_propagate_attributes.py | 2 +- tests/{ => unit}/test_resource_manager.py | 6 +- tests/{ => unit}/test_serializer.py | 0 tests/{ => unit}/test_span_filter.py | 0 tests/{ => unit}/test_utils.py | 0 tests/{ => unit}/test_version.py | 0 43 files changed, 2137 insertions(+), 1613 deletions(-) create mode 100644 tests/conftest.py create mode 100644 tests/e2e/__init__.py rename tests/{ => e2e}/test_batch_evaluation.py (99%) rename tests/{ => e2e}/test_core_sdk.py (99%) rename tests/{ => e2e}/test_datasets.py (99%) rename tests/{ => e2e}/test_decorators.py (99%) rename tests/{ => e2e}/test_experiments.py (99%) create mode 100644 tests/e2e/test_media.py create mode 100644 tests/e2e/test_prompt.py create mode 100644 tests/live_provider/__init__.py rename tests/{ => live_provider}/test_langchain.py (99%) rename tests/{ => live_provider}/test_langchain_integration.py (99%) rename tests/{ => live_provider}/test_openai.py (99%) create mode 100644 tests/live_provider/test_prompt.py create mode 100644 tests/support/__init__.py rename tests/{ => support}/api_wrapper.py (100%) rename tests/{ => support}/utils.py (100%) delete mode 100644 tests/test_prompt.py create mode 100644 tests/unit/__init__.py rename tests/{ => unit}/test_additional_headers_simple.py (100%) rename tests/{ => unit}/test_error_logging.py (100%) rename tests/{ => unit}/test_error_parsing.py (100%) rename tests/{ => unit}/test_initialization.py (100%) rename tests/{ => unit}/test_json.py (100%) create mode 100644 tests/unit/test_langchain.py rename tests/{ => unit}/test_logger.py (100%) rename tests/{ => unit}/test_media.py (69%) rename tests/{ => unit}/test_media_manager.py (100%) create mode 100644 tests/unit/test_openai.py rename tests/{ => unit}/test_openai_prompt_extraction.py (100%) rename tests/{ => unit}/test_otel.py (100%) rename tests/{ => unit}/test_parse_usage_model.py (100%) create mode 100644 tests/unit/test_prompt.py rename tests/{ => unit}/test_prompt_atexit.py (100%) rename tests/{ => unit}/test_prompt_compilation.py (100%) rename tests/{ => unit}/test_propagate_attributes.py (99%) rename tests/{ => unit}/test_resource_manager.py (92%) rename tests/{ => unit}/test_serializer.py (100%) rename tests/{ => unit}/test_span_filter.py (100%) rename tests/{ => unit}/test_utils.py (100%) rename tests/{ => unit}/test_version.py (100%) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 68c3dc977..201bb35a3 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -52,17 +52,14 @@ jobs: - name: Run mypy type checking run: uv run --frozen mypy langfuse --no-error-summary - ci: + unit-tests: runs-on: ubuntu-latest timeout-minutes: 30 env: LANGFUSE_BASE_URL: "http://localhost:3000" - LANGFUSE_PUBLIC_KEY: "pk-lf-1234567890" - LANGFUSE_SECRET_KEY: "sk-lf-1234567890" - OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }} - # SERPAPI_API_KEY: ${{ secrets.SERPAPI_API_KEY }} - HUGGINGFACEHUB_API_TOKEN: ${{ secrets.HUGGINGFACEHUB_API_TOKEN }} - ANTHROPIC_API_KEY: ${{ secrets.ANTHROPIC_API_KEY }} + LANGFUSE_PUBLIC_KEY: "pk-lf-test" + LANGFUSE_SECRET_KEY: "sk-lf-test" + OPENAI_API_KEY: "test-openai-key" strategy: fail-fast: false matrix: @@ -73,7 +70,40 @@ jobs: - "3.13" - "3.14" - name: Test on Python version ${{ matrix.python-version }} + name: Unit tests on Python ${{ matrix.python-version }} + steps: + - uses: actions/checkout@v3 + - name: Install uv and set Python version + uses: astral-sh/setup-uv@v7 + with: + version: "0.11.2" + python-version: ${{ matrix.python-version }} + enable-cache: true + + - name: Check Python version + run: python --version + + - name: Install the project dependencies + run: uv sync --locked + + - name: Run the automated tests + run: | + python --version + uv run --frozen pytest -n auto --dist loadfile -s -v --log-cli-level=INFO tests/unit + + e2e-tests: + runs-on: ubuntu-latest + timeout-minutes: 30 + env: + LANGFUSE_BASE_URL: "http://localhost:3000" + LANGFUSE_PUBLIC_KEY: "pk-lf-1234567890" + LANGFUSE_SECRET_KEY: "sk-lf-1234567890" + OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }} + # SERPAPI_API_KEY: ${{ secrets.SERPAPI_API_KEY }} + HUGGINGFACEHUB_API_TOKEN: ${{ secrets.HUGGINGFACEHUB_API_TOKEN }} + ANTHROPIC_API_KEY: ${{ secrets.ANTHROPIC_API_KEY }} + + name: E2E tests on Python 3.13 steps: - uses: actions/checkout@v3 - uses: pnpm/action-setup@v3 @@ -115,7 +145,7 @@ jobs: echo "::group::Seed db" cp .env.dev.example .env - pnpm run db:migrate + pnpm run db:migrate pnpm run db:seed echo "::endgroup::" rm -rf .env @@ -134,7 +164,6 @@ jobs: echo "::endgroup::" - # Add this step to check the health of the container - name: Health check for langfuse server run: | echo "Checking if the langfuse server is up..." @@ -158,7 +187,7 @@ jobs: uses: astral-sh/setup-uv@v7 with: version: "0.11.2" - python-version: ${{ matrix.python-version }} + python-version: "3.13" enable-cache: true - name: Check Python version @@ -167,15 +196,15 @@ jobs: - name: Install the project dependencies run: uv sync --locked - - name: Run the automated tests + - name: Run the end-to-end tests run: | python --version - uv run --frozen pytest -n auto --dist loadfile -s -v --log-cli-level=INFO + uv run --frozen pytest -s -v --log-cli-level=INFO tests/e2e all-tests-passed: # This allows us to have a branch protection rule for tests and deploys with matrix runs-on: ubuntu-latest - needs: [ci, linting, type-checking] + needs: [unit-tests, e2e-tests, linting, type-checking] if: always() steps: - name: Successful deploy diff --git a/pyproject.toml b/pyproject.toml index 1467ef436..f70a62c63 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -48,6 +48,11 @@ module-root = "" [tool.pytest.ini_options] log_cli = true +markers = [ + "unit: deterministic tests that run without a Langfuse server", + "e2e: tests that require a real Langfuse server or persisted backend behaviour", + "live_provider: tests that call live model providers and are kept out of default CI", +] [tool.mypy] python_version = "3.12" diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 000000000..c1e8e3b87 --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,149 @@ +import json +from pathlib import Path +from typing import Any, Iterable, Sequence + +import pytest +from opentelemetry.sdk.resources import Resource +from opentelemetry.sdk.trace import ReadableSpan, TracerProvider +from opentelemetry.sdk.trace.export import SpanExporter, SpanExportResult + +from langfuse._client.client import Langfuse +from langfuse._client.resource_manager import LangfuseResourceManager + + +class InMemorySpanExporter(SpanExporter): + """Simple in-memory exporter to collect spans for deterministic tests.""" + + def __init__(self) -> None: + self._finished_spans: list[ReadableSpan] = [] + self._stopped = False + + def export(self, spans: Sequence[ReadableSpan]) -> SpanExportResult: + if self._stopped: + return SpanExportResult.FAILURE + + self._finished_spans.extend(spans) + return SpanExportResult.SUCCESS + + def shutdown(self) -> None: + self._stopped = True + + def get_finished_spans(self) -> list[ReadableSpan]: + return list(self._finished_spans) + + def clear(self) -> None: + self._finished_spans.clear() + + +def pytest_collection_modifyitems(items: list[pytest.Item]) -> None: + for item in items: + test_group = Path(str(item.fspath)).parent.name + + if test_group == "unit": + item.add_marker(pytest.mark.unit) + continue + + if test_group == "e2e": + item.add_marker(pytest.mark.e2e) + continue + + if test_group == "live_provider": + item.add_marker(pytest.mark.e2e) + item.add_marker(pytest.mark.live_provider) + + +@pytest.fixture(autouse=True) +def reset_langfuse_state() -> Iterable[None]: + LangfuseResourceManager.reset() + yield + LangfuseResourceManager.reset() + + +@pytest.fixture +def memory_exporter() -> Iterable[InMemorySpanExporter]: + exporter = InMemorySpanExporter() + yield exporter + exporter.shutdown() + + +@pytest.fixture +def langfuse_memory_client( + monkeypatch: pytest.MonkeyPatch, memory_exporter: InMemorySpanExporter +) -> Iterable[Langfuse]: + monkeypatch.setenv("LANGFUSE_PUBLIC_KEY", "test-public-key") + monkeypatch.setenv("LANGFUSE_SECRET_KEY", "test-secret-key") + monkeypatch.setenv("LANGFUSE_BASE_URL", "http://test-host") + + tracer_provider = TracerProvider(resource=Resource.create({"service.name": "test"})) + + def mock_init(self: Any, **kwargs: Any) -> None: + from opentelemetry.sdk.trace.export import BatchSpanProcessor + + from langfuse._client.span_filter import is_default_export_span + + self.public_key = kwargs.get("public_key", "test-public-key") + blocked_scopes = kwargs.get("blocked_instrumentation_scopes") + self.blocked_instrumentation_scopes = ( + blocked_scopes if blocked_scopes is not None else [] + ) + self._should_export_span = ( + kwargs.get("should_export_span") or is_default_export_span + ) + BatchSpanProcessor.__init__( + self, + span_exporter=memory_exporter, + max_export_batch_size=512, + schedule_delay_millis=1, + ) + + monkeypatch.setattr( + "langfuse._client.span_processor.LangfuseSpanProcessor.__init__", + mock_init, + ) + + client = Langfuse( + public_key="test-public-key", + secret_key="test-secret-key", + base_url="http://test-host", + tracing_enabled=True, + tracer_provider=tracer_provider, + ) + + yield client + client.flush() + + +@pytest.fixture +def get_span(memory_exporter: InMemorySpanExporter): + def _get_span(name: str) -> ReadableSpan: + for span in memory_exporter.get_finished_spans(): + if span.name == name: + return span + + raise AssertionError( + f"Span {name!r} not found in {[span.name for span in memory_exporter.get_finished_spans()]}" + ) + + return _get_span + + +@pytest.fixture +def find_spans(memory_exporter: InMemorySpanExporter): + def _find_spans(name: str) -> list[ReadableSpan]: + return [ + span for span in memory_exporter.get_finished_spans() if span.name == name + ] + + return _find_spans + + +@pytest.fixture +def json_attr(): + def _json_attr(span: ReadableSpan, attribute: str) -> Any: + value = span.attributes[attribute] + if not isinstance(value, str): + return value + + return json.loads(value) + + return _json_attr diff --git a/tests/e2e/__init__.py b/tests/e2e/__init__.py new file mode 100644 index 000000000..8b1378917 --- /dev/null +++ b/tests/e2e/__init__.py @@ -0,0 +1 @@ + diff --git a/tests/test_batch_evaluation.py b/tests/e2e/test_batch_evaluation.py similarity index 99% rename from tests/test_batch_evaluation.py rename to tests/e2e/test_batch_evaluation.py index 27c49acdd..48accef7c 100644 --- a/tests/test_batch_evaluation.py +++ b/tests/e2e/test_batch_evaluation.py @@ -18,7 +18,7 @@ EvaluatorStats, ) from langfuse.experiment import Evaluation -from tests.utils import create_uuid +from tests.support.utils import create_uuid # ============================================================================ # FIXTURES & SETUP diff --git a/tests/test_core_sdk.py b/tests/e2e/test_core_sdk.py similarity index 99% rename from tests/test_core_sdk.py rename to tests/e2e/test_core_sdk.py index 91064de23..da00b1748 100644 --- a/tests/test_core_sdk.py +++ b/tests/e2e/test_core_sdk.py @@ -9,8 +9,8 @@ from langfuse import Langfuse, propagate_attributes from langfuse._client.resource_manager import LangfuseResourceManager from langfuse._utils import _get_timestamp -from tests.api_wrapper import LangfuseAPI -from tests.utils import ( +from tests.support.api_wrapper import LangfuseAPI +from tests.support.utils import ( create_uuid, get_api, ) diff --git a/tests/test_datasets.py b/tests/e2e/test_datasets.py similarity index 99% rename from tests/test_datasets.py rename to tests/e2e/test_datasets.py index 7cbbed817..e3d24e129 100644 --- a/tests/test_datasets.py +++ b/tests/e2e/test_datasets.py @@ -3,7 +3,7 @@ from langfuse import Langfuse from langfuse.api import DatasetStatus -from tests.utils import create_uuid +from tests.support.utils import create_uuid def test_create_and_get_dataset(): diff --git a/tests/test_decorators.py b/tests/e2e/test_decorators.py similarity index 99% rename from tests/test_decorators.py rename to tests/e2e/test_decorators.py index c6ed42594..754d41343 100644 --- a/tests/test_decorators.py +++ b/tests/e2e/test_decorators.py @@ -16,7 +16,7 @@ from langfuse._client.resource_manager import LangfuseResourceManager from langfuse.langchain import CallbackHandler from langfuse.media import LangfuseMedia -from tests.utils import get_api +from tests.support.utils import get_api mock_metadata = {"key": "metadata"} mock_deep_metadata = {"key": "mock_deep_metadata"} diff --git a/tests/test_experiments.py b/tests/e2e/test_experiments.py similarity index 99% rename from tests/test_experiments.py rename to tests/e2e/test_experiments.py index 3ba8b4afa..cd17e80e0 100644 --- a/tests/test_experiments.py +++ b/tests/e2e/test_experiments.py @@ -12,7 +12,7 @@ ExperimentItem, ExperimentItemResult, ) -from tests.utils import create_uuid, get_api +from tests.support.utils import create_uuid, get_api @pytest.fixture diff --git a/tests/e2e/test_media.py b/tests/e2e/test_media.py new file mode 100644 index 000000000..b262c0a86 --- /dev/null +++ b/tests/e2e/test_media.py @@ -0,0 +1,58 @@ +import base64 +import re +from uuid import uuid4 + +from langfuse._client.client import Langfuse +from langfuse.media import LangfuseMedia +from tests.support.utils import get_api + + +def test_replace_media_reference_string_in_object(): + audio_file = "static/joke_prompt.wav" + with open(audio_file, "rb") as f: + mock_audio_bytes = f.read() + + langfuse = Langfuse() + + mock_trace_name = f"test-trace-with-audio-{uuid4()}" + base64_audio = base64.b64encode(mock_audio_bytes).decode() + + span = langfuse.start_observation( + name=mock_trace_name, + metadata={ + "context": { + "nested": LangfuseMedia( + base64_data_uri=f"data:audio/wav;base64,{base64_audio}" + ) + } + }, + ).end() + + langfuse.flush() + + fetched_trace = get_api().trace.get(span.trace_id) + media_ref = fetched_trace.observations[0].metadata["context"]["nested"] + assert re.match( + r"^@@@langfuseMedia:type=audio/wav\|id=.+\|source=base64_data_uri@@@$", + media_ref, + ) + + resolved_obs = langfuse.resolve_media_references( + obj=fetched_trace.observations[0], resolve_with="base64_data_uri" + ) + + expected_base64 = f"data:audio/wav;base64,{base64_audio}" + assert resolved_obs["metadata"]["context"]["nested"] == expected_base64 + + span2 = langfuse.start_observation( + name=f"2-{mock_trace_name}", + metadata={"context": {"nested": resolved_obs["metadata"]["context"]["nested"]}}, + ).end() + + langfuse.flush() + + fetched_trace2 = get_api().trace.get(span2.trace_id) + assert ( + fetched_trace2.observations[0].metadata["context"]["nested"] + == fetched_trace.observations[0].metadata["context"]["nested"] + ) diff --git a/tests/e2e/test_prompt.py b/tests/e2e/test_prompt.py new file mode 100644 index 000000000..6e113cb41 --- /dev/null +++ b/tests/e2e/test_prompt.py @@ -0,0 +1,697 @@ +import pytest + +from langfuse._client.client import Langfuse +from tests.support.utils import create_uuid, get_api + + +def test_create_prompt(): + langfuse = Langfuse() + prompt_name = create_uuid() + prompt_client = langfuse.create_prompt( + name=prompt_name, + prompt="test prompt", + labels=["production"], + commit_message="initial commit", + ) + + second_prompt_client = langfuse.get_prompt(prompt_name) + + assert prompt_client.name == second_prompt_client.name + assert prompt_client.version == second_prompt_client.version + assert prompt_client.prompt == second_prompt_client.prompt + assert prompt_client.config == second_prompt_client.config + assert prompt_client.commit_message == second_prompt_client.commit_message + assert prompt_client.config == {} + + +def test_create_prompt_with_special_chars_in_name(): + langfuse = Langfuse() + prompt_name = create_uuid() + "special chars !@#$%^&*() +" + prompt_client = langfuse.create_prompt( + name=prompt_name, + prompt="test prompt", + labels=["production"], + tags=["test"], + ) + + second_prompt_client = langfuse.get_prompt(prompt_name) + + assert prompt_client.name == second_prompt_client.name + assert prompt_client.version == second_prompt_client.version + assert prompt_client.prompt == second_prompt_client.prompt + assert prompt_client.tags == second_prompt_client.tags + assert prompt_client.config == second_prompt_client.config + assert prompt_client.config == {} + + +def test_create_prompt_with_placeholders(): + """Test creating a prompt with placeholder messages.""" + langfuse = Langfuse() + prompt_name = create_uuid() + prompt_client = langfuse.create_prompt( + name=prompt_name, + prompt=[ + {"role": "system", "content": "System message"}, + {"type": "placeholder", "name": "context"}, + {"role": "user", "content": "User message"}, + ], + type="chat", + ) + + # Verify the full prompt structure with placeholders + assert len(prompt_client.prompt) == 3 + + # First message - system + assert prompt_client.prompt[0]["type"] == "message" + assert prompt_client.prompt[0]["role"] == "system" + assert prompt_client.prompt[0]["content"] == "System message" + # Placeholder + assert prompt_client.prompt[1]["type"] == "placeholder" + assert prompt_client.prompt[1]["name"] == "context" + # Third message - user + assert prompt_client.prompt[2]["type"] == "message" + assert prompt_client.prompt[2]["role"] == "user" + assert prompt_client.prompt[2]["content"] == "User message" + + +def test_get_prompt_with_placeholders(): + """Test retrieving a prompt with placeholders.""" + langfuse = Langfuse() + prompt_name = create_uuid() + + langfuse.create_prompt( + name=prompt_name, + prompt=[ + {"role": "system", "content": "You are {{name}}"}, + {"type": "placeholder", "name": "history"}, + {"role": "user", "content": "{{question}}"}, + ], + type="chat", + ) + + prompt_client = langfuse.get_prompt(prompt_name, type="chat", version=1) + + # Verify placeholder structure is preserved + assert len(prompt_client.prompt) == 3 + + # First message - system with variable + assert prompt_client.prompt[0]["type"] == "message" + assert prompt_client.prompt[0]["role"] == "system" + assert prompt_client.prompt[0]["content"] == "You are {{name}}" + # Placeholder + assert prompt_client.prompt[1]["type"] == "placeholder" + assert prompt_client.prompt[1]["name"] == "history" + # Third message - user with variable + assert prompt_client.prompt[2]["type"] == "message" + assert prompt_client.prompt[2]["role"] == "user" + assert prompt_client.prompt[2]["content"] == "{{question}}" + + +def test_warning_on_unresolved_placeholders(): + """Test that a warning is emitted when compiling with unresolved placeholders.""" + from unittest.mock import patch + + langfuse = Langfuse() + prompt_name = create_uuid() + + langfuse.create_prompt( + name=prompt_name, + prompt=[ + {"role": "system", "content": "You are {{name}}"}, + {"type": "placeholder", "name": "history"}, + {"role": "user", "content": "{{question}}"}, + ], + type="chat", + ) + + prompt_client = langfuse.get_prompt(prompt_name, type="chat", version=1) + + # Test that warning is emitted when compiling with unresolved placeholders + with patch("langfuse.logger.langfuse_logger.warning") as mock_warning: + # Compile without providing the 'history' placeholder + result = prompt_client.compile(name="Assistant", question="What is 2+2?") + + # Verify the warning was called with the expected message + mock_warning.assert_called_once() + warning_message = mock_warning.call_args[0][0] + assert "Placeholders ['history'] have not been resolved" in warning_message + + # Verify the result only contains the resolved messages + assert len(result) == 3 + assert result[0]["content"] == "You are Assistant" + assert result[1]["name"] == "history" + assert result[2]["content"] == "What is 2+2?" + + +def test_compiling_chat_prompt(): + langfuse = Langfuse() + prompt_name = create_uuid() + + prompt_client = langfuse.create_prompt( + name=prompt_name, + prompt=[ + { + "role": "system", + "content": "test prompt 1 with {{state}} {{target}} {{state}}", + }, + {"role": "user", "content": "test prompt 2 with {{state}}"}, + ], + labels=["production"], + type="chat", + ) + + second_prompt_client = langfuse.get_prompt(prompt_name, type="chat") + + assert prompt_client.name == second_prompt_client.name + assert prompt_client.version == second_prompt_client.version + assert prompt_client.prompt == second_prompt_client.prompt + assert prompt_client.labels == ["production", "latest"] + + assert second_prompt_client.compile(target="world", state="great") == [ + {"role": "system", "content": "test prompt 1 with great world great"}, + {"role": "user", "content": "test prompt 2 with great"}, + ] + + +def test_compiling_prompt(): + langfuse = Langfuse() + prompt_name = "test_compiling_prompt" + + prompt_client = langfuse.create_prompt( + name=prompt_name, + prompt='Hello, {{target}}! I hope you are {{state}}. {{undefined_variable}}. And here is some JSON that should not be compiled: {{ "key": "value" }} \ + Here is a custom var for users using str.format instead of the mustache-style double curly braces: {custom_var}', + labels=["production"], + ) + + second_prompt_client = langfuse.get_prompt(prompt_name) + + assert prompt_client.name == second_prompt_client.name + assert prompt_client.version == second_prompt_client.version + assert prompt_client.prompt == second_prompt_client.prompt + assert prompt_client.labels == ["production", "latest"] + + compiled = second_prompt_client.compile(target="world", state="great") + + assert ( + compiled + == 'Hello, world! I hope you are great. {{undefined_variable}}. And here is some JSON that should not be compiled: {{ "key": "value" }} \ + Here is a custom var for users using str.format instead of the mustache-style double curly braces: {custom_var}' + ) + + +def test_compiling_prompt_without_character_escaping(): + langfuse = Langfuse() + prompt_name = "test_compiling_prompt_without_character_escaping" + + prompt_client = langfuse.create_prompt( + name=prompt_name, prompt="Hello, {{ some_json }}", labels=["production"] + ) + + second_prompt_client = langfuse.get_prompt(prompt_name) + + assert prompt_client.name == second_prompt_client.name + assert prompt_client.version == second_prompt_client.version + assert prompt_client.prompt == second_prompt_client.prompt + assert prompt_client.labels == ["production", "latest"] + + some_json = '{"key": "value"}' + compiled = second_prompt_client.compile(some_json=some_json) + + assert compiled == 'Hello, {"key": "value"}' + + +def test_compiling_prompt_with_content_as_variable_name(): + langfuse = Langfuse() + prompt_name = "test_compiling_prompt_with_content_as_variable_name" + + prompt_client = langfuse.create_prompt( + name=prompt_name, + prompt="Hello, {{ content }}!", + labels=["production"], + ) + + second_prompt_client = langfuse.get_prompt(prompt_name) + + assert prompt_client.name == second_prompt_client.name + assert prompt_client.version == second_prompt_client.version + assert prompt_client.prompt == second_prompt_client.prompt + assert prompt_client.labels == ["production", "latest"] + + compiled = second_prompt_client.compile(content="Jane") + + assert compiled == "Hello, Jane!" + + +def test_create_prompt_with_null_config(): + langfuse = Langfuse(debug=False) + + langfuse.create_prompt( + name="test_null_config", + prompt="Hello, world! I hope you are great", + labels=["production"], + config=None, + ) + + prompt = langfuse.get_prompt("test_null_config") + + assert prompt.config == {} + + +def test_create_prompt_with_tags(): + langfuse = Langfuse(debug=False) + prompt_name = create_uuid() + + langfuse.create_prompt( + name=prompt_name, + prompt="Hello, world! I hope you are great", + tags=["tag1", "tag2"], + ) + + prompt = langfuse.get_prompt(prompt_name, version=1) + + assert prompt.tags == ["tag1", "tag2"] + + +def test_create_prompt_with_empty_tags(): + langfuse = Langfuse(debug=False) + prompt_name = create_uuid() + + langfuse.create_prompt( + name=prompt_name, + prompt="Hello, world! I hope you are great", + tags=[], + ) + + prompt = langfuse.get_prompt(prompt_name, version=1) + + assert prompt.tags == [] + + +def test_create_prompt_with_previous_tags(): + langfuse = Langfuse(debug=False) + prompt_name = create_uuid() + + langfuse.create_prompt( + name=prompt_name, + prompt="Hello, world! I hope you are great", + ) + + prompt = langfuse.get_prompt(prompt_name, version=1) + + assert prompt.tags == [] + + langfuse.create_prompt( + name=prompt_name, + prompt="Hello, world! I hope you are great", + tags=["tag1", "tag2"], + ) + + prompt_v2 = langfuse.get_prompt(prompt_name, version=2) + + assert prompt_v2.tags == ["tag1", "tag2"] + + langfuse.create_prompt( + name=prompt_name, + prompt="Hello, world! I hope you are great", + ) + + prompt_v3 = langfuse.get_prompt(prompt_name, version=3) + + assert prompt_v3.tags == ["tag1", "tag2"] + + +def test_remove_prompt_tags(): + langfuse = Langfuse(debug=False) + prompt_name = create_uuid() + + langfuse.create_prompt( + name=prompt_name, + prompt="Hello, world! I hope you are great", + tags=["tag1", "tag2"], + ) + + langfuse.create_prompt( + name=prompt_name, + prompt="Hello, world! I hope you are great", + tags=[], + ) + + prompt_v1 = langfuse.get_prompt(prompt_name, version=1) + prompt_v2 = langfuse.get_prompt(prompt_name, version=2) + + assert prompt_v1.tags == [] + assert prompt_v2.tags == [] + + +def test_update_prompt_tags(): + langfuse = Langfuse(debug=False) + prompt_name = create_uuid() + + langfuse.create_prompt( + name=prompt_name, + prompt="Hello, world! I hope you are great", + tags=["tag1", "tag2"], + ) + + prompt_v1 = langfuse.get_prompt(prompt_name, version=1) + + assert prompt_v1.tags == ["tag1", "tag2"] + + langfuse.create_prompt( + name=prompt_name, + prompt="Hello, world! I hope you are great", + tags=["tag3", "tag4"], + ) + + prompt_v2 = langfuse.get_prompt(prompt_name, version=2) + + assert prompt_v2.tags == ["tag3", "tag4"] + + +def test_get_prompt_by_version_or_label(): + langfuse = Langfuse() + prompt_name = create_uuid() + + for i in range(3): + langfuse.create_prompt( + name=prompt_name, + prompt="test prompt " + str(i + 1), + labels=["production"] if i == 1 else [], + ) + + default_prompt_client = langfuse.get_prompt(prompt_name) + assert default_prompt_client.version == 2 + assert default_prompt_client.prompt == "test prompt 2" + assert default_prompt_client.labels == ["production"] + + first_prompt_client = langfuse.get_prompt(prompt_name, version=1) + assert first_prompt_client.version == 1 + assert first_prompt_client.prompt == "test prompt 1" + assert first_prompt_client.labels == [] + + second_prompt_client = langfuse.get_prompt(prompt_name, version=2) + assert second_prompt_client.version == 2 + assert second_prompt_client.prompt == "test prompt 2" + assert second_prompt_client.labels == ["production"] + + third_prompt_client = langfuse.get_prompt(prompt_name, label="latest") + assert third_prompt_client.version == 3 + assert third_prompt_client.prompt == "test prompt 3" + assert third_prompt_client.labels == ["latest"] + + +def test_prompt_end_to_end(): + langfuse = Langfuse(debug=False) + + langfuse.create_prompt( + name="test", + prompt="Hello, {{target}}! I hope you are {{state}}.", + labels=["production"], + config={"temperature": 0.5}, + ) + + prompt = langfuse.get_prompt("test") + + prompt_str = prompt.compile(target="world", state="great") + assert prompt_str == "Hello, world! I hope you are great." + assert prompt.config == {"temperature": 0.5} + + generation = langfuse.start_observation( + as_type="generation", + name="mygen", + input=prompt_str, + prompt=prompt, + ).end() + + # to check that these do not error + generation.update(prompt=prompt) + + langfuse.flush() + + api = get_api() + + trace = api.trace.get(generation.trace_id) + + assert len(trace.observations) == 1 + + generation = trace.observations[0] + assert generation.prompt_id is not None + + observation = api.legacy.observations_v1.get(generation.id) + + assert observation.prompt_id is not None + + +def test_do_not_return_fallback_if_fetch_success(): + langfuse = Langfuse() + prompt_name = create_uuid() + prompt_client = langfuse.create_prompt( + name=prompt_name, + prompt="test prompt", + labels=["production"], + ) + + second_prompt_client = langfuse.get_prompt(prompt_name, fallback="fallback") + + assert prompt_client.name == second_prompt_client.name + assert prompt_client.version == second_prompt_client.version + assert prompt_client.prompt == second_prompt_client.prompt + assert prompt_client.config == second_prompt_client.config + assert prompt_client.config == {} + + +def test_fallback_text_prompt(): + langfuse = Langfuse() + + fallback_text_prompt = "this is a fallback text prompt with {{variable}}" + + # Should throw an error if prompt not found and no fallback provided + with pytest.raises(Exception): + langfuse.get_prompt("nonexistent_prompt") + + prompt = langfuse.get_prompt("nonexistent_prompt", fallback=fallback_text_prompt) + + assert prompt.prompt == fallback_text_prompt + assert ( + prompt.compile(variable="value") == "this is a fallback text prompt with value" + ) + + +def test_fallback_chat_prompt(): + langfuse = Langfuse() + fallback_chat_prompt = [ + {"role": "system", "content": "fallback system"}, + {"role": "user", "content": "fallback user name {{name}}"}, + ] + + # Should throw an error if prompt not found and no fallback provided + with pytest.raises(Exception): + langfuse.get_prompt("nonexistent_chat_prompt", type="chat") + + prompt = langfuse.get_prompt( + "nonexistent_chat_prompt", type="chat", fallback=fallback_chat_prompt + ) + + # Check that the prompt structure contains the fallback data (allowing for internal formatting) + assert len(prompt.prompt) == len(fallback_chat_prompt) + assert all(msg["type"] == "message" for msg in prompt.prompt) + assert prompt.prompt[0]["role"] == "system" + assert prompt.prompt[0]["content"] == "fallback system" + assert prompt.prompt[1]["role"] == "user" + assert prompt.prompt[1]["content"] == "fallback user name {{name}}" + assert prompt.compile(name="Jane") == [ + {"role": "system", "content": "fallback system"}, + {"role": "user", "content": "fallback user name Jane"}, + ] + + +def test_do_not_link_observation_if_fallback(): + langfuse = Langfuse() + + fallback_text_prompt = "this is a fallback text prompt with {{variable}}" + + # Should throw an error if prompt not found and no fallback provided + with pytest.raises(Exception): + langfuse.get_prompt("nonexistent_prompt") + + prompt = langfuse.get_prompt("nonexistent_prompt", fallback=fallback_text_prompt) + + generation = langfuse.start_observation( + as_type="generation", + name="mygen", + prompt=prompt, + input="this is a test input", + ).end() + langfuse.flush() + + api = get_api() + trace = api.trace.get(generation.trace_id) + + assert len(trace.observations) == 1 + assert trace.observations[0].prompt_id is None + + +def test_variable_names_on_content_with_variable_names(): + langfuse = Langfuse() + + prompt_client = langfuse.create_prompt( + name="test_variable_names_1", + prompt="test prompt with var names {{ var1 }} {{ var2 }}", + labels=["production"], + type="text", + ) + + second_prompt_client = langfuse.get_prompt("test_variable_names_1") + + assert prompt_client.name == second_prompt_client.name + assert prompt_client.version == second_prompt_client.version + assert prompt_client.prompt == second_prompt_client.prompt + assert prompt_client.labels == ["production", "latest"] + + var_names = second_prompt_client.variables + + assert var_names == ["var1", "var2"] + + +def test_variable_names_on_content_with_no_variable_names(): + langfuse = Langfuse() + + prompt_client = langfuse.create_prompt( + name="test_variable_names_2", + prompt="test prompt with no var names", + labels=["production"], + type="text", + ) + + second_prompt_client = langfuse.get_prompt("test_variable_names_2") + + assert prompt_client.name == second_prompt_client.name + assert prompt_client.version == second_prompt_client.version + assert prompt_client.prompt == second_prompt_client.prompt + assert prompt_client.labels == ["production", "latest"] + + var_names = second_prompt_client.variables + + assert var_names == [] + + +def test_variable_names_on_content_with_variable_names_chat_messages(): + langfuse = Langfuse() + + prompt_client = langfuse.create_prompt( + name="test_variable_names_3", + prompt=[ + { + "role": "system", + "content": "test prompt with template vars {{ var1 }} {{ var2 }}", + }, + {"role": "user", "content": "test prompt 2 with template vars {{ var3 }}"}, + ], + labels=["production"], + type="chat", + ) + + second_prompt_client = langfuse.get_prompt("test_variable_names_3") + + assert prompt_client.name == second_prompt_client.name + assert prompt_client.version == second_prompt_client.version + assert prompt_client.prompt == second_prompt_client.prompt + assert prompt_client.labels == ["production", "latest"] + + var_names = second_prompt_client.variables + + assert var_names == ["var1", "var2", "var3"] + + +def test_variable_names_on_content_with_no_variable_names_chat_messages(): + langfuse = Langfuse() + prompt_name = "test_variable_names_on_content_with_no_variable_names_chat_messages" + + prompt_client = langfuse.create_prompt( + name=prompt_name, + prompt=[ + {"role": "system", "content": "test prompt with no template vars"}, + {"role": "user", "content": "test prompt 2 with no template vars"}, + ], + labels=["production"], + type="chat", + ) + + second_prompt_client = langfuse.get_prompt(prompt_name) + + assert prompt_client.name == second_prompt_client.name + assert prompt_client.version == second_prompt_client.version + assert prompt_client.prompt == second_prompt_client.prompt + assert prompt_client.labels == ["production", "latest"] + + var_names = second_prompt_client.variables + + assert var_names == [] + + +def test_update_prompt(): + langfuse = Langfuse() + prompt_name = create_uuid() + + # Create initial prompt + langfuse.create_prompt( + name=prompt_name, + prompt="test prompt", + labels=["production"], + ) + + # Update prompt labels + updated_prompt = langfuse.update_prompt( + name=prompt_name, + version=1, + new_labels=["john", "doe"], + ) + + # Fetch prompt after update (should be invalidated) + fetched_prompt = langfuse.get_prompt(prompt_name) + + # Verify the fetched prompt matches the updated values + assert fetched_prompt.name == prompt_name + assert fetched_prompt.version == 1 + print(f"Fetched prompt labels: {fetched_prompt.labels}") + print(f"Updated prompt labels: {updated_prompt.labels}") + + # production was set by the first call, latest is managed and set by Langfuse + expected_labels = sorted(["latest", "doe", "production", "john"]) + assert sorted(fetched_prompt.labels) == expected_labels + assert sorted(updated_prompt.labels) == expected_labels + + +def test_update_prompt_in_folder(): + langfuse = Langfuse() + prompt_name = f"some-folder/{create_uuid()}" + + # Create initial prompt + langfuse.create_prompt( + name=prompt_name, + prompt="test prompt", + labels=["production"], + ) + + old_prompt_obj = langfuse.get_prompt(prompt_name) + + updated_prompt = langfuse.update_prompt( + name=old_prompt_obj.name, + version=old_prompt_obj.version, + new_labels=["john", "doe"], + ) + + # Fetch prompt after update (should be invalidated) + fetched_prompt = langfuse.get_prompt(prompt_name) + + # Verify the fetched prompt matches the updated values + assert fetched_prompt.name == prompt_name + assert fetched_prompt.version == 1 + print(f"Fetched prompt labels: {fetched_prompt.labels}") + print(f"Updated prompt labels: {updated_prompt.labels}") + + # production was set by the first call, latest is managed and set by Langfuse + expected_labels = sorted(["latest", "doe", "production", "john"]) + assert sorted(fetched_prompt.labels) == expected_labels + assert sorted(updated_prompt.labels) == expected_labels diff --git a/tests/live_provider/__init__.py b/tests/live_provider/__init__.py new file mode 100644 index 000000000..8b1378917 --- /dev/null +++ b/tests/live_provider/__init__.py @@ -0,0 +1 @@ + diff --git a/tests/test_langchain.py b/tests/live_provider/test_langchain.py similarity index 99% rename from tests/test_langchain.py rename to tests/live_provider/test_langchain.py index 6c9d3eb4d..4c2495c2d 100644 --- a/tests/test_langchain.py +++ b/tests/live_provider/test_langchain.py @@ -18,7 +18,7 @@ from langfuse._client.client import Langfuse from langfuse.langchain import CallbackHandler -from tests.utils import create_uuid, encode_file_to_base64, get_api +from tests.support.utils import create_uuid, encode_file_to_base64, get_api def test_callback_generated_from_trace_chat(): diff --git a/tests/test_langchain_integration.py b/tests/live_provider/test_langchain_integration.py similarity index 99% rename from tests/test_langchain_integration.py rename to tests/live_provider/test_langchain_integration.py index c7e4a9418..edb5455c4 100644 --- a/tests/test_langchain_integration.py +++ b/tests/live_provider/test_langchain_integration.py @@ -7,9 +7,7 @@ from langfuse import Langfuse from langfuse.langchain import CallbackHandler -from tests.utils import get_api - -from .utils import create_uuid +from tests.support.utils import create_uuid, get_api def _is_streaming_response(response): diff --git a/tests/test_openai.py b/tests/live_provider/test_openai.py similarity index 99% rename from tests/test_openai.py rename to tests/live_provider/test_openai.py index 47f17a5c8..d7fee57f0 100644 --- a/tests/test_openai.py +++ b/tests/live_provider/test_openai.py @@ -6,7 +6,7 @@ from pydantic import BaseModel from langfuse._client.client import Langfuse -from tests.utils import create_uuid, encode_file_to_base64, get_api +from tests.support.utils import create_uuid, encode_file_to_base64, get_api langfuse = Langfuse() diff --git a/tests/live_provider/test_prompt.py b/tests/live_provider/test_prompt.py new file mode 100644 index 000000000..a64f26f45 --- /dev/null +++ b/tests/live_provider/test_prompt.py @@ -0,0 +1,87 @@ +import openai + +from langfuse._client.client import Langfuse +from tests.support.utils import create_uuid + + +def test_create_chat_prompt(): + langfuse = Langfuse() + prompt_name = create_uuid() + + prompt_client = langfuse.create_prompt( + name=prompt_name, + prompt=[ + {"role": "system", "content": "test prompt 1 with {{animal}}"}, + {"role": "user", "content": "test prompt 2 with {{occupation}}"}, + ], + labels=["production"], + tags=["test"], + type="chat", + commit_message="initial commit", + ) + + second_prompt_client = langfuse.get_prompt(prompt_name, type="chat") + + completion = openai.OpenAI().chat.completions.create( + model="gpt-4", + messages=prompt_client.compile(animal="dog", occupation="doctor"), + ) + + assert len(completion.choices) > 0 + + assert prompt_client.name == second_prompt_client.name + assert prompt_client.version == second_prompt_client.version + assert prompt_client.prompt == second_prompt_client.prompt + assert prompt_client.config == second_prompt_client.config + assert prompt_client.labels == ["production", "latest"] + assert prompt_client.tags == second_prompt_client.tags + assert prompt_client.commit_message == second_prompt_client.commit_message + assert prompt_client.config == {} + + +def test_create_chat_prompt_with_placeholders(): + langfuse = Langfuse() + prompt_name = create_uuid() + + prompt_client = langfuse.create_prompt( + name=prompt_name, + prompt=[ + {"role": "system", "content": "You are a {{role}} assistant"}, + {"type": "placeholder", "name": "history"}, + {"role": "user", "content": "Help me with {{task}}"}, + ], + labels=["production"], + tags=["test"], + type="chat", + commit_message="initial commit", + ) + + second_prompt_client = langfuse.get_prompt(prompt_name, type="chat") + messages = second_prompt_client.compile( + role="helpful", + task="coding", + history=[ + {"role": "user", "content": "Example: {{task}}"}, + {"role": "assistant", "content": "Example response"}, + ], + ) + + completion = openai.OpenAI().chat.completions.create( + model="gpt-4", + messages=messages, + ) + + assert len(completion.choices) > 0 + assert len(messages) == 4 + assert messages[0]["content"] == "You are a helpful assistant" + assert messages[1]["content"] == "Example: coding" + assert messages[2]["content"] == "Example response" + assert messages[3]["content"] == "Help me with coding" + + assert prompt_client.name == second_prompt_client.name + assert prompt_client.version == second_prompt_client.version + assert prompt_client.config == second_prompt_client.config + assert prompt_client.labels == ["production", "latest"] + assert prompt_client.tags == second_prompt_client.tags + assert prompt_client.commit_message == second_prompt_client.commit_message + assert prompt_client.config == {} diff --git a/tests/support/__init__.py b/tests/support/__init__.py new file mode 100644 index 000000000..8b1378917 --- /dev/null +++ b/tests/support/__init__.py @@ -0,0 +1 @@ + diff --git a/tests/api_wrapper.py b/tests/support/api_wrapper.py similarity index 100% rename from tests/api_wrapper.py rename to tests/support/api_wrapper.py diff --git a/tests/utils.py b/tests/support/utils.py similarity index 100% rename from tests/utils.py rename to tests/support/utils.py diff --git a/tests/test_prompt.py b/tests/test_prompt.py deleted file mode 100644 index 3c4c5c013..000000000 --- a/tests/test_prompt.py +++ /dev/null @@ -1,1524 +0,0 @@ -from time import sleep -from unittest.mock import Mock, patch - -import openai -import pytest - -from langfuse._client.client import Langfuse -from langfuse._utils.prompt_cache import ( - DEFAULT_PROMPT_CACHE_TTL_SECONDS, - PromptCache, - PromptCacheItem, -) -from langfuse.api import NotFoundError, Prompt_Chat, Prompt_Text -from langfuse.model import ChatPromptClient, TextPromptClient -from tests.utils import create_uuid, get_api - - -def test_create_prompt(): - langfuse = Langfuse() - prompt_name = create_uuid() - prompt_client = langfuse.create_prompt( - name=prompt_name, - prompt="test prompt", - labels=["production"], - commit_message="initial commit", - ) - - second_prompt_client = langfuse.get_prompt(prompt_name) - - assert prompt_client.name == second_prompt_client.name - assert prompt_client.version == second_prompt_client.version - assert prompt_client.prompt == second_prompt_client.prompt - assert prompt_client.config == second_prompt_client.config - assert prompt_client.commit_message == second_prompt_client.commit_message - assert prompt_client.config == {} - - -def test_create_prompt_with_special_chars_in_name(): - langfuse = Langfuse() - prompt_name = create_uuid() + "special chars !@#$%^&*() +" - prompt_client = langfuse.create_prompt( - name=prompt_name, - prompt="test prompt", - labels=["production"], - tags=["test"], - ) - - second_prompt_client = langfuse.get_prompt(prompt_name) - - assert prompt_client.name == second_prompt_client.name - assert prompt_client.version == second_prompt_client.version - assert prompt_client.prompt == second_prompt_client.prompt - assert prompt_client.tags == second_prompt_client.tags - assert prompt_client.config == second_prompt_client.config - assert prompt_client.config == {} - - -def test_create_chat_prompt(): - langfuse = Langfuse() - prompt_name = create_uuid() - - prompt_client = langfuse.create_prompt( - name=prompt_name, - prompt=[ - {"role": "system", "content": "test prompt 1 with {{animal}}"}, - {"role": "user", "content": "test prompt 2 with {{occupation}}"}, - ], - labels=["production"], - tags=["test"], - type="chat", - commit_message="initial commit", - ) - - second_prompt_client = langfuse.get_prompt(prompt_name, type="chat") - - # Create a test generation - completion = openai.OpenAI().chat.completions.create( - model="gpt-4", - messages=prompt_client.compile(animal="dog", occupation="doctor"), - ) - - assert len(completion.choices) > 0 - - assert prompt_client.name == second_prompt_client.name - assert prompt_client.version == second_prompt_client.version - assert prompt_client.prompt == second_prompt_client.prompt - assert prompt_client.config == second_prompt_client.config - assert prompt_client.labels == ["production", "latest"] - assert prompt_client.tags == second_prompt_client.tags - assert prompt_client.commit_message == second_prompt_client.commit_message - assert prompt_client.config == {} - - -def test_create_chat_prompt_with_placeholders(): - langfuse = Langfuse() - prompt_name = create_uuid() - - prompt_client = langfuse.create_prompt( - name=prompt_name, - prompt=[ - {"role": "system", "content": "You are a {{role}} assistant"}, - {"type": "placeholder", "name": "history"}, - {"role": "user", "content": "Help me with {{task}}"}, - ], - labels=["production"], - tags=["test"], - type="chat", - commit_message="initial commit", - ) - - second_prompt_client = langfuse.get_prompt(prompt_name, type="chat") - messages = second_prompt_client.compile( - role="helpful", - task="coding", - history=[ - {"role": "user", "content": "Example: {{task}}"}, - {"role": "assistant", "content": "Example response"}, - ], - ) - - # Create a test generation using compiled messages - completion = openai.OpenAI().chat.completions.create( - model="gpt-4", - messages=messages, - ) - - assert len(completion.choices) > 0 - assert len(messages) == 4 - assert messages[0]["content"] == "You are a helpful assistant" - assert messages[1]["content"] == "Example: coding" - assert messages[2]["content"] == "Example response" - assert messages[3]["content"] == "Help me with coding" - - assert prompt_client.name == second_prompt_client.name - assert prompt_client.version == second_prompt_client.version - assert prompt_client.config == second_prompt_client.config - assert prompt_client.labels == ["production", "latest"] - assert prompt_client.tags == second_prompt_client.tags - assert prompt_client.commit_message == second_prompt_client.commit_message - assert prompt_client.config == {} - - -def test_create_prompt_with_placeholders(): - """Test creating a prompt with placeholder messages.""" - langfuse = Langfuse() - prompt_name = create_uuid() - prompt_client = langfuse.create_prompt( - name=prompt_name, - prompt=[ - {"role": "system", "content": "System message"}, - {"type": "placeholder", "name": "context"}, - {"role": "user", "content": "User message"}, - ], - type="chat", - ) - - # Verify the full prompt structure with placeholders - assert len(prompt_client.prompt) == 3 - - # First message - system - assert prompt_client.prompt[0]["type"] == "message" - assert prompt_client.prompt[0]["role"] == "system" - assert prompt_client.prompt[0]["content"] == "System message" - # Placeholder - assert prompt_client.prompt[1]["type"] == "placeholder" - assert prompt_client.prompt[1]["name"] == "context" - # Third message - user - assert prompt_client.prompt[2]["type"] == "message" - assert prompt_client.prompt[2]["role"] == "user" - assert prompt_client.prompt[2]["content"] == "User message" - - -def test_get_prompt_with_placeholders(): - """Test retrieving a prompt with placeholders.""" - langfuse = Langfuse() - prompt_name = create_uuid() - - langfuse.create_prompt( - name=prompt_name, - prompt=[ - {"role": "system", "content": "You are {{name}}"}, - {"type": "placeholder", "name": "history"}, - {"role": "user", "content": "{{question}}"}, - ], - type="chat", - ) - - prompt_client = langfuse.get_prompt(prompt_name, type="chat", version=1) - - # Verify placeholder structure is preserved - assert len(prompt_client.prompt) == 3 - - # First message - system with variable - assert prompt_client.prompt[0]["type"] == "message" - assert prompt_client.prompt[0]["role"] == "system" - assert prompt_client.prompt[0]["content"] == "You are {{name}}" - # Placeholder - assert prompt_client.prompt[1]["type"] == "placeholder" - assert prompt_client.prompt[1]["name"] == "history" - # Third message - user with variable - assert prompt_client.prompt[2]["type"] == "message" - assert prompt_client.prompt[2]["role"] == "user" - assert prompt_client.prompt[2]["content"] == "{{question}}" - - -@pytest.mark.parametrize( - ("variables", "placeholders", "expected_len", "expected_contents"), - [ - # 0. Variables only, no placeholders. Unresolved placeholders kept in output - ( - {"role": "helpful", "task": "coding"}, - {}, - 3, - [ - "You are a helpful assistant", - None, - "Help me with coding", - ], # None = placeholder - ), - # 1. No variables, no placeholders. Expect verbatim message+placeholder output - ( - {}, - {}, - 3, - ["You are a {{role}} assistant", None, "Help me with {{task}}"], - ), # None = placeholder - # 2. Placeholders only, empty variables. Expect output with placeholders filled in - ( - {}, - { - "examples": [ - {"role": "user", "content": "Example question"}, - {"role": "assistant", "content": "Example answer"}, - ], - }, - 4, - [ - "You are a {{role}} assistant", - "Example question", - "Example answer", - "Help me with {{task}}", - ], - ), - # 3. Both variables and placeholders. Expect fully compiled output - ( - {"role": "helpful", "task": "coding"}, - { - "examples": [ - {"role": "user", "content": "Show me {{task}}"}, - {"role": "assistant", "content": "Here's {{task}}"}, - ], - }, - 4, - [ - "You are a helpful assistant", - "Show me coding", - "Here's coding", - "Help me with coding", - ], - ), - # # Empty placeholder array - # This is expected to fail! If the user provides a placeholder, it should contain an array - # ( - # {"role": "helpful", "task": "coding"}, - # {"examples": []}, - # 2, - # ["You are a helpful assistant", "Help me with coding"], - # ), - # 4. Unused placeholder fill ins. Unresolved placeholders kept in output - ( - {"role": "helpful", "task": "coding"}, - {"unused": [{"role": "user", "content": "Won't appear"}]}, - 3, - [ - "You are a helpful assistant", - None, - "Help me with coding", - ], # None = placeholder - ), - # 5. Placeholder with non-list value (should log warning and append as string) - ( - {"role": "helpful", "task": "coding"}, - {"examples": "not a list"}, - 3, - [ - "You are a helpful assistant", - "not a list", # String value appended directly - "Help me with coding", - ], - ), - # 6. Placeholder with invalid message structure (should log warning and include both) - ( - {"role": "helpful", "task": "coding"}, - { - "examples": [ - "invalid message", - {"role": "user", "content": "valid message"}, - ] - }, - 4, - [ - "You are a helpful assistant", - "['invalid message', {'role': 'user', 'content': 'valid message'}]", # Invalid structure becomes string - "valid message", # Valid message processed normally - "Help me with coding", - ], - ), - ], -) -def test_compile_with_placeholders( - variables, placeholders, expected_len, expected_contents -) -> None: - """Test compile_with_placeholders with different variable/placeholder combinations.""" - from langfuse.api import Prompt_Chat - from langfuse.model import ChatPromptClient - - mock_prompt = Prompt_Chat( - name="test_prompt", - version=1, - type="chat", - config={}, - tags=[], - labels=[], - prompt=[ - {"role": "system", "content": "You are a {{role}} assistant"}, - {"type": "placeholder", "name": "examples"}, - {"role": "user", "content": "Help me with {{task}}"}, - ], - ) - - compile_kwargs = {**placeholders, **variables} - result = ChatPromptClient(mock_prompt).compile(**compile_kwargs) - - assert len(result) == expected_len - for i, expected_content in enumerate(expected_contents): - if expected_content is None: - # This should be an unresolved placeholder - assert "type" in result[i] and result[i]["type"] == "placeholder" - elif isinstance(result[i], str): - # This is a string value from invalid placeholder - assert result[i] == expected_content - else: - # This should be a regular message - assert "content" in result[i] - assert result[i]["content"] == expected_content - - -def test_warning_on_unresolved_placeholders(): - """Test that a warning is emitted when compiling with unresolved placeholders.""" - from unittest.mock import patch - - langfuse = Langfuse() - prompt_name = create_uuid() - - langfuse.create_prompt( - name=prompt_name, - prompt=[ - {"role": "system", "content": "You are {{name}}"}, - {"type": "placeholder", "name": "history"}, - {"role": "user", "content": "{{question}}"}, - ], - type="chat", - ) - - prompt_client = langfuse.get_prompt(prompt_name, type="chat", version=1) - - # Test that warning is emitted when compiling with unresolved placeholders - with patch("langfuse.logger.langfuse_logger.warning") as mock_warning: - # Compile without providing the 'history' placeholder - result = prompt_client.compile(name="Assistant", question="What is 2+2?") - - # Verify the warning was called with the expected message - mock_warning.assert_called_once() - warning_message = mock_warning.call_args[0][0] - assert "Placeholders ['history'] have not been resolved" in warning_message - - # Verify the result only contains the resolved messages - assert len(result) == 3 - assert result[0]["content"] == "You are Assistant" - assert result[1]["name"] == "history" - assert result[2]["content"] == "What is 2+2?" - - -def test_compiling_chat_prompt(): - langfuse = Langfuse() - prompt_name = create_uuid() - - prompt_client = langfuse.create_prompt( - name=prompt_name, - prompt=[ - { - "role": "system", - "content": "test prompt 1 with {{state}} {{target}} {{state}}", - }, - {"role": "user", "content": "test prompt 2 with {{state}}"}, - ], - labels=["production"], - type="chat", - ) - - second_prompt_client = langfuse.get_prompt(prompt_name, type="chat") - - assert prompt_client.name == second_prompt_client.name - assert prompt_client.version == second_prompt_client.version - assert prompt_client.prompt == second_prompt_client.prompt - assert prompt_client.labels == ["production", "latest"] - - assert second_prompt_client.compile(target="world", state="great") == [ - {"role": "system", "content": "test prompt 1 with great world great"}, - {"role": "user", "content": "test prompt 2 with great"}, - ] - - -def test_compiling_prompt(): - langfuse = Langfuse() - prompt_name = "test_compiling_prompt" - - prompt_client = langfuse.create_prompt( - name=prompt_name, - prompt='Hello, {{target}}! I hope you are {{state}}. {{undefined_variable}}. And here is some JSON that should not be compiled: {{ "key": "value" }} \ - Here is a custom var for users using str.format instead of the mustache-style double curly braces: {custom_var}', - labels=["production"], - ) - - second_prompt_client = langfuse.get_prompt(prompt_name) - - assert prompt_client.name == second_prompt_client.name - assert prompt_client.version == second_prompt_client.version - assert prompt_client.prompt == second_prompt_client.prompt - assert prompt_client.labels == ["production", "latest"] - - compiled = second_prompt_client.compile(target="world", state="great") - - assert ( - compiled - == 'Hello, world! I hope you are great. {{undefined_variable}}. And here is some JSON that should not be compiled: {{ "key": "value" }} \ - Here is a custom var for users using str.format instead of the mustache-style double curly braces: {custom_var}' - ) - - -def test_compiling_prompt_without_character_escaping(): - langfuse = Langfuse() - prompt_name = "test_compiling_prompt_without_character_escaping" - - prompt_client = langfuse.create_prompt( - name=prompt_name, prompt="Hello, {{ some_json }}", labels=["production"] - ) - - second_prompt_client = langfuse.get_prompt(prompt_name) - - assert prompt_client.name == second_prompt_client.name - assert prompt_client.version == second_prompt_client.version - assert prompt_client.prompt == second_prompt_client.prompt - assert prompt_client.labels == ["production", "latest"] - - some_json = '{"key": "value"}' - compiled = second_prompt_client.compile(some_json=some_json) - - assert compiled == 'Hello, {"key": "value"}' - - -def test_compiling_prompt_with_content_as_variable_name(): - langfuse = Langfuse() - prompt_name = "test_compiling_prompt_with_content_as_variable_name" - - prompt_client = langfuse.create_prompt( - name=prompt_name, - prompt="Hello, {{ content }}!", - labels=["production"], - ) - - second_prompt_client = langfuse.get_prompt(prompt_name) - - assert prompt_client.name == second_prompt_client.name - assert prompt_client.version == second_prompt_client.version - assert prompt_client.prompt == second_prompt_client.prompt - assert prompt_client.labels == ["production", "latest"] - - compiled = second_prompt_client.compile(content="Jane") - - assert compiled == "Hello, Jane!" - - -def test_create_prompt_with_null_config(): - langfuse = Langfuse(debug=False) - - langfuse.create_prompt( - name="test_null_config", - prompt="Hello, world! I hope you are great", - labels=["production"], - config=None, - ) - - prompt = langfuse.get_prompt("test_null_config") - - assert prompt.config == {} - - -def test_create_prompt_with_tags(): - langfuse = Langfuse(debug=False) - prompt_name = create_uuid() - - langfuse.create_prompt( - name=prompt_name, - prompt="Hello, world! I hope you are great", - tags=["tag1", "tag2"], - ) - - prompt = langfuse.get_prompt(prompt_name, version=1) - - assert prompt.tags == ["tag1", "tag2"] - - -def test_create_prompt_with_empty_tags(): - langfuse = Langfuse(debug=False) - prompt_name = create_uuid() - - langfuse.create_prompt( - name=prompt_name, - prompt="Hello, world! I hope you are great", - tags=[], - ) - - prompt = langfuse.get_prompt(prompt_name, version=1) - - assert prompt.tags == [] - - -def test_create_prompt_with_previous_tags(): - langfuse = Langfuse(debug=False) - prompt_name = create_uuid() - - langfuse.create_prompt( - name=prompt_name, - prompt="Hello, world! I hope you are great", - ) - - prompt = langfuse.get_prompt(prompt_name, version=1) - - assert prompt.tags == [] - - langfuse.create_prompt( - name=prompt_name, - prompt="Hello, world! I hope you are great", - tags=["tag1", "tag2"], - ) - - prompt_v2 = langfuse.get_prompt(prompt_name, version=2) - - assert prompt_v2.tags == ["tag1", "tag2"] - - langfuse.create_prompt( - name=prompt_name, - prompt="Hello, world! I hope you are great", - ) - - prompt_v3 = langfuse.get_prompt(prompt_name, version=3) - - assert prompt_v3.tags == ["tag1", "tag2"] - - -def test_remove_prompt_tags(): - langfuse = Langfuse(debug=False) - prompt_name = create_uuid() - - langfuse.create_prompt( - name=prompt_name, - prompt="Hello, world! I hope you are great", - tags=["tag1", "tag2"], - ) - - langfuse.create_prompt( - name=prompt_name, - prompt="Hello, world! I hope you are great", - tags=[], - ) - - prompt_v1 = langfuse.get_prompt(prompt_name, version=1) - prompt_v2 = langfuse.get_prompt(prompt_name, version=2) - - assert prompt_v1.tags == [] - assert prompt_v2.tags == [] - - -def test_update_prompt_tags(): - langfuse = Langfuse(debug=False) - prompt_name = create_uuid() - - langfuse.create_prompt( - name=prompt_name, - prompt="Hello, world! I hope you are great", - tags=["tag1", "tag2"], - ) - - prompt_v1 = langfuse.get_prompt(prompt_name, version=1) - - assert prompt_v1.tags == ["tag1", "tag2"] - - langfuse.create_prompt( - name=prompt_name, - prompt="Hello, world! I hope you are great", - tags=["tag3", "tag4"], - ) - - prompt_v2 = langfuse.get_prompt(prompt_name, version=2) - - assert prompt_v2.tags == ["tag3", "tag4"] - - -def test_get_prompt_by_version_or_label(): - langfuse = Langfuse() - prompt_name = create_uuid() - - for i in range(3): - langfuse.create_prompt( - name=prompt_name, - prompt="test prompt " + str(i + 1), - labels=["production"] if i == 1 else [], - ) - - default_prompt_client = langfuse.get_prompt(prompt_name) - assert default_prompt_client.version == 2 - assert default_prompt_client.prompt == "test prompt 2" - assert default_prompt_client.labels == ["production"] - - first_prompt_client = langfuse.get_prompt(prompt_name, version=1) - assert first_prompt_client.version == 1 - assert first_prompt_client.prompt == "test prompt 1" - assert first_prompt_client.labels == [] - - second_prompt_client = langfuse.get_prompt(prompt_name, version=2) - assert second_prompt_client.version == 2 - assert second_prompt_client.prompt == "test prompt 2" - assert second_prompt_client.labels == ["production"] - - third_prompt_client = langfuse.get_prompt(prompt_name, label="latest") - assert third_prompt_client.version == 3 - assert third_prompt_client.prompt == "test prompt 3" - assert third_prompt_client.labels == ["latest"] - - -def test_prompt_end_to_end(): - langfuse = Langfuse(debug=False) - - langfuse.create_prompt( - name="test", - prompt="Hello, {{target}}! I hope you are {{state}}.", - labels=["production"], - config={"temperature": 0.5}, - ) - - prompt = langfuse.get_prompt("test") - - prompt_str = prompt.compile(target="world", state="great") - assert prompt_str == "Hello, world! I hope you are great." - assert prompt.config == {"temperature": 0.5} - - generation = langfuse.start_observation( - as_type="generation", - name="mygen", - input=prompt_str, - prompt=prompt, - ).end() - - # to check that these do not error - generation.update(prompt=prompt) - - langfuse.flush() - - api = get_api() - - trace = api.trace.get(generation.trace_id) - - assert len(trace.observations) == 1 - - generation = trace.observations[0] - assert generation.prompt_id is not None - - observation = api.legacy.observations_v1.get(generation.id) - - assert observation.prompt_id is not None - - -@pytest.fixture -def langfuse(): - from langfuse._client.resource_manager import LangfuseResourceManager - - langfuse_instance = Langfuse() - langfuse_instance.api = Mock() - - if langfuse_instance._resources is None: - langfuse_instance._resources = Mock(spec=LangfuseResourceManager) - langfuse_instance._resources.prompt_cache = PromptCache() - - return langfuse_instance - - -# Fetching a new prompt when nothing in cache -def test_get_fresh_prompt(langfuse): - prompt_name = "test_get_fresh_prompt" - prompt = Prompt_Text( - name=prompt_name, - version=1, - prompt="Make me laugh", - type="text", - labels=[], - config={}, - tags=[], - ) - - mock_server_call = langfuse.api.prompts.get - mock_server_call.return_value = prompt - - result = langfuse.get_prompt(prompt_name, fallback="fallback") - mock_server_call.assert_called_once_with( - prompt_name, - version=None, - label=None, - request_options=None, - ) - - assert result == TextPromptClient(prompt) - - -# Should throw an error if prompt name is unspecified -def test_throw_if_name_unspecified(langfuse): - prompt_name = "" - - with pytest.raises(ValueError) as exc_info: - langfuse.get_prompt(prompt_name) - - assert "Prompt name cannot be empty" in str(exc_info.value) - - -# Should throw an error if nothing in cache and fetch fails -def test_throw_when_failing_fetch_and_no_cache(langfuse): - prompt_name = "failing_fetch_and_no_cache" - - mock_server_call = langfuse.api.prompts.get - mock_server_call.side_effect = Exception("Prompt not found") - - with pytest.raises(Exception) as exc_info: - langfuse.get_prompt(prompt_name) - - assert "Prompt not found" in str(exc_info.value) - - -def test_using_custom_prompt_timeouts(langfuse): - prompt_name = "test_using_custom_prompt_timeouts" - prompt = Prompt_Text( - name=prompt_name, - version=1, - prompt="Make me laugh", - type="text", - labels=[], - config={}, - tags=[], - ) - - mock_server_call = langfuse.api.prompts.get - mock_server_call.return_value = prompt - - result = langfuse.get_prompt( - prompt_name, fallback="fallback", fetch_timeout_seconds=1000 - ) - mock_server_call.assert_called_once_with( - prompt_name, - version=None, - label=None, - request_options={"timeout_in_seconds": 1000}, - ) - - assert result == TextPromptClient(prompt) - - -# Should throw an error if cache_ttl_seconds is passed as positional rather than keyword argument -def test_throw_if_cache_ttl_seconds_positional_argument(langfuse): - prompt_name = "test ttl seconds in positional arg" - ttl_seconds = 20 - - with pytest.raises(TypeError) as exc_info: - langfuse.get_prompt(prompt_name, ttl_seconds) - - assert "positional arguments" in str(exc_info.value) - - -# Should return cached prompt if not expired -def test_get_valid_cached_prompt(langfuse): - prompt_name = "test_get_valid_cached_prompt" - prompt = Prompt_Text( - name=prompt_name, - version=1, - prompt="Make me laugh", - type="text", - labels=[], - config={}, - tags=[], - ) - prompt_client = TextPromptClient(prompt) - - mock_server_call = langfuse.api.prompts.get - mock_server_call.return_value = prompt - - result_call_1 = langfuse.get_prompt(prompt_name, fallback="fallback") - assert mock_server_call.call_count == 1 - assert result_call_1 == prompt_client - - result_call_2 = langfuse.get_prompt(prompt_name) - assert mock_server_call.call_count == 1 - assert result_call_2 == prompt_client - - -# Should return cached chat prompt if not expired when fetching by label -def test_get_valid_cached_chat_prompt_by_label(langfuse): - prompt_name = "test_get_valid_cached_chat_prompt_by_label" - prompt = Prompt_Chat( - name=prompt_name, - version=1, - prompt=[{"role": "system", "content": "Make me laugh"}], - labels=["test"], - type="chat", - config={}, - tags=[], - ) - prompt_client = ChatPromptClient(prompt) - - mock_server_call = langfuse.api.prompts.get - mock_server_call.return_value = prompt - - result_call_1 = langfuse.get_prompt(prompt_name, label="test") - assert mock_server_call.call_count == 1 - assert result_call_1 == prompt_client - - result_call_2 = langfuse.get_prompt(prompt_name, label="test") - assert mock_server_call.call_count == 1 - assert result_call_2 == prompt_client - - -# Should return cached chat prompt if not expired when fetching by version -def test_get_valid_cached_chat_prompt_by_version(langfuse): - prompt_name = "test_get_valid_cached_chat_prompt_by_version" - prompt = Prompt_Chat( - name=prompt_name, - version=1, - prompt=[{"role": "system", "content": "Make me laugh"}], - labels=["test"], - type="chat", - config={}, - tags=[], - ) - prompt_client = ChatPromptClient(prompt) - - mock_server_call = langfuse.api.prompts.get - mock_server_call.return_value = prompt - - result_call_1 = langfuse.get_prompt(prompt_name, version=1) - assert mock_server_call.call_count == 1 - assert result_call_1 == prompt_client - - result_call_2 = langfuse.get_prompt(prompt_name, version=1) - assert mock_server_call.call_count == 1 - assert result_call_2 == prompt_client - - -# Should return cached chat prompt if fetching the default prompt or the 'production' labeled one -def test_get_valid_cached_production_chat_prompt(langfuse): - prompt_name = "test_get_valid_cached_production_chat_prompt" - prompt = Prompt_Chat( - name=prompt_name, - version=1, - prompt=[{"role": "system", "content": "Make me laugh"}], - labels=["test"], - type="chat", - config={}, - tags=[], - ) - prompt_client = ChatPromptClient(prompt) - - mock_server_call = langfuse.api.prompts.get - mock_server_call.return_value = prompt - - result_call_1 = langfuse.get_prompt(prompt_name) - assert mock_server_call.call_count == 1 - assert result_call_1 == prompt_client - - result_call_2 = langfuse.get_prompt(prompt_name, label="production") - assert mock_server_call.call_count == 1 - assert result_call_2 == prompt_client - - -# Should return cached chat prompt if not expired -def test_get_valid_cached_chat_prompt(langfuse): - prompt_name = "test_get_valid_cached_chat_prompt" - prompt = Prompt_Chat( - name=prompt_name, - version=1, - prompt=[{"role": "system", "content": "Make me laugh"}], - labels=[], - type="chat", - config={}, - tags=[], - ) - prompt_client = ChatPromptClient(prompt) - - mock_server_call = langfuse.api.prompts.get - mock_server_call.return_value = prompt - - result_call_1 = langfuse.get_prompt(prompt_name) - assert mock_server_call.call_count == 1 - assert result_call_1 == prompt_client - - result_call_2 = langfuse.get_prompt(prompt_name) - assert mock_server_call.call_count == 1 - assert result_call_2 == prompt_client - - -# Should refetch and return new prompt if cached one is expired according to custom TTL -@patch.object(PromptCacheItem, "get_epoch_seconds") -def test_get_fresh_prompt_when_expired_cache_custom_ttl(mock_time, langfuse: Langfuse): - mock_time.return_value = 0 - ttl_seconds = 20 - - prompt_name = "test_get_fresh_prompt_when_expired_cache_custom_ttl" - prompt = Prompt_Text( - name=prompt_name, - version=1, - prompt="Make me laugh", - config={"temperature": 0.9}, - labels=[], - type="text", - tags=[], - ) - prompt_client = TextPromptClient(prompt) - - mock_server_call = langfuse.api.prompts.get - mock_server_call.return_value = prompt - - result_call_1 = langfuse.get_prompt(prompt_name, cache_ttl_seconds=ttl_seconds) - assert mock_server_call.call_count == 1 - assert result_call_1 == prompt_client - - # Set time to just BEFORE cache expiry - mock_time.return_value = ttl_seconds - 1 - - result_call_2 = langfuse.get_prompt(prompt_name) - assert mock_server_call.call_count == 1 # No new call - assert result_call_2 == prompt_client - - # Set time to just AFTER cache expiry - mock_time.return_value = ttl_seconds + 1 - - result_call_3 = langfuse.get_prompt(prompt_name) - - while True: - if langfuse._resources.prompt_cache._task_manager.active_tasks() == 0: - break - sleep(0.1) - - assert mock_server_call.call_count == 2 # New call - assert result_call_3 == prompt_client - - -# Should disable caching when cache_ttl_seconds is set to 0 -@patch.object(PromptCacheItem, "get_epoch_seconds") -def test_disable_caching_when_ttl_zero(mock_time, langfuse: Langfuse): - mock_time.return_value = 0 - prompt_name = "test_disable_caching_when_ttl_zero" - - # Initial prompt - prompt1 = Prompt_Text( - name=prompt_name, - version=1, - prompt="Make me laugh", - labels=[], - type="text", - config={}, - tags=[], - ) - - # Updated prompts - prompt2 = Prompt_Text( - name=prompt_name, - version=2, - prompt="Tell me a joke", - labels=[], - type="text", - config={}, - tags=[], - ) - prompt3 = Prompt_Text( - name=prompt_name, - version=3, - prompt="Share a funny story", - labels=[], - type="text", - config={}, - tags=[], - ) - - mock_server_call = langfuse.api.prompts.get - mock_server_call.side_effect = [prompt1, prompt2, prompt3] - - # First call - result1 = langfuse.get_prompt(prompt_name, cache_ttl_seconds=0) - assert mock_server_call.call_count == 1 - assert result1 == TextPromptClient(prompt1) - - # Second call - result2 = langfuse.get_prompt(prompt_name, cache_ttl_seconds=0) - assert mock_server_call.call_count == 2 - assert result2 == TextPromptClient(prompt2) - - # Third call - result3 = langfuse.get_prompt(prompt_name, cache_ttl_seconds=0) - assert mock_server_call.call_count == 3 - assert result3 == TextPromptClient(prompt3) - - # Verify that all results are different - assert result1 != result2 != result3 - - -# Should return stale prompt immediately if cached one is expired according to default TTL and add to refresh promise map -@patch.object(PromptCacheItem, "get_epoch_seconds") -def test_get_stale_prompt_when_expired_cache_default_ttl(mock_time, langfuse: Langfuse): - import logging - - logging.basicConfig(level=logging.DEBUG) - mock_time.return_value = 0 - - prompt_name = "test_get_stale_prompt_when_expired_cache_default_ttl" - prompt = Prompt_Text( - name=prompt_name, - version=1, - prompt="Make me laugh", - labels=[], - type="text", - config={}, - tags=[], - ) - prompt_client = TextPromptClient(prompt) - - mock_server_call = langfuse.api.prompts.get - mock_server_call.return_value = prompt - - result_call_1 = langfuse.get_prompt(prompt_name) - assert mock_server_call.call_count == 1 - assert result_call_1 == prompt_client - - # Update the version of the returned mocked prompt - updated_prompt = Prompt_Text( - name=prompt_name, - version=2, - prompt="Make me laugh", - labels=[], - type="text", - config={}, - tags=[], - ) - mock_server_call.return_value = updated_prompt - - # Set time to just AFTER cache expiry - mock_time.return_value = DEFAULT_PROMPT_CACHE_TTL_SECONDS + 1 - - stale_result = langfuse.get_prompt(prompt_name) - assert stale_result == prompt_client - - # Ensure that only one refresh is triggered despite multiple calls - # Cannot check for value as the prompt might have already been updated - langfuse.get_prompt(prompt_name) - langfuse.get_prompt(prompt_name) - langfuse.get_prompt(prompt_name) - langfuse.get_prompt(prompt_name) - - while True: - if langfuse._resources.prompt_cache._task_manager.active_tasks() == 0: - break - sleep(0.1) - - assert mock_server_call.call_count == 2 # Only one new call to server - - # Check that the prompt has been updated after refresh - updated_result = langfuse.get_prompt(prompt_name) - assert updated_result.version == 2 - assert updated_result == TextPromptClient(updated_prompt) - - -# Should refetch and return new prompt if cached one is expired according to default TTL -@patch.object(PromptCacheItem, "get_epoch_seconds") -def test_get_fresh_prompt_when_expired_cache_default_ttl(mock_time, langfuse: Langfuse): - mock_time.return_value = 0 - - prompt_name = "test_get_fresh_prompt_when_expired_cache_default_ttl" - prompt = Prompt_Text( - name=prompt_name, - version=1, - prompt="Make me laugh", - labels=[], - type="text", - config={}, - tags=[], - ) - prompt_client = TextPromptClient(prompt) - - mock_server_call = langfuse.api.prompts.get - mock_server_call.return_value = prompt - - result_call_1 = langfuse.get_prompt(prompt_name) - assert mock_server_call.call_count == 1 - assert result_call_1 == prompt_client - - # Set time to just BEFORE cache expiry - mock_time.return_value = DEFAULT_PROMPT_CACHE_TTL_SECONDS - 1 - - result_call_2 = langfuse.get_prompt(prompt_name) - assert mock_server_call.call_count == 1 # No new call - assert result_call_2 == prompt_client - - # Set time to just AFTER cache expiry - mock_time.return_value = DEFAULT_PROMPT_CACHE_TTL_SECONDS + 1 - - result_call_3 = langfuse.get_prompt(prompt_name) - while True: - if langfuse._resources.prompt_cache._task_manager.active_tasks() == 0: - break - sleep(0.1) - - assert mock_server_call.call_count == 2 # New call - assert result_call_3 == prompt_client - - -# Should return expired prompt if refetch fails -@patch.object(PromptCacheItem, "get_epoch_seconds") -def test_get_expired_prompt_when_failing_fetch(mock_time, langfuse: Langfuse): - mock_time.return_value = 0 - - prompt_name = "test_get_expired_prompt_when_failing_fetch" - prompt = Prompt_Text( - name=prompt_name, - version=1, - prompt="Make me laugh", - labels=[], - type="text", - config={}, - tags=[], - ) - prompt_client = TextPromptClient(prompt) - - mock_server_call = langfuse.api.prompts.get - mock_server_call.return_value = prompt - - result_call_1 = langfuse.get_prompt(prompt_name) - assert mock_server_call.call_count == 1 - assert result_call_1 == prompt_client - - # Set time to just AFTER cache expiry - mock_time.return_value = DEFAULT_PROMPT_CACHE_TTL_SECONDS + 1 - - mock_server_call.side_effect = Exception("Server error") - - result_call_2 = langfuse.get_prompt(prompt_name, max_retries=1) - while True: - if langfuse._resources.prompt_cache._task_manager.active_tasks() == 0: - break - sleep(0.1) - - assert mock_server_call.call_count == 3 - assert result_call_2 == prompt_client - - -@patch.object(PromptCacheItem, "get_epoch_seconds") -def test_evict_prompt_cache_entry_when_refresh_returns_not_found( - mock_time, langfuse: Langfuse -) -> None: - mock_time.return_value = 0 - - prompt_name = "test_evict_prompt_cache_entry_when_refresh_returns_not_found" - ttl_seconds = 5 - fallback_prompt = "fallback text prompt" - - prompt = Prompt_Text( - name=prompt_name, - version=1, - prompt="Make me laugh", - labels=[], - type="text", - config={}, - tags=[], - ) - prompt_client = TextPromptClient(prompt) - cache_key = PromptCache.generate_cache_key(prompt_name, version=None, label=None) - - mock_server_call = langfuse.api.prompts.get - mock_server_call.return_value = prompt - - initial_result = langfuse.get_prompt( - prompt_name, - cache_ttl_seconds=ttl_seconds, - max_retries=0, - ) - assert initial_result == prompt_client - assert langfuse._resources.prompt_cache.get(cache_key) is not None - - # Expire cache entry and trigger background refresh - mock_time.return_value = ttl_seconds + 1 - - def raise_not_found(*_args: object, **_kwargs: object) -> None: - raise NotFoundError({"message": "Prompt not found"}) - - mock_server_call.side_effect = raise_not_found - - stale_result = langfuse.get_prompt( - prompt_name, - cache_ttl_seconds=ttl_seconds, - max_retries=0, - ) - assert stale_result == prompt_client - - while True: - if langfuse._resources.prompt_cache._task_manager.active_tasks() == 0: - break - sleep(0.1) - - assert langfuse._resources.prompt_cache.get(cache_key) is None - - fallback_result = langfuse.get_prompt( - prompt_name, - cache_ttl_seconds=ttl_seconds, - fallback=fallback_prompt, - max_retries=0, - ) - assert fallback_result.is_fallback - assert fallback_result.prompt == fallback_prompt - - -# Should fetch new prompt if version changes -def test_get_fresh_prompt_when_version_changes(langfuse: Langfuse): - prompt_name = "test_get_fresh_prompt_when_version_changes" - prompt = Prompt_Text( - name=prompt_name, - version=1, - prompt="Make me laugh", - labels=[], - type="text", - config={}, - tags=[], - ) - prompt_client = TextPromptClient(prompt) - - mock_server_call = langfuse.api.prompts.get - mock_server_call.return_value = prompt - - result_call_1 = langfuse.get_prompt(prompt_name, version=1) - assert mock_server_call.call_count == 1 - assert result_call_1 == prompt_client - - version_changed_prompt = Prompt_Text( - name=prompt_name, - version=2, - labels=[], - prompt="Make me laugh", - type="text", - config={}, - tags=[], - ) - version_changed_prompt_client = TextPromptClient(version_changed_prompt) - mock_server_call.return_value = version_changed_prompt - - result_call_2 = langfuse.get_prompt(prompt_name, version=2) - assert mock_server_call.call_count == 2 - assert result_call_2 == version_changed_prompt_client - - -def test_do_not_return_fallback_if_fetch_success(): - langfuse = Langfuse() - prompt_name = create_uuid() - prompt_client = langfuse.create_prompt( - name=prompt_name, - prompt="test prompt", - labels=["production"], - ) - - second_prompt_client = langfuse.get_prompt(prompt_name, fallback="fallback") - - assert prompt_client.name == second_prompt_client.name - assert prompt_client.version == second_prompt_client.version - assert prompt_client.prompt == second_prompt_client.prompt - assert prompt_client.config == second_prompt_client.config - assert prompt_client.config == {} - - -def test_fallback_text_prompt(): - langfuse = Langfuse() - - fallback_text_prompt = "this is a fallback text prompt with {{variable}}" - - # Should throw an error if prompt not found and no fallback provided - with pytest.raises(Exception): - langfuse.get_prompt("nonexistent_prompt") - - prompt = langfuse.get_prompt("nonexistent_prompt", fallback=fallback_text_prompt) - - assert prompt.prompt == fallback_text_prompt - assert ( - prompt.compile(variable="value") == "this is a fallback text prompt with value" - ) - - -def test_fallback_chat_prompt(): - langfuse = Langfuse() - fallback_chat_prompt = [ - {"role": "system", "content": "fallback system"}, - {"role": "user", "content": "fallback user name {{name}}"}, - ] - - # Should throw an error if prompt not found and no fallback provided - with pytest.raises(Exception): - langfuse.get_prompt("nonexistent_chat_prompt", type="chat") - - prompt = langfuse.get_prompt( - "nonexistent_chat_prompt", type="chat", fallback=fallback_chat_prompt - ) - - # Check that the prompt structure contains the fallback data (allowing for internal formatting) - assert len(prompt.prompt) == len(fallback_chat_prompt) - assert all(msg["type"] == "message" for msg in prompt.prompt) - assert prompt.prompt[0]["role"] == "system" - assert prompt.prompt[0]["content"] == "fallback system" - assert prompt.prompt[1]["role"] == "user" - assert prompt.prompt[1]["content"] == "fallback user name {{name}}" - assert prompt.compile(name="Jane") == [ - {"role": "system", "content": "fallback system"}, - {"role": "user", "content": "fallback user name Jane"}, - ] - - -def test_do_not_link_observation_if_fallback(): - langfuse = Langfuse() - - fallback_text_prompt = "this is a fallback text prompt with {{variable}}" - - # Should throw an error if prompt not found and no fallback provided - with pytest.raises(Exception): - langfuse.get_prompt("nonexistent_prompt") - - prompt = langfuse.get_prompt("nonexistent_prompt", fallback=fallback_text_prompt) - - generation = langfuse.start_observation( - as_type="generation", - name="mygen", - prompt=prompt, - input="this is a test input", - ).end() - langfuse.flush() - - api = get_api() - trace = api.trace.get(generation.trace_id) - - assert len(trace.observations) == 1 - assert trace.observations[0].prompt_id is None - - -def test_variable_names_on_content_with_variable_names(): - langfuse = Langfuse() - - prompt_client = langfuse.create_prompt( - name="test_variable_names_1", - prompt="test prompt with var names {{ var1 }} {{ var2 }}", - labels=["production"], - type="text", - ) - - second_prompt_client = langfuse.get_prompt("test_variable_names_1") - - assert prompt_client.name == second_prompt_client.name - assert prompt_client.version == second_prompt_client.version - assert prompt_client.prompt == second_prompt_client.prompt - assert prompt_client.labels == ["production", "latest"] - - var_names = second_prompt_client.variables - - assert var_names == ["var1", "var2"] - - -def test_variable_names_on_content_with_no_variable_names(): - langfuse = Langfuse() - - prompt_client = langfuse.create_prompt( - name="test_variable_names_2", - prompt="test prompt with no var names", - labels=["production"], - type="text", - ) - - second_prompt_client = langfuse.get_prompt("test_variable_names_2") - - assert prompt_client.name == second_prompt_client.name - assert prompt_client.version == second_prompt_client.version - assert prompt_client.prompt == second_prompt_client.prompt - assert prompt_client.labels == ["production", "latest"] - - var_names = second_prompt_client.variables - - assert var_names == [] - - -def test_variable_names_on_content_with_variable_names_chat_messages(): - langfuse = Langfuse() - - prompt_client = langfuse.create_prompt( - name="test_variable_names_3", - prompt=[ - { - "role": "system", - "content": "test prompt with template vars {{ var1 }} {{ var2 }}", - }, - {"role": "user", "content": "test prompt 2 with template vars {{ var3 }}"}, - ], - labels=["production"], - type="chat", - ) - - second_prompt_client = langfuse.get_prompt("test_variable_names_3") - - assert prompt_client.name == second_prompt_client.name - assert prompt_client.version == second_prompt_client.version - assert prompt_client.prompt == second_prompt_client.prompt - assert prompt_client.labels == ["production", "latest"] - - var_names = second_prompt_client.variables - - assert var_names == ["var1", "var2", "var3"] - - -def test_variable_names_on_content_with_no_variable_names_chat_messages(): - langfuse = Langfuse() - prompt_name = "test_variable_names_on_content_with_no_variable_names_chat_messages" - - prompt_client = langfuse.create_prompt( - name=prompt_name, - prompt=[ - {"role": "system", "content": "test prompt with no template vars"}, - {"role": "user", "content": "test prompt 2 with no template vars"}, - ], - labels=["production"], - type="chat", - ) - - second_prompt_client = langfuse.get_prompt(prompt_name) - - assert prompt_client.name == second_prompt_client.name - assert prompt_client.version == second_prompt_client.version - assert prompt_client.prompt == second_prompt_client.prompt - assert prompt_client.labels == ["production", "latest"] - - var_names = second_prompt_client.variables - - assert var_names == [] - - -def test_update_prompt(): - langfuse = Langfuse() - prompt_name = create_uuid() - - # Create initial prompt - langfuse.create_prompt( - name=prompt_name, - prompt="test prompt", - labels=["production"], - ) - - # Update prompt labels - updated_prompt = langfuse.update_prompt( - name=prompt_name, - version=1, - new_labels=["john", "doe"], - ) - - # Fetch prompt after update (should be invalidated) - fetched_prompt = langfuse.get_prompt(prompt_name) - - # Verify the fetched prompt matches the updated values - assert fetched_prompt.name == prompt_name - assert fetched_prompt.version == 1 - print(f"Fetched prompt labels: {fetched_prompt.labels}") - print(f"Updated prompt labels: {updated_prompt.labels}") - - # production was set by the first call, latest is managed and set by Langfuse - expected_labels = sorted(["latest", "doe", "production", "john"]) - assert sorted(fetched_prompt.labels) == expected_labels - assert sorted(updated_prompt.labels) == expected_labels - - -def test_update_prompt_in_folder(): - langfuse = Langfuse() - prompt_name = f"some-folder/{create_uuid()}" - - # Create initial prompt - langfuse.create_prompt( - name=prompt_name, - prompt="test prompt", - labels=["production"], - ) - - old_prompt_obj = langfuse.get_prompt(prompt_name) - - updated_prompt = langfuse.update_prompt( - name=old_prompt_obj.name, - version=old_prompt_obj.version, - new_labels=["john", "doe"], - ) - - # Fetch prompt after update (should be invalidated) - fetched_prompt = langfuse.get_prompt(prompt_name) - - # Verify the fetched prompt matches the updated values - assert fetched_prompt.name == prompt_name - assert fetched_prompt.version == 1 - print(f"Fetched prompt labels: {fetched_prompt.labels}") - print(f"Updated prompt labels: {updated_prompt.labels}") - - # production was set by the first call, latest is managed and set by Langfuse - expected_labels = sorted(["latest", "doe", "production", "john"]) - assert sorted(fetched_prompt.labels) == expected_labels - assert sorted(updated_prompt.labels) == expected_labels diff --git a/tests/unit/__init__.py b/tests/unit/__init__.py new file mode 100644 index 000000000..8b1378917 --- /dev/null +++ b/tests/unit/__init__.py @@ -0,0 +1 @@ + diff --git a/tests/test_additional_headers_simple.py b/tests/unit/test_additional_headers_simple.py similarity index 100% rename from tests/test_additional_headers_simple.py rename to tests/unit/test_additional_headers_simple.py diff --git a/tests/test_error_logging.py b/tests/unit/test_error_logging.py similarity index 100% rename from tests/test_error_logging.py rename to tests/unit/test_error_logging.py diff --git a/tests/test_error_parsing.py b/tests/unit/test_error_parsing.py similarity index 100% rename from tests/test_error_parsing.py rename to tests/unit/test_error_parsing.py diff --git a/tests/test_initialization.py b/tests/unit/test_initialization.py similarity index 100% rename from tests/test_initialization.py rename to tests/unit/test_initialization.py diff --git a/tests/test_json.py b/tests/unit/test_json.py similarity index 100% rename from tests/test_json.py rename to tests/unit/test_json.py diff --git a/tests/unit/test_langchain.py b/tests/unit/test_langchain.py new file mode 100644 index 000000000..b4c1ba2ee --- /dev/null +++ b/tests/unit/test_langchain.py @@ -0,0 +1,168 @@ +from unittest.mock import patch + +import pytest +from langchain.messages import HumanMessage +from langchain_core.messages import AIMessage +from langchain_core.output_parsers import StrOutputParser +from langchain_core.outputs import ChatGeneration, ChatResult, Generation, LLMResult +from langchain_core.prompts import ChatPromptTemplate +from langchain_openai import ChatOpenAI, OpenAI + +from langfuse._client.attributes import LangfuseOtelSpanAttributes +from langfuse.langchain import CallbackHandler + + +def _assert_parent_child(parent_span, child_span) -> None: + assert child_span.parent is not None + assert child_span.parent.span_id == parent_span.context.span_id + + +def test_chat_model_callback_exports_generation_span( + langfuse_memory_client, get_span, json_attr +): + response = ChatResult( + generations=[ + ChatGeneration(message=AIMessage(content="bonjour"), text="bonjour") + ], + llm_output={ + "token_usage": { + "prompt_tokens": 4, + "completion_tokens": 2, + "total_tokens": 6, + }, + "model_name": "gpt-4o-mini", + }, + ) + + with patch.object(ChatOpenAI, "_generate", return_value=response): + handler = CallbackHandler() + + with langfuse_memory_client.start_as_current_observation(name="parent"): + ChatOpenAI(api_key="test", temperature=0).invoke( + [HumanMessage(content="hello")], + config={"callbacks": [handler]}, + ) + + langfuse_memory_client.flush() + parent_span = get_span("parent") + generation_span = get_span("ChatOpenAI") + + _assert_parent_child(parent_span, generation_span) + assert ( + generation_span.attributes[LangfuseOtelSpanAttributes.OBSERVATION_TYPE] + == "generation" + ) + assert json_attr(generation_span, LangfuseOtelSpanAttributes.OBSERVATION_INPUT) == [ + {"role": "user", "content": "hello"} + ] + assert json_attr( + generation_span, LangfuseOtelSpanAttributes.OBSERVATION_OUTPUT + ) == { + "role": "assistant", + "content": "bonjour", + } + assert ( + generation_span.attributes[LangfuseOtelSpanAttributes.OBSERVATION_MODEL] + == "gpt-4o-mini" + ) + assert json_attr( + generation_span, LangfuseOtelSpanAttributes.OBSERVATION_USAGE_DETAILS + ) == { + "prompt_tokens": 4, + "completion_tokens": 2, + "total_tokens": 6, + } + + +def test_llm_callback_exports_generation_span(langfuse_memory_client, get_span): + response = LLMResult( + generations=[[Generation(text="sockzilla")]], + llm_output={ + "token_usage": { + "prompt_tokens": 7, + "completion_tokens": 3, + "total_tokens": 10, + }, + "model_name": "gpt-4o-mini-instruct", + }, + ) + + with patch.object(OpenAI, "_generate", return_value=response): + handler = CallbackHandler() + + with langfuse_memory_client.start_as_current_observation(name="parent"): + OpenAI(api_key="test", temperature=0).invoke( + "name a sock company", + config={"callbacks": [handler], "run_name": "sock-name"}, + ) + + langfuse_memory_client.flush() + span = get_span("sock-name") + + assert span.attributes[LangfuseOtelSpanAttributes.OBSERVATION_OUTPUT] == "sockzilla" + assert ( + span.attributes[LangfuseOtelSpanAttributes.OBSERVATION_MODEL] + == "gpt-4o-mini-instruct" + ) + + +def test_lcel_chain_exports_intermediate_chain_spans( + langfuse_memory_client, get_span, find_spans +): + response = ChatResult( + generations=[ + ChatGeneration( + message=AIMessage(content="knock knock"), + text="knock knock", + ) + ], + llm_output={ + "token_usage": { + "prompt_tokens": 4, + "completion_tokens": 2, + "total_tokens": 6, + }, + "model_name": "gpt-4o-mini", + }, + ) + + with patch.object(ChatOpenAI, "_generate", return_value=response): + handler = CallbackHandler() + prompt = ChatPromptTemplate.from_template("tell me a joke about {topic}") + chain = prompt | ChatOpenAI(api_key="test", temperature=0) | StrOutputParser() + + with langfuse_memory_client.start_as_current_observation(name="parent"): + result = chain.invoke({"topic": "otters"}, config={"callbacks": [handler]}) + + assert result == "knock knock" + + langfuse_memory_client.flush() + sequence_span = get_span("RunnableSequence") + prompt_span = get_span("ChatPromptTemplate") + generation_span = get_span("ChatOpenAI") + parser_span = get_span("StrOutputParser") + + _assert_parent_child(sequence_span, prompt_span) + _assert_parent_child(sequence_span, generation_span) + _assert_parent_child(sequence_span, parser_span) + assert len(find_spans("ChatOpenAI")) == 1 + + +def test_chat_model_error_marks_generation_error(langfuse_memory_client, get_span): + with patch.object(ChatOpenAI, "_generate", side_effect=RuntimeError("boom")): + handler = CallbackHandler() + + with langfuse_memory_client.start_as_current_observation(name="parent"): + with pytest.raises(RuntimeError, match="boom"): + ChatOpenAI(api_key="test", temperature=0).invoke( + [HumanMessage(content="hello")], + config={"callbacks": [handler]}, + ) + + langfuse_memory_client.flush() + span = get_span("ChatOpenAI") + + assert span.attributes[LangfuseOtelSpanAttributes.OBSERVATION_LEVEL] == "ERROR" + assert ( + "boom" in span.attributes[LangfuseOtelSpanAttributes.OBSERVATION_STATUS_MESSAGE] + ) diff --git a/tests/test_logger.py b/tests/unit/test_logger.py similarity index 100% rename from tests/test_logger.py rename to tests/unit/test_logger.py diff --git a/tests/test_media.py b/tests/unit/test_media.py similarity index 69% rename from tests/test_media.py rename to tests/unit/test_media.py index 6c095ece1..63df03920 100644 --- a/tests/test_media.py +++ b/tests/unit/test_media.py @@ -1,14 +1,10 @@ import base64 -import re from types import SimpleNamespace from unittest.mock import Mock -from uuid import uuid4 import pytest -from langfuse._client.client import Langfuse from langfuse.media import LangfuseMedia -from tests.utils import get_api # Test data SAMPLE_JPEG_BYTES = b"\xff\xd8\xff\xe0\x00\x10JFIF\x00\x01\x01\x01\x00H\x00H\x00\x00" @@ -143,61 +139,3 @@ def test_resolve_media_references_uses_configured_httpx_client(): httpx_client.get.assert_called_once_with( "https://example.com/test.jpg", timeout=fetch_timeout_seconds ) - - -def test_replace_media_reference_string_in_object(): - # Create test audio file - audio_file = "static/joke_prompt.wav" - with open(audio_file, "rb") as f: - mock_audio_bytes = f.read() - - # Create Langfuse client and trace with media - langfuse = Langfuse() - - mock_trace_name = f"test-trace-with-audio-{uuid4()}" - base64_audio = base64.b64encode(mock_audio_bytes).decode() - - span = langfuse.start_observation( - name=mock_trace_name, - metadata={ - "context": { - "nested": LangfuseMedia( - base64_data_uri=f"data:audio/wav;base64,{base64_audio}" - ) - } - }, - ).end() - - langfuse.flush() - - # Verify media reference string format - fetched_trace = get_api().trace.get(span.trace_id) - media_ref = fetched_trace.observations[0].metadata["context"]["nested"] - assert re.match( - r"^@@@langfuseMedia:type=audio/wav\|id=.+\|source=base64_data_uri@@@$", - media_ref, - ) - - # Resolve media references back to base64 - resolved_obs = langfuse.resolve_media_references( - obj=fetched_trace.observations[0], resolve_with="base64_data_uri" - ) - - # Verify resolved base64 matches original - expected_base64 = f"data:audio/wav;base64,{base64_audio}" - assert resolved_obs["metadata"]["context"]["nested"] == expected_base64 - - # Create second trace reusing the media reference - span2 = langfuse.start_observation( - name=f"2-{mock_trace_name}", - metadata={"context": {"nested": resolved_obs["metadata"]["context"]["nested"]}}, - ).end() - - langfuse.flush() - - # Verify second trace has same media reference - fetched_trace2 = get_api().trace.get(span2.trace_id) - assert ( - fetched_trace2.observations[0].metadata["context"]["nested"] - == fetched_trace.observations[0].metadata["context"]["nested"] - ) diff --git a/tests/test_media_manager.py b/tests/unit/test_media_manager.py similarity index 100% rename from tests/test_media_manager.py rename to tests/unit/test_media_manager.py diff --git a/tests/unit/test_openai.py b/tests/unit/test_openai.py new file mode 100644 index 000000000..6ef51ff54 --- /dev/null +++ b/tests/unit/test_openai.py @@ -0,0 +1,238 @@ +from types import SimpleNamespace +from unittest.mock import patch + +import pytest + +from langfuse._client.attributes import LangfuseOtelSpanAttributes +from langfuse.openai import openai as lf_openai + + +def test_chat_completion_exports_generation_span( + langfuse_memory_client, get_span, json_attr +): + openai_client = lf_openai.OpenAI(api_key="test") + response = SimpleNamespace( + model="gpt-4o-mini", + choices=[ + SimpleNamespace( + message=SimpleNamespace( + role="assistant", + content="2", + function_call=None, + tool_calls=None, + audio=None, + ) + ) + ], + usage=SimpleNamespace(prompt_tokens=3, completion_tokens=1, total_tokens=4), + ) + + with patch.object(openai_client.chat.completions, "_post", return_value=response): + result = openai_client.chat.completions.create( + name="unit-openai-chat", + model="gpt-4o-mini", + messages=[{"role": "user", "content": "1 + 1 = ?"}], + temperature=0, + metadata={"suite": "unit"}, + ) + + assert result is response + + langfuse_memory_client.flush() + span = get_span("unit-openai-chat") + + assert span.attributes[LangfuseOtelSpanAttributes.OBSERVATION_TYPE] == "generation" + assert ( + span.attributes[LangfuseOtelSpanAttributes.OBSERVATION_MODEL] == "gpt-4o-mini" + ) + assert span.attributes["langfuse.observation.metadata.suite"] == "unit" + assert json_attr(span, LangfuseOtelSpanAttributes.OBSERVATION_INPUT) == [ + {"role": "user", "content": "1 + 1 = ?"} + ] + assert json_attr(span, LangfuseOtelSpanAttributes.OBSERVATION_OUTPUT) == { + "role": "assistant", + "content": "2", + } + assert json_attr(span, LangfuseOtelSpanAttributes.OBSERVATION_MODEL_PARAMETERS) == { + "temperature": 0, + "max_tokens": "Infinity", + "top_p": 1, + "frequency_penalty": 0, + "presence_penalty": 0, + } + assert json_attr(span, LangfuseOtelSpanAttributes.OBSERVATION_USAGE_DETAILS) == { + "prompt_tokens": 3, + "completion_tokens": 1, + "total_tokens": 4, + } + + +def test_streaming_chat_completion_exports_ttft( + langfuse_memory_client, get_span, json_attr +): + openai_client = lf_openai.OpenAI(api_key="test") + usage = SimpleNamespace(prompt_tokens=3, completion_tokens=1, total_tokens=4) + + def fake_stream(): + yield SimpleNamespace( + model="gpt-4o-mini", + choices=[ + SimpleNamespace( + delta=SimpleNamespace( + role="assistant", + content="2", + function_call=None, + tool_calls=None, + ), + finish_reason=None, + ) + ], + usage=None, + ) + yield SimpleNamespace( + model="gpt-4o-mini", + choices=[ + SimpleNamespace( + delta=SimpleNamespace( + role=None, + content=None, + function_call=None, + tool_calls=None, + ), + finish_reason="stop", + ) + ], + usage=usage, + ) + + with patch.object( + openai_client.chat.completions, "_post", return_value=fake_stream() + ): + stream = openai_client.chat.completions.create( + name="unit-openai-stream", + model="gpt-4o-mini", + messages=[{"role": "user", "content": "1 + 1 = ?"}], + temperature=0, + stream=True, + ) + chunks = list(stream) + + assert len(chunks) == 2 + + langfuse_memory_client.flush() + span = get_span("unit-openai-stream") + + assert span.attributes[LangfuseOtelSpanAttributes.OBSERVATION_OUTPUT] == "2" + assert ( + span.attributes[LangfuseOtelSpanAttributes.OBSERVATION_COMPLETION_START_TIME] + is not None + ) + assert span.attributes["langfuse.observation.metadata.finish_reason"] == "stop" + assert json_attr(span, LangfuseOtelSpanAttributes.OBSERVATION_USAGE_DETAILS) == { + "prompt_tokens": 3, + "completion_tokens": 1, + "total_tokens": 4, + } + + +def test_chat_completion_error_marks_generation_error(langfuse_memory_client, get_span): + openai_client = lf_openai.OpenAI(api_key="test") + + with patch.object( + openai_client.chat.completions, + "_post", + side_effect=RuntimeError("boom"), + ): + with pytest.raises(RuntimeError, match="boom"): + openai_client.chat.completions.create( + name="unit-openai-error", + model="gpt-4o-mini", + messages=[{"role": "user", "content": "explode"}], + temperature=0, + ) + + langfuse_memory_client.flush() + span = get_span("unit-openai-error") + + assert span.attributes[LangfuseOtelSpanAttributes.OBSERVATION_LEVEL] == "ERROR" + assert ( + "boom" in span.attributes[LangfuseOtelSpanAttributes.OBSERVATION_STATUS_MESSAGE] + ) + assert LangfuseOtelSpanAttributes.OBSERVATION_OUTPUT not in span.attributes + + +@pytest.mark.asyncio +async def test_async_chat_completion_exports_generation_span( + langfuse_memory_client, get_span, json_attr +): + openai_client = lf_openai.AsyncOpenAI(api_key="test") + response = SimpleNamespace( + model="gpt-4o-mini", + choices=[ + SimpleNamespace( + message=SimpleNamespace( + role="assistant", + content="async result", + function_call=None, + tool_calls=None, + audio=None, + ) + ) + ], + usage=SimpleNamespace(prompt_tokens=5, completion_tokens=2, total_tokens=7), + ) + + with patch.object(openai_client.chat.completions, "_post", return_value=response): + result = await openai_client.chat.completions.create( + name="unit-openai-async", + model="gpt-4o-mini", + messages=[{"role": "user", "content": "hello"}], + temperature=0, + ) + + assert result is response + + langfuse_memory_client.flush() + span = get_span("unit-openai-async") + + assert json_attr(span, LangfuseOtelSpanAttributes.OBSERVATION_OUTPUT) == { + "role": "assistant", + "content": "async result", + } + assert json_attr(span, LangfuseOtelSpanAttributes.OBSERVATION_USAGE_DETAILS) == { + "prompt_tokens": 5, + "completion_tokens": 2, + "total_tokens": 7, + } + + +def test_embedding_exports_dimensions_and_count( + langfuse_memory_client, get_span, json_attr +): + openai_client = lf_openai.OpenAI(api_key="test") + response = SimpleNamespace( + model="text-embedding-3-small", + data=[SimpleNamespace(embedding=[0.1, 0.2, 0.3])], + usage=SimpleNamespace(prompt_tokens=2, total_tokens=2), + ) + + with patch.object(openai_client.embeddings, "_post", return_value=response): + result = openai_client.embeddings.create( + name="unit-openai-embedding", + model="text-embedding-3-small", + input="hello world", + ) + + assert result is response + + langfuse_memory_client.flush() + span = get_span("unit-openai-embedding") + + assert span.attributes[LangfuseOtelSpanAttributes.OBSERVATION_TYPE] == "embedding" + assert json_attr(span, LangfuseOtelSpanAttributes.OBSERVATION_OUTPUT) == { + "dimensions": 3, + "count": 1, + } + assert json_attr(span, LangfuseOtelSpanAttributes.OBSERVATION_USAGE_DETAILS) == { + "input": 2 + } diff --git a/tests/test_openai_prompt_extraction.py b/tests/unit/test_openai_prompt_extraction.py similarity index 100% rename from tests/test_openai_prompt_extraction.py rename to tests/unit/test_openai_prompt_extraction.py diff --git a/tests/test_otel.py b/tests/unit/test_otel.py similarity index 100% rename from tests/test_otel.py rename to tests/unit/test_otel.py diff --git a/tests/test_parse_usage_model.py b/tests/unit/test_parse_usage_model.py similarity index 100% rename from tests/test_parse_usage_model.py rename to tests/unit/test_parse_usage_model.py diff --git a/tests/unit/test_prompt.py b/tests/unit/test_prompt.py new file mode 100644 index 000000000..071cff837 --- /dev/null +++ b/tests/unit/test_prompt.py @@ -0,0 +1,673 @@ +from time import sleep +from unittest.mock import Mock, patch + +import pytest + +from langfuse._client.client import Langfuse +from langfuse._utils.prompt_cache import ( + DEFAULT_PROMPT_CACHE_TTL_SECONDS, + PromptCache, + PromptCacheItem, +) +from langfuse.api import NotFoundError, Prompt_Chat, Prompt_Text +from langfuse.model import ChatPromptClient, TextPromptClient + + +@pytest.mark.parametrize( + ("variables", "placeholders", "expected_len", "expected_contents"), + [ + ( + {"role": "helpful", "task": "coding"}, + {}, + 3, + ["You are a helpful assistant", None, "Help me with coding"], + ), + ( + {}, + {}, + 3, + ["You are a {{role}} assistant", None, "Help me with {{task}}"], + ), + ( + {}, + { + "examples": [ + {"role": "user", "content": "Example question"}, + {"role": "assistant", "content": "Example answer"}, + ], + }, + 4, + [ + "You are a {{role}} assistant", + "Example question", + "Example answer", + "Help me with {{task}}", + ], + ), + ( + {"role": "helpful", "task": "coding"}, + { + "examples": [ + {"role": "user", "content": "Show me {{task}}"}, + {"role": "assistant", "content": "Here's {{task}}"}, + ], + }, + 4, + [ + "You are a helpful assistant", + "Show me coding", + "Here's coding", + "Help me with coding", + ], + ), + ( + {"role": "helpful", "task": "coding"}, + {"unused": [{"role": "user", "content": "Won't appear"}]}, + 3, + ["You are a helpful assistant", None, "Help me with coding"], + ), + ( + {"role": "helpful", "task": "coding"}, + {"examples": "not a list"}, + 3, + [ + "You are a helpful assistant", + "not a list", + "Help me with coding", + ], + ), + ( + {"role": "helpful", "task": "coding"}, + { + "examples": [ + "invalid message", + {"role": "user", "content": "valid message"}, + ] + }, + 4, + [ + "You are a helpful assistant", + "['invalid message', {'role': 'user', 'content': 'valid message'}]", + "valid message", + "Help me with coding", + ], + ), + ], +) +def test_compile_with_placeholders( + variables, placeholders, expected_len, expected_contents +) -> None: + mock_prompt = Prompt_Chat( + name="test_prompt", + version=1, + type="chat", + config={}, + tags=[], + labels=[], + prompt=[ + {"role": "system", "content": "You are a {{role}} assistant"}, + {"type": "placeholder", "name": "examples"}, + {"role": "user", "content": "Help me with {{task}}"}, + ], + ) + + compile_kwargs = {**placeholders, **variables} + result = ChatPromptClient(mock_prompt).compile(**compile_kwargs) + + assert len(result) == expected_len + for i, expected_content in enumerate(expected_contents): + if expected_content is None: + assert "type" in result[i] and result[i]["type"] == "placeholder" + elif isinstance(result[i], str): + assert result[i] == expected_content + else: + assert "content" in result[i] + assert result[i]["content"] == expected_content + + +@pytest.fixture +def langfuse(): + from langfuse._client.resource_manager import LangfuseResourceManager + + langfuse_instance = Langfuse() + langfuse_instance.api = Mock() + + if langfuse_instance._resources is None: + langfuse_instance._resources = Mock(spec=LangfuseResourceManager) + langfuse_instance._resources.prompt_cache = PromptCache() + + return langfuse_instance + + +def test_get_fresh_prompt(langfuse): + prompt_name = "test_get_fresh_prompt" + prompt = Prompt_Text( + name=prompt_name, + version=1, + prompt="Make me laugh", + type="text", + labels=[], + config={}, + tags=[], + ) + + mock_server_call = langfuse.api.prompts.get + mock_server_call.return_value = prompt + + result = langfuse.get_prompt(prompt_name, fallback="fallback") + mock_server_call.assert_called_once_with( + prompt_name, + version=None, + label=None, + request_options=None, + ) + + assert result == TextPromptClient(prompt) + + +def test_throw_if_name_unspecified(langfuse): + with pytest.raises(ValueError) as exc_info: + langfuse.get_prompt("") + + assert "Prompt name cannot be empty" in str(exc_info.value) + + +def test_throw_when_failing_fetch_and_no_cache(langfuse): + mock_server_call = langfuse.api.prompts.get + mock_server_call.side_effect = Exception("Prompt not found") + + with pytest.raises(Exception) as exc_info: + langfuse.get_prompt("failing_fetch_and_no_cache") + + assert "Prompt not found" in str(exc_info.value) + + +def test_using_custom_prompt_timeouts(langfuse): + prompt_name = "test_using_custom_prompt_timeouts" + prompt = Prompt_Text( + name=prompt_name, + version=1, + prompt="Make me laugh", + type="text", + labels=[], + config={}, + tags=[], + ) + + mock_server_call = langfuse.api.prompts.get + mock_server_call.return_value = prompt + + result = langfuse.get_prompt( + prompt_name, fallback="fallback", fetch_timeout_seconds=1000 + ) + mock_server_call.assert_called_once_with( + prompt_name, + version=None, + label=None, + request_options={"timeout_in_seconds": 1000}, + ) + + assert result == TextPromptClient(prompt) + + +def test_throw_if_cache_ttl_seconds_positional_argument(langfuse): + with pytest.raises(TypeError) as exc_info: + langfuse.get_prompt("test ttl seconds in positional arg", 20) + + assert "positional arguments" in str(exc_info.value) + + +def test_get_valid_cached_prompt(langfuse): + prompt_name = "test_get_valid_cached_prompt" + prompt = Prompt_Text( + name=prompt_name, + version=1, + prompt="Make me laugh", + type="text", + labels=[], + config={}, + tags=[], + ) + prompt_client = TextPromptClient(prompt) + + mock_server_call = langfuse.api.prompts.get + mock_server_call.return_value = prompt + + result_call_1 = langfuse.get_prompt(prompt_name, fallback="fallback") + assert mock_server_call.call_count == 1 + assert result_call_1 == prompt_client + + result_call_2 = langfuse.get_prompt(prompt_name) + assert mock_server_call.call_count == 1 + assert result_call_2 == prompt_client + + +def test_get_valid_cached_chat_prompt_by_label(langfuse): + prompt_name = "test_get_valid_cached_chat_prompt_by_label" + prompt = Prompt_Chat( + name=prompt_name, + version=1, + prompt=[{"role": "system", "content": "Make me laugh"}], + labels=["test"], + type="chat", + config={}, + tags=[], + ) + prompt_client = ChatPromptClient(prompt) + + mock_server_call = langfuse.api.prompts.get + mock_server_call.return_value = prompt + + result_call_1 = langfuse.get_prompt(prompt_name, label="test") + assert mock_server_call.call_count == 1 + assert result_call_1 == prompt_client + + result_call_2 = langfuse.get_prompt(prompt_name, label="test") + assert mock_server_call.call_count == 1 + assert result_call_2 == prompt_client + + +def test_get_valid_cached_chat_prompt_by_version(langfuse): + prompt_name = "test_get_valid_cached_chat_prompt_by_version" + prompt = Prompt_Chat( + name=prompt_name, + version=1, + prompt=[{"role": "system", "content": "Make me laugh"}], + labels=["test"], + type="chat", + config={}, + tags=[], + ) + prompt_client = ChatPromptClient(prompt) + + mock_server_call = langfuse.api.prompts.get + mock_server_call.return_value = prompt + + result_call_1 = langfuse.get_prompt(prompt_name, version=1) + assert mock_server_call.call_count == 1 + assert result_call_1 == prompt_client + + result_call_2 = langfuse.get_prompt(prompt_name, version=1) + assert mock_server_call.call_count == 1 + assert result_call_2 == prompt_client + + +def test_get_valid_cached_production_chat_prompt(langfuse): + prompt_name = "test_get_valid_cached_production_chat_prompt" + prompt = Prompt_Chat( + name=prompt_name, + version=1, + prompt=[{"role": "system", "content": "Make me laugh"}], + labels=["test"], + type="chat", + config={}, + tags=[], + ) + prompt_client = ChatPromptClient(prompt) + + mock_server_call = langfuse.api.prompts.get + mock_server_call.return_value = prompt + + result_call_1 = langfuse.get_prompt(prompt_name) + assert mock_server_call.call_count == 1 + assert result_call_1 == prompt_client + + result_call_2 = langfuse.get_prompt(prompt_name, label="production") + assert mock_server_call.call_count == 1 + assert result_call_2 == prompt_client + + +def test_get_valid_cached_chat_prompt(langfuse): + prompt_name = "test_get_valid_cached_chat_prompt" + prompt = Prompt_Chat( + name=prompt_name, + version=1, + prompt=[{"role": "system", "content": "Make me laugh"}], + labels=[], + type="chat", + config={}, + tags=[], + ) + prompt_client = ChatPromptClient(prompt) + + mock_server_call = langfuse.api.prompts.get + mock_server_call.return_value = prompt + + result_call_1 = langfuse.get_prompt(prompt_name) + assert mock_server_call.call_count == 1 + assert result_call_1 == prompt_client + + result_call_2 = langfuse.get_prompt(prompt_name) + assert mock_server_call.call_count == 1 + assert result_call_2 == prompt_client + + +@patch.object(PromptCacheItem, "get_epoch_seconds") +def test_get_fresh_prompt_when_expired_cache_custom_ttl(mock_time, langfuse: Langfuse): + mock_time.return_value = 0 + ttl_seconds = 20 + + prompt_name = "test_get_fresh_prompt_when_expired_cache_custom_ttl" + prompt = Prompt_Text( + name=prompt_name, + version=1, + prompt="Make me laugh", + config={"temperature": 0.9}, + labels=[], + type="text", + tags=[], + ) + prompt_client = TextPromptClient(prompt) + + mock_server_call = langfuse.api.prompts.get + mock_server_call.return_value = prompt + + result_call_1 = langfuse.get_prompt(prompt_name, cache_ttl_seconds=ttl_seconds) + assert mock_server_call.call_count == 1 + assert result_call_1 == prompt_client + + mock_time.return_value = ttl_seconds - 1 + + result_call_2 = langfuse.get_prompt(prompt_name) + assert mock_server_call.call_count == 1 + assert result_call_2 == prompt_client + + mock_time.return_value = ttl_seconds + 1 + + result_call_3 = langfuse.get_prompt(prompt_name) + + while True: + if langfuse._resources.prompt_cache._task_manager.active_tasks() == 0: + break + sleep(0.1) + + assert mock_server_call.call_count == 2 + assert result_call_3 == prompt_client + + +@patch.object(PromptCacheItem, "get_epoch_seconds") +def test_disable_caching_when_ttl_zero(mock_time, langfuse: Langfuse): + mock_time.return_value = 0 + prompt_name = "test_disable_caching_when_ttl_zero" + + prompt1 = Prompt_Text( + name=prompt_name, + version=1, + prompt="Make me laugh", + labels=[], + type="text", + config={}, + tags=[], + ) + prompt2 = Prompt_Text( + name=prompt_name, + version=2, + prompt="Tell me a joke", + labels=[], + type="text", + config={}, + tags=[], + ) + prompt3 = Prompt_Text( + name=prompt_name, + version=3, + prompt="Share a funny story", + labels=[], + type="text", + config={}, + tags=[], + ) + + mock_server_call = langfuse.api.prompts.get + mock_server_call.side_effect = [prompt1, prompt2, prompt3] + + result1 = langfuse.get_prompt(prompt_name, cache_ttl_seconds=0) + assert mock_server_call.call_count == 1 + assert result1 == TextPromptClient(prompt1) + + result2 = langfuse.get_prompt(prompt_name, cache_ttl_seconds=0) + assert mock_server_call.call_count == 2 + assert result2 == TextPromptClient(prompt2) + + result3 = langfuse.get_prompt(prompt_name, cache_ttl_seconds=0) + assert mock_server_call.call_count == 3 + assert result3 == TextPromptClient(prompt3) + + assert result1 != result2 != result3 + + +@patch.object(PromptCacheItem, "get_epoch_seconds") +def test_get_stale_prompt_when_expired_cache_default_ttl(mock_time, langfuse: Langfuse): + import logging + + logging.basicConfig(level=logging.DEBUG) + mock_time.return_value = 0 + + prompt_name = "test_get_stale_prompt_when_expired_cache_default_ttl" + prompt = Prompt_Text( + name=prompt_name, + version=1, + prompt="Make me laugh", + labels=[], + type="text", + config={}, + tags=[], + ) + prompt_client = TextPromptClient(prompt) + + mock_server_call = langfuse.api.prompts.get + mock_server_call.return_value = prompt + + result_call_1 = langfuse.get_prompt(prompt_name) + assert mock_server_call.call_count == 1 + assert result_call_1 == prompt_client + + updated_prompt = Prompt_Text( + name=prompt_name, + version=2, + prompt="Make me laugh", + labels=[], + type="text", + config={}, + tags=[], + ) + mock_server_call.return_value = updated_prompt + + mock_time.return_value = DEFAULT_PROMPT_CACHE_TTL_SECONDS + 1 + + stale_result = langfuse.get_prompt(prompt_name) + assert stale_result == prompt_client + + langfuse.get_prompt(prompt_name) + langfuse.get_prompt(prompt_name) + langfuse.get_prompt(prompt_name) + langfuse.get_prompt(prompt_name) + + while True: + if langfuse._resources.prompt_cache._task_manager.active_tasks() == 0: + break + sleep(0.1) + + assert mock_server_call.call_count == 2 + + updated_result = langfuse.get_prompt(prompt_name) + assert updated_result.version == 2 + assert updated_result == TextPromptClient(updated_prompt) + + +@patch.object(PromptCacheItem, "get_epoch_seconds") +def test_get_fresh_prompt_when_expired_cache_default_ttl(mock_time, langfuse: Langfuse): + mock_time.return_value = 0 + + prompt_name = "test_get_fresh_prompt_when_expired_cache_default_ttl" + prompt = Prompt_Text( + name=prompt_name, + version=1, + prompt="Make me laugh", + labels=[], + type="text", + config={}, + tags=[], + ) + prompt_client = TextPromptClient(prompt) + + mock_server_call = langfuse.api.prompts.get + mock_server_call.return_value = prompt + + result_call_1 = langfuse.get_prompt(prompt_name) + assert mock_server_call.call_count == 1 + assert result_call_1 == prompt_client + + mock_time.return_value = DEFAULT_PROMPT_CACHE_TTL_SECONDS - 1 + + result_call_2 = langfuse.get_prompt(prompt_name) + assert mock_server_call.call_count == 1 + assert result_call_2 == prompt_client + + mock_time.return_value = DEFAULT_PROMPT_CACHE_TTL_SECONDS + 1 + + result_call_3 = langfuse.get_prompt(prompt_name) + while True: + if langfuse._resources.prompt_cache._task_manager.active_tasks() == 0: + break + sleep(0.1) + + assert mock_server_call.call_count == 2 + assert result_call_3 == prompt_client + + +@patch.object(PromptCacheItem, "get_epoch_seconds") +def test_get_expired_prompt_when_failing_fetch(mock_time, langfuse: Langfuse): + mock_time.return_value = 0 + + prompt_name = "test_get_expired_prompt_when_failing_fetch" + prompt = Prompt_Text( + name=prompt_name, + version=1, + prompt="Make me laugh", + labels=[], + type="text", + config={}, + tags=[], + ) + prompt_client = TextPromptClient(prompt) + + mock_server_call = langfuse.api.prompts.get + mock_server_call.return_value = prompt + + result_call_1 = langfuse.get_prompt(prompt_name) + assert mock_server_call.call_count == 1 + assert result_call_1 == prompt_client + + mock_time.return_value = DEFAULT_PROMPT_CACHE_TTL_SECONDS + 1 + mock_server_call.side_effect = Exception("Server error") + + result_call_2 = langfuse.get_prompt(prompt_name, max_retries=1) + while True: + if langfuse._resources.prompt_cache._task_manager.active_tasks() == 0: + break + sleep(0.1) + + assert mock_server_call.call_count == 3 + assert result_call_2 == prompt_client + + +@patch.object(PromptCacheItem, "get_epoch_seconds") +def test_evict_prompt_cache_entry_when_refresh_returns_not_found( + mock_time, langfuse: Langfuse +) -> None: + mock_time.return_value = 0 + + prompt_name = "test_evict_prompt_cache_entry_when_refresh_returns_not_found" + ttl_seconds = 5 + fallback_prompt = "fallback text prompt" + + prompt = Prompt_Text( + name=prompt_name, + version=1, + prompt="Make me laugh", + labels=[], + type="text", + config={}, + tags=[], + ) + prompt_client = TextPromptClient(prompt) + cache_key = PromptCache.generate_cache_key(prompt_name, version=None, label=None) + + mock_server_call = langfuse.api.prompts.get + mock_server_call.return_value = prompt + + initial_result = langfuse.get_prompt( + prompt_name, + cache_ttl_seconds=ttl_seconds, + max_retries=0, + ) + assert initial_result == prompt_client + assert langfuse._resources.prompt_cache.get(cache_key) is not None + + mock_time.return_value = ttl_seconds + 1 + + def raise_not_found(*_args: object, **_kwargs: object) -> None: + raise NotFoundError({"message": "Prompt not found"}) + + mock_server_call.side_effect = raise_not_found + + stale_result = langfuse.get_prompt( + prompt_name, + cache_ttl_seconds=ttl_seconds, + max_retries=0, + ) + assert stale_result == prompt_client + + while True: + if langfuse._resources.prompt_cache._task_manager.active_tasks() == 0: + break + sleep(0.1) + + assert langfuse._resources.prompt_cache.get(cache_key) is None + + fallback_result = langfuse.get_prompt( + prompt_name, + cache_ttl_seconds=ttl_seconds, + fallback=fallback_prompt, + max_retries=0, + ) + assert fallback_result.is_fallback + assert fallback_result.prompt == fallback_prompt + + +def test_get_fresh_prompt_when_version_changes(langfuse: Langfuse): + prompt_name = "test_get_fresh_prompt_when_version_changes" + prompt = Prompt_Text( + name=prompt_name, + version=1, + prompt="Make me laugh", + labels=[], + type="text", + config={}, + tags=[], + ) + prompt_client = TextPromptClient(prompt) + + mock_server_call = langfuse.api.prompts.get + mock_server_call.return_value = prompt + + result_call_1 = langfuse.get_prompt(prompt_name, version=1) + assert mock_server_call.call_count == 1 + assert result_call_1 == prompt_client + + version_changed_prompt = Prompt_Text( + name=prompt_name, + version=2, + labels=[], + prompt="Make me laugh", + type="text", + config={}, + tags=[], + ) + version_changed_prompt_client = TextPromptClient(version_changed_prompt) + mock_server_call.return_value = version_changed_prompt + + result_call_2 = langfuse.get_prompt(prompt_name, version=2) + assert mock_server_call.call_count == 2 + assert result_call_2 == version_changed_prompt_client diff --git a/tests/test_prompt_atexit.py b/tests/unit/test_prompt_atexit.py similarity index 100% rename from tests/test_prompt_atexit.py rename to tests/unit/test_prompt_atexit.py diff --git a/tests/test_prompt_compilation.py b/tests/unit/test_prompt_compilation.py similarity index 100% rename from tests/test_prompt_compilation.py rename to tests/unit/test_prompt_compilation.py diff --git a/tests/test_propagate_attributes.py b/tests/unit/test_propagate_attributes.py similarity index 99% rename from tests/test_propagate_attributes.py rename to tests/unit/test_propagate_attributes.py index b3be9f830..18b84cd52 100644 --- a/tests/test_propagate_attributes.py +++ b/tests/unit/test_propagate_attributes.py @@ -17,7 +17,7 @@ from langfuse._client.constants import LANGFUSE_SDK_EXPERIMENT_ENVIRONMENT from langfuse._client.datasets import DatasetClient from langfuse.api import Dataset, DatasetItem, DatasetStatus -from tests.test_otel import TestOTelBase +from tests.unit.test_otel import TestOTelBase class TestPropagateAttributesBase(TestOTelBase): diff --git a/tests/test_resource_manager.py b/tests/unit/test_resource_manager.py similarity index 92% rename from tests/test_resource_manager.py rename to tests/unit/test_resource_manager.py index 72f9f7d7e..02cda7e76 100644 --- a/tests/test_resource_manager.py +++ b/tests/unit/test_resource_manager.py @@ -5,11 +5,15 @@ from langfuse._client.resource_manager import LangfuseResourceManager -def test_get_client_preserves_all_settings(): +def test_get_client_preserves_all_settings(monkeypatch): """Test that get_client() preserves environment and all client settings.""" with LangfuseResourceManager._lock: LangfuseResourceManager._instances.clear() + monkeypatch.setenv("LANGFUSE_PUBLIC_KEY", "pk-comprehensive-default") + monkeypatch.setenv("LANGFUSE_SECRET_KEY", "sk-comprehensive-default") + monkeypatch.setenv("LANGFUSE_BASE_URL", "http://localhost:3000") + def should_export(span): return span.name != "drop" diff --git a/tests/test_serializer.py b/tests/unit/test_serializer.py similarity index 100% rename from tests/test_serializer.py rename to tests/unit/test_serializer.py diff --git a/tests/test_span_filter.py b/tests/unit/test_span_filter.py similarity index 100% rename from tests/test_span_filter.py rename to tests/unit/test_span_filter.py diff --git a/tests/test_utils.py b/tests/unit/test_utils.py similarity index 100% rename from tests/test_utils.py rename to tests/unit/test_utils.py diff --git a/tests/test_version.py b/tests/unit/test_version.py similarity index 100% rename from tests/test_version.py rename to tests/unit/test_version.py From 38b5c9a7e47a62ef055523c200f8648a01681012 Mon Sep 17 00:00:00 2001 From: Hassieb Pakzad <68423100+hassiebp@users.noreply.github.com> Date: Sat, 4 Apr 2026 15:45:21 +0200 Subject: [PATCH 02/23] speed up unit test suite --- .github/workflows/ci.yml | 2 +- langfuse/_utils/prompt_cache.py | 35 ++++++++++++--------- tests/unit/test_otel.py | 42 +++++++++++++++---------- tests/unit/test_prompt.py | 30 ++++++------------ tests/unit/test_prompt_atexit.py | 20 +++++++----- tests/unit/test_propagate_attributes.py | 12 ++----- tests/unit/test_resource_manager.py | 3 ++ 7 files changed, 73 insertions(+), 71 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 201bb35a3..db65ffa01 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -89,7 +89,7 @@ jobs: - name: Run the automated tests run: | python --version - uv run --frozen pytest -n auto --dist loadfile -s -v --log-cli-level=INFO tests/unit + uv run --frozen pytest -n auto --dist worksteal -s -v --log-cli-level=INFO tests/unit e2e-tests: runs-on: ubuntu-latest diff --git a/langfuse/_utils/prompt_cache.py b/langfuse/_utils/prompt_cache.py index ef465b038..b927bc22f 100644 --- a/langfuse/_utils/prompt_cache.py +++ b/langfuse/_utils/prompt_cache.py @@ -3,7 +3,7 @@ import atexit import os from datetime import datetime -from queue import Empty, Queue +from queue import Queue from threading import Thread from typing import Callable, Dict, List, Optional, Set @@ -18,6 +18,7 @@ ) DEFAULT_PROMPT_CACHE_REFRESH_WORKERS = 1 +_SHUTDOWN_SENTINEL = object() class PromptCacheItem: @@ -46,26 +47,29 @@ def __init__(self, queue: Queue, identifier: int): def run(self) -> None: while self.running: + task = self._queue.get() + + if task is _SHUTDOWN_SENTINEL: + self._queue.task_done() + continue + + logger.debug( + f"PromptCacheRefreshConsumer processing task, {self._identifier}" + ) try: - task = self._queue.get(timeout=1) - logger.debug( - f"PromptCacheRefreshConsumer processing task, {self._identifier}" + task() + # Task failed, but we still consider it processed + except Exception as e: + logger.warning( + f"PromptCacheRefreshConsumer encountered an error, cache was not refreshed: {self._identifier}, {e}" ) - try: - task() - # Task failed, but we still consider it processed - except Exception as e: - logger.warning( - f"PromptCacheRefreshConsumer encountered an error, cache was not refreshed: {self._identifier}, {e}" - ) - self._queue.task_done() - except Empty: - pass + self._queue.task_done() def pause(self) -> None: """Pause the consumer.""" self.running = False + self._queue.put(_SHUTDOWN_SENTINEL) class PromptCacheTaskManager(object): @@ -99,6 +103,9 @@ def add_task(self, key: str, task: Callable[[], None]) -> None: def active_tasks(self) -> int: return len(self._processing_keys) + def wait_for_idle(self) -> None: + self._queue.join() + def _wrap_task(self, key: str, task: Callable[[], None]) -> Callable[[], None]: def wrapped() -> None: logger.debug(f"Refreshing prompt cache for key: {key}") diff --git a/tests/unit/test_otel.py b/tests/unit/test_otel.py index b4b985780..cd0520ad0 100644 --- a/tests/unit/test_otel.py +++ b/tests/unit/test_otel.py @@ -54,10 +54,17 @@ class TestOTelBase: @pytest.fixture(scope="function", autouse=True) def cleanup_otel(self): """Reset OpenTelemetry state between tests.""" - original_provider = trace_api.get_tracer_provider() + from opentelemetry.util._once import Once + + trace_api._TRACER_PROVIDER = None + trace_api._PROXY_TRACER_PROVIDER = trace_api.ProxyTracerProvider() + trace_api._TRACER_PROVIDER_SET_ONCE = Once() + yield - trace_api.set_tracer_provider(original_provider) LangfuseResourceManager.reset() + trace_api._TRACER_PROVIDER = None + trace_api._PROXY_TRACER_PROVIDER = trace_api.ProxyTracerProvider() + trace_api._TRACER_PROVIDER_SET_ONCE = Once() @pytest.fixture def memory_exporter(self): @@ -97,7 +104,7 @@ def mock_init(self, **kwargs): self, span_exporter=memory_exporter, max_export_batch_size=512, - schedule_delay_millis=5000, + schedule_delay_millis=1, ) monkeypatch.setattr( @@ -1870,7 +1877,7 @@ def update_random_metadata(thread_id): update = random.choice(updates) # Sleep a tiny bit to simulate work and increase chances of thread interleaving - time.sleep(random.uniform(0.001, 0.01)) + time.sleep(random.uniform(0.0005, 0.001)) # Apply the update to current_metadata (in a real system, this would update OTEL span) with metadata_lock: @@ -2001,7 +2008,7 @@ def mock_processor_init(self, **kwargs): self, span_exporter=exporter, max_export_batch_size=512, - schedule_delay_millis=5000, + schedule_delay_millis=1, ) monkeypatch.setattr( @@ -2118,7 +2125,7 @@ def create_spans_project1(): metadata={"project": "project1", "index": i}, ) # Small sleep to ensure overlap with other thread - time.sleep(0.01) + time.sleep(0.001) span.end() def create_spans_project2(): @@ -2128,7 +2135,7 @@ def create_spans_project2(): metadata={"project": "project2", "index": i}, ) # Small sleep to ensure overlap with other thread - time.sleep(0.01) + time.sleep(0.001) span.end() # Start threads @@ -2378,7 +2385,7 @@ def mock_processor_init(self, **kwargs): self, span_exporter=exporter, max_export_batch_size=512, - schedule_delay_millis=5000, + schedule_delay_millis=1, ) monkeypatch.setattr( @@ -2757,7 +2764,7 @@ async def async_task(parent_span, task_id): child_span = parent_span.start_observation(name=f"async-task-{task_id}") # Simulate async work - await asyncio.sleep(0.1) + await asyncio.sleep(0.01) # Update span with results child_span.update( @@ -2948,7 +2955,7 @@ async def test_span_metadata_updates_in_async_context( # Define async tasks that update different parts of metadata async def update_temperature(): - await asyncio.sleep(0.1) # Simulate some async work + await asyncio.sleep(0.01) # Simulate some async work main_span.update( metadata={ "llm_config": { @@ -2960,7 +2967,7 @@ async def update_temperature(): ) async def update_model(): - await asyncio.sleep(0.05) # Simulate some async work + await asyncio.sleep(0.005) # Simulate some async work main_span.update( metadata={ "llm_config": { @@ -2970,7 +2977,7 @@ async def update_model(): ) async def add_context_length(): - await asyncio.sleep(0.15) # Simulate some async work + await asyncio.sleep(0.015) # Simulate some async work main_span.update( metadata={ "llm_config": { @@ -2982,7 +2989,7 @@ async def add_context_length(): ) async def update_user_id(): - await asyncio.sleep(0.08) # Simulate some async work + await asyncio.sleep(0.008) # Simulate some async work main_span.update( metadata={ "request_info": { @@ -3047,7 +3054,7 @@ def test_metrics_and_timing(self, langfuse_client, memory_exporter): span = langfuse_client.start_observation(name="timing-test-span") # Add a small delay - time.sleep(0.1) + time.sleep(0.01) # End the span span.end() @@ -3089,10 +3096,10 @@ def test_metrics_and_timing(self, langfuse_client, memory_exporter): ) / 1_000_000_000 assert span_duration_seconds > 0, "Span duration should be positive" - # Since we slept for 0.1 seconds, the span duration should be at least 0.05 seconds + # Since we slept for 0.01 seconds, the span duration should be at least 0.005 seconds # but we'll be generous with the upper bound due to potential system delays - assert span_duration_seconds >= 0.05, ( - f"Span duration ({span_duration_seconds}s) should be at least 0.05s" + assert span_duration_seconds >= 0.005, ( + f"Span duration ({span_duration_seconds}s) should be at least 0.005s" ) @@ -3349,6 +3356,7 @@ def langfuse_client(self, monkeypatch): public_key="test-public-key", secret_key="test-secret-key", base_url="http://test-host", + tracing_enabled=False, ) return client diff --git a/tests/unit/test_prompt.py b/tests/unit/test_prompt.py index 071cff837..dca601b06 100644 --- a/tests/unit/test_prompt.py +++ b/tests/unit/test_prompt.py @@ -1,4 +1,3 @@ -from time import sleep from unittest.mock import Mock, patch import pytest @@ -139,6 +138,10 @@ def langfuse(): return langfuse_instance +def wait_for_prompt_refresh(langfuse: Langfuse) -> None: + langfuse._resources.prompt_cache._task_manager.wait_for_idle() + + def test_get_fresh_prompt(langfuse): prompt_name = "test_get_fresh_prompt" prompt = Prompt_Text( @@ -376,10 +379,7 @@ def test_get_fresh_prompt_when_expired_cache_custom_ttl(mock_time, langfuse: Lan result_call_3 = langfuse.get_prompt(prompt_name) - while True: - if langfuse._resources.prompt_cache._task_manager.active_tasks() == 0: - break - sleep(0.1) + wait_for_prompt_refresh(langfuse) assert mock_server_call.call_count == 2 assert result_call_3 == prompt_client @@ -483,10 +483,7 @@ def test_get_stale_prompt_when_expired_cache_default_ttl(mock_time, langfuse: La langfuse.get_prompt(prompt_name) langfuse.get_prompt(prompt_name) - while True: - if langfuse._resources.prompt_cache._task_manager.active_tasks() == 0: - break - sleep(0.1) + wait_for_prompt_refresh(langfuse) assert mock_server_call.call_count == 2 @@ -527,10 +524,7 @@ def test_get_fresh_prompt_when_expired_cache_default_ttl(mock_time, langfuse: La mock_time.return_value = DEFAULT_PROMPT_CACHE_TTL_SECONDS + 1 result_call_3 = langfuse.get_prompt(prompt_name) - while True: - if langfuse._resources.prompt_cache._task_manager.active_tasks() == 0: - break - sleep(0.1) + wait_for_prompt_refresh(langfuse) assert mock_server_call.call_count == 2 assert result_call_3 == prompt_client @@ -563,10 +557,7 @@ def test_get_expired_prompt_when_failing_fetch(mock_time, langfuse: Langfuse): mock_server_call.side_effect = Exception("Server error") result_call_2 = langfuse.get_prompt(prompt_name, max_retries=1) - while True: - if langfuse._resources.prompt_cache._task_manager.active_tasks() == 0: - break - sleep(0.1) + wait_for_prompt_refresh(langfuse) assert mock_server_call.call_count == 3 assert result_call_2 == prompt_client @@ -619,10 +610,7 @@ def raise_not_found(*_args: object, **_kwargs: object) -> None: ) assert stale_result == prompt_client - while True: - if langfuse._resources.prompt_cache._task_manager.active_tasks() == 0: - break - sleep(0.1) + wait_for_prompt_refresh(langfuse) assert langfuse._resources.prompt_cache.get(cache_key) is None diff --git a/tests/unit/test_prompt_atexit.py b/tests/unit/test_prompt_atexit.py index 2eac27ceb..ea64dcfe3 100644 --- a/tests/unit/test_prompt_atexit.py +++ b/tests/unit/test_prompt_atexit.py @@ -20,13 +20,15 @@ def test_prompts_atexit(): print("Adding prompt cache", PromptCache) prompt_cache = PromptCache(max_prompt_refresh_workers=10) -# example task that takes 2 seconds but we will force it to exit earlier -def wait_2_sec(): - time.sleep(2) +# example task that stays in flight briefly while the process exits +def wait_briefly(): + time.sleep(0.1) # 8 times for i in range(8): - prompt_cache.add_refresh_prompt_task(f"key_wait_2_sec_i_{i}", lambda: wait_2_sec()) + prompt_cache.add_refresh_prompt_task( + f"key_wait_briefly_i_{i}", lambda: wait_briefly() + ) """ process = subprocess.Popen( @@ -74,12 +76,14 @@ async def main(): print("Adding prompt cache", PromptCache) prompt_cache = PromptCache(max_prompt_refresh_workers=10) - # example task that takes 2 seconds but we will force it to exit earlier - def wait_2_sec(): - time.sleep(2) + # example task that stays in flight briefly while the process exits + def wait_briefly(): + time.sleep(0.1) async def add_new_prompt_refresh(i: int): - prompt_cache.add_refresh_prompt_task(f"key_wait_2_sec_i_{i}", lambda: wait_2_sec()) + prompt_cache.add_refresh_prompt_task( + f"key_wait_briefly_i_{i}", lambda: wait_briefly() + ) # 8 times tasks = [add_new_prompt_refresh(i) for i in range(8)] diff --git a/tests/unit/test_propagate_attributes.py b/tests/unit/test_propagate_attributes.py index 18b84cd52..1e8753437 100644 --- a/tests/unit/test_propagate_attributes.py +++ b/tests/unit/test_propagate_attributes.py @@ -6,7 +6,6 @@ """ import concurrent.futures -import time from datetime import datetime import pytest @@ -1460,7 +1459,7 @@ async def create_trace_with_user(user_id: str): """Create a trace with specific user_id.""" with langfuse_client.start_as_current_observation(name=f"trace-{user_id}"): with propagate_attributes(user_id=user_id): - await asyncio.sleep(0.01) # Simulate async work + await asyncio.sleep(0.001) # Simulate async work span = langfuse_client.start_observation(name=f"span-{user_id}") span.end() @@ -2305,7 +2304,7 @@ async def test_experiment_propagates_user_id_in_async_context( local_data = [{"input": "test input", "expected_output": "expected output"}] async def async_task(*, item, **kwargs): - await asyncio.sleep(0.01) + await asyncio.sleep(0.001) return f"processed: {item['input']}" with propagate_attributes(user_id="async-experiment-user"): @@ -2316,7 +2315,6 @@ async def async_task(*, item, **kwargs): ) langfuse_client.flush() - time.sleep(0.1) root_span = self.get_span_by_name(memory_exporter, "experiment-item-run") self.verify_span_attribute( @@ -2361,7 +2359,6 @@ def task_with_child_spans(*, item, **kwargs): # Flush to ensure spans are exported langfuse_client.flush() - time.sleep(0.1) # Get the root span root_spans = self.get_spans_by_name(memory_exporter, "experiment-item-run") @@ -2556,7 +2553,6 @@ def task_with_children(*, item, **kwargs): ) langfuse_client.flush() - time.sleep(0.1) # Verify root has dataset-specific attributes root_spans = self.get_spans_by_name(memory_exporter, "experiment-item-run") @@ -2684,7 +2680,6 @@ def task_with_nested_spans(*, item, **kwargs): ) langfuse_client.flush() - time.sleep(0.1) root_spans = self.get_spans_by_name(memory_exporter, "experiment-item-run") first_root = root_spans[0] @@ -2742,8 +2737,6 @@ def task_with_nested_spans(*, item, **kwargs): def test_experiment_metadata_merging(self, langfuse_client, memory_exporter): """Test that experiment metadata and item metadata are both propagated correctly.""" - import time - from langfuse._client.attributes import _serialize # Rich metadata @@ -2780,7 +2773,6 @@ def task_with_child(*, item, **kwargs): ) langfuse_client.flush() - time.sleep(0.1) # Verify root span has environment set root_span = self.get_span_by_name(memory_exporter, "experiment-item-run") diff --git a/tests/unit/test_resource_manager.py b/tests/unit/test_resource_manager.py index 02cda7e76..0afce9621 100644 --- a/tests/unit/test_resource_manager.py +++ b/tests/unit/test_resource_manager.py @@ -25,6 +25,7 @@ def should_export(span): "sample_rate": 0.8, "should_export_span": should_export, "additional_headers": {"X-Custom": "value"}, + "tracing_enabled": False, } original_client = Langfuse(**settings) @@ -62,6 +63,7 @@ def should_export_b(span): "timeout": 10, "sample_rate": 0.5, "should_export_span": should_export_a, + "tracing_enabled": False, } # Settings for client B @@ -73,6 +75,7 @@ def should_export_b(span): "timeout": 20, "sample_rate": 0.9, "should_export_span": should_export_b, + "tracing_enabled": False, } client_a = Langfuse(**settings_a) From 48fcd294de25a0ae015eb6c91bdad463b1ea1d17 Mon Sep 17 00:00:00 2001 From: Hassieb Pakzad <68423100+hassiebp@users.noreply.github.com> Date: Sat, 4 Apr 2026 16:00:34 +0200 Subject: [PATCH 03/23] speed up unit shutdown without weakening assertions --- langfuse/_task_manager/media_manager.py | 14 +++++++ .../_task_manager/media_upload_consumer.py | 1 + .../_task_manager/score_ingestion_consumer.py | 13 ++++++- tests/unit/test_otel.py | 8 ++-- tests/unit/test_prompt_atexit.py | 16 ++++---- tests/unit/test_resource_manager.py | 37 +++++++++++++++++++ 6 files changed, 76 insertions(+), 13 deletions(-) diff --git a/langfuse/_task_manager/media_manager.py b/langfuse/_task_manager/media_manager.py index d1da17a27..598f5b879 100644 --- a/langfuse/_task_manager/media_manager.py +++ b/langfuse/_task_manager/media_manager.py @@ -18,6 +18,7 @@ T = TypeVar("T") P = ParamSpec("P") +_SHUTDOWN_SENTINEL = object() class MediaManager: @@ -40,6 +41,11 @@ def __init__( def process_next_media_upload(self) -> None: try: upload_job = self._queue.get(block=True, timeout=1) + + if upload_job is _SHUTDOWN_SENTINEL: + self._queue.task_done() + return + logger.debug( f"Media: Processing upload for media_id={upload_job['media_id']} in trace_id={upload_job['trace_id']}" ) @@ -54,6 +60,14 @@ def process_next_media_upload(self) -> None: ) self._queue.task_done() + def signal_shutdown(self) -> None: + try: + self._queue.put(_SHUTDOWN_SENTINEL, block=False) + except Full: + # If the queue is full, the consumer will keep draining work and + # observe the paused flag on the next loop iteration. + pass + def _find_and_process_media( self, *, diff --git a/langfuse/_task_manager/media_upload_consumer.py b/langfuse/_task_manager/media_upload_consumer.py index b9058066b..6f69d4363 100644 --- a/langfuse/_task_manager/media_upload_consumer.py +++ b/langfuse/_task_manager/media_upload_consumer.py @@ -42,3 +42,4 @@ def pause(self) -> None: f"Thread: Pausing media upload consumer thread #{self._identifier}" ) self.running = False + self._media_manager.signal_shutdown() diff --git a/langfuse/_task_manager/score_ingestion_consumer.py b/langfuse/_task_manager/score_ingestion_consumer.py index dcb575263..ea5c2b34e 100644 --- a/langfuse/_task_manager/score_ingestion_consumer.py +++ b/langfuse/_task_manager/score_ingestion_consumer.py @@ -2,7 +2,7 @@ import os import threading import time -from queue import Empty, Queue +from queue import Empty, Full, Queue from typing import Any, List, Optional import backoff @@ -17,6 +17,7 @@ MAX_EVENT_SIZE_BYTES = int(os.environ.get("LANGFUSE_MAX_EVENT_SIZE_BYTES", 1_000_000)) MAX_BATCH_SIZE_BYTES = int(os.environ.get("LANGFUSE_MAX_BATCH_SIZE_BYTES", 2_500_000)) +_SHUTDOWN_SENTINEL = object() class ScoreIngestionMetadata(BaseModel): @@ -71,6 +72,10 @@ def _next(self) -> list: block=True, timeout=self._flush_interval - elapsed ) + if event is _SHUTDOWN_SENTINEL: + self._ingestion_queue.task_done() + break + # convert pydantic models to dicts if "body" in event and isinstance(event["body"], BaseModel): event["body"] = event["body"].model_dump(exclude_none=True) @@ -139,6 +144,12 @@ def upload(self) -> None: def pause(self) -> None: """Pause the consumer.""" self.running = False + try: + self._ingestion_queue.put(_SHUTDOWN_SENTINEL, block=False) + except Full: + # If the queue is full, the consumer will wake up naturally while + # draining items, so a dedicated shutdown signal is not required. + pass def _upload_batch(self, batch: List[Any]) -> None: logger.debug( diff --git a/tests/unit/test_otel.py b/tests/unit/test_otel.py index cd0520ad0..e7eb74280 100644 --- a/tests/unit/test_otel.py +++ b/tests/unit/test_otel.py @@ -3054,7 +3054,7 @@ def test_metrics_and_timing(self, langfuse_client, memory_exporter): span = langfuse_client.start_observation(name="timing-test-span") # Add a small delay - time.sleep(0.01) + time.sleep(0.1) # End the span span.end() @@ -3096,10 +3096,10 @@ def test_metrics_and_timing(self, langfuse_client, memory_exporter): ) / 1_000_000_000 assert span_duration_seconds > 0, "Span duration should be positive" - # Since we slept for 0.01 seconds, the span duration should be at least 0.005 seconds + # Since we slept for 0.1 seconds, the span duration should be at least 0.05 seconds # but we'll be generous with the upper bound due to potential system delays - assert span_duration_seconds >= 0.005, ( - f"Span duration ({span_duration_seconds}s) should be at least 0.005s" + assert span_duration_seconds >= 0.05, ( + f"Span duration ({span_duration_seconds}s) should be at least 0.05s" ) diff --git a/tests/unit/test_prompt_atexit.py b/tests/unit/test_prompt_atexit.py index ea64dcfe3..ccd1b0f19 100644 --- a/tests/unit/test_prompt_atexit.py +++ b/tests/unit/test_prompt_atexit.py @@ -20,14 +20,14 @@ def test_prompts_atexit(): print("Adding prompt cache", PromptCache) prompt_cache = PromptCache(max_prompt_refresh_workers=10) -# example task that stays in flight briefly while the process exits -def wait_briefly(): - time.sleep(0.1) +# example task that takes 2 seconds but we will force it to exit earlier +def wait_2_sec(): + time.sleep(2) # 8 times for i in range(8): prompt_cache.add_refresh_prompt_task( - f"key_wait_briefly_i_{i}", lambda: wait_briefly() + f"key_wait_2_sec_i_{i}", lambda: wait_2_sec() ) """ @@ -76,13 +76,13 @@ async def main(): print("Adding prompt cache", PromptCache) prompt_cache = PromptCache(max_prompt_refresh_workers=10) - # example task that stays in flight briefly while the process exits - def wait_briefly(): - time.sleep(0.1) + # example task that takes 2 seconds but we will force it to exit earlier + def wait_2_sec(): + time.sleep(2) async def add_new_prompt_refresh(i: int): prompt_cache.add_refresh_prompt_task( - f"key_wait_briefly_i_{i}", lambda: wait_briefly() + f"key_wait_2_sec_i_{i}", lambda: wait_2_sec() ) # 8 times diff --git a/tests/unit/test_resource_manager.py b/tests/unit/test_resource_manager.py index 0afce9621..fbe65ba85 100644 --- a/tests/unit/test_resource_manager.py +++ b/tests/unit/test_resource_manager.py @@ -1,8 +1,14 @@ """Test the LangfuseResourceManager and get_client() function.""" +from queue import Queue +from unittest.mock import Mock + from langfuse import Langfuse from langfuse._client.get_client import get_client from langfuse._client.resource_manager import LangfuseResourceManager +from langfuse._task_manager.media_manager import MediaManager +from langfuse._task_manager.media_upload_consumer import MediaUploadConsumer +from langfuse._task_manager.score_ingestion_consumer import ScoreIngestionConsumer def test_get_client_preserves_all_settings(monkeypatch): @@ -101,3 +107,34 @@ def should_export_b(span): client_a.shutdown() client_b.shutdown() + + +def test_score_ingestion_consumer_pause_wakes_blocked_thread(): + consumer = ScoreIngestionConsumer( + ingestion_queue=Queue(), + identifier=0, + client=Mock(), + public_key="pk-test", + flush_interval=30, + ) + + consumer.start() + consumer.pause() + consumer.join(timeout=0.5) + + assert not consumer.is_alive() + + +def test_media_upload_consumer_pause_wakes_blocked_thread(): + media_manager = MediaManager( + api_client=Mock(), + httpx_client=Mock(), + media_upload_queue=Queue(), + ) + consumer = MediaUploadConsumer(identifier=0, media_manager=media_manager) + + consumer.start() + consumer.pause() + consumer.join(timeout=0.5) + + assert not consumer.is_alive() From a20e8113de283b5e8e88ec9d99d89e04f1ca996f Mon Sep 17 00:00:00 2001 From: Hassieb Pakzad <68423100+hassiebp@users.noreply.github.com> Date: Sat, 4 Apr 2026 17:08:18 +0200 Subject: [PATCH 04/23] fix post-merge propagate attributes test --- tests/unit/test_propagate_attributes.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/unit/test_propagate_attributes.py b/tests/unit/test_propagate_attributes.py index 630ed4d15..67cd703c3 100644 --- a/tests/unit/test_propagate_attributes.py +++ b/tests/unit/test_propagate_attributes.py @@ -2492,7 +2492,6 @@ def test_experiment_id_is_stable_across_local_items( ) langfuse_client.flush() - time.sleep(0.1) root_spans = self.get_spans_by_name(memory_exporter, "experiment-item-run") experiment_ids = { From 4cd23c38c48d4cc795cf316737647ebf9e8bf6dc Mon Sep 17 00:00:00 2001 From: Hassieb Pakzad <68423100+hassiebp@users.noreply.github.com> Date: Sat, 4 Apr 2026 17:39:13 +0200 Subject: [PATCH 05/23] stabilize e2e readbacks in ci --- .github/workflows/ci.yml | 2 +- tests/e2e/test_core_sdk.py | 10 +-- tests/support/api_wrapper.py | 32 +++++---- tests/support/retry.py | 64 ++++++++++++++++++ tests/support/utils.py | 59 +++++++++++++++-- tests/unit/test_e2e_support.py | 116 +++++++++++++++++++++++++++++++++ 6 files changed, 259 insertions(+), 24 deletions(-) create mode 100644 tests/support/retry.py create mode 100644 tests/unit/test_e2e_support.py diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index d6120beb4..991c6e1e6 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -199,7 +199,7 @@ jobs: - name: Run the end-to-end tests run: | python --version - uv run --frozen pytest -s -v --log-cli-level=INFO tests/e2e + uv run --frozen pytest -n auto --dist loadfile -s -v --log-cli-level=INFO tests/e2e all-tests-passed: # This allows us to have a branch protection rule for tests and deploys with matrix diff --git a/tests/e2e/test_core_sdk.py b/tests/e2e/test_core_sdk.py index da00b1748..ece3ee3a2 100644 --- a/tests/e2e/test_core_sdk.py +++ b/tests/e2e/test_core_sdk.py @@ -152,7 +152,7 @@ def test_create_session_score(): sleep(2) # Retrieve and verify - score = langfuse.api.scores.get_by_id(score_id) + score = get_api().scores.get_by_id(score_id) # find the score by name (server may transform the id format) assert score is not None @@ -1841,18 +1841,18 @@ def test_get_observations(): def test_get_trace_not_found(): # Attempt to fetch a non-existent trace using the API with pytest.raises(Exception): - get_api().trace.get(create_uuid()) + get_api(retry=False).trace.get(create_uuid()) def test_get_observation_not_found(): # Attempt to fetch a non-existent observation using the API with pytest.raises(Exception): - get_api().legacy.observations_v1.get(create_uuid()) + get_api(retry=False).legacy.observations_v1.get(create_uuid()) def test_get_traces_empty(): # Fetch traces with a filter that should return no results - response = get_api().trace.list(name=create_uuid()) + response = get_api(retry=False).trace.list(name=create_uuid()) assert len(response.data) == 0 assert response.meta.total_items == 0 @@ -1860,7 +1860,7 @@ def test_get_traces_empty(): def test_get_observations_empty(): # Fetch observations with a filter that should return no results - response = get_api().legacy.observations_v1.get_many(name=create_uuid()) + response = get_api(retry=False).legacy.observations_v1.get_many(name=create_uuid()) assert len(response.data) == 0 assert response.meta.total_items == 0 diff --git a/tests/support/api_wrapper.py b/tests/support/api_wrapper.py index 6067e6bfa..f4d66f00a 100644 --- a/tests/support/api_wrapper.py +++ b/tests/support/api_wrapper.py @@ -1,8 +1,10 @@ import os -from time import sleep import httpx +from langfuse.api.commons.errors.not_found_error import NotFoundError +from tests.support.retry import is_not_found_payload, retry_until_ready + class LangfuseAPI: def __init__(self, username=None, password=None, base_url=None): @@ -11,28 +13,32 @@ def __init__(self, username=None, password=None, base_url=None): self.auth = (username, password) self.BASE_URL = base_url if base_url else os.environ["LANGFUSE_BASE_URL"] + def _get_json(self, url, params=None): + def _request(): + response = httpx.get(url, params=params, auth=self.auth) + payload = response.json() + + if response.status_code == 404 and is_not_found_payload(payload): + raise NotFoundError(body=payload, headers=dict(response.headers)) + + return payload + + return retry_until_ready(_request) + def get_observation(self, observation_id): - sleep(1) url = f"{self.BASE_URL}/api/public/observations/{observation_id}" - response = httpx.get(url, auth=self.auth) - return response.json() + return self._get_json(url) def get_scores(self, page=None, limit=None, user_id=None, name=None): - sleep(1) params = {"page": page, "limit": limit, "userId": user_id, "name": name} url = f"{self.BASE_URL}/api/public/scores" - response = httpx.get(url, params=params, auth=self.auth) - return response.json() + return self._get_json(url, params=params) def get_traces(self, page=None, limit=None, user_id=None, name=None): - sleep(1) params = {"page": page, "limit": limit, "userId": user_id, "name": name} url = f"{self.BASE_URL}/api/public/traces" - response = httpx.get(url, params=params, auth=self.auth) - return response.json() + return self._get_json(url, params=params) def get_trace(self, trace_id): - sleep(1) url = f"{self.BASE_URL}/api/public/traces/{trace_id}" - response = httpx.get(url, auth=self.auth) - return response.json() + return self._get_json(url) diff --git a/tests/support/retry.py b/tests/support/retry.py new file mode 100644 index 000000000..8922fd1f1 --- /dev/null +++ b/tests/support/retry.py @@ -0,0 +1,64 @@ +from __future__ import annotations + +import os +from time import monotonic, sleep +from typing import Callable, TypeVar + +from langfuse.api.commons.errors.not_found_error import NotFoundError +from langfuse.api.core.api_error import ApiError + +T = TypeVar("T") + +DEFAULT_RETRY_TIMEOUT_SECONDS = float( + os.environ.get("LANGFUSE_E2E_READ_TIMEOUT_SECONDS", "12") +) +DEFAULT_RETRY_INTERVAL_SECONDS = float( + os.environ.get("LANGFUSE_E2E_READ_INTERVAL_SECONDS", "0.25") +) + + +def is_eventual_consistency_error(error: Exception) -> bool: + if isinstance(error, NotFoundError): + return True + + if not isinstance(error, ApiError): + return False + + body = error.body + return isinstance(body, dict) and body.get("error") == "LangfuseNotFoundError" + + +def is_not_found_payload(payload: object) -> bool: + return isinstance(payload, dict) and payload.get("error") == "LangfuseNotFoundError" + + +def retry_until_ready( + operation: Callable[[], T], + *, + is_retryable_error: Callable[[Exception], bool] = is_eventual_consistency_error, + is_result_ready: Callable[[T], bool] | None = None, + timeout_seconds: float = DEFAULT_RETRY_TIMEOUT_SECONDS, + interval_seconds: float = DEFAULT_RETRY_INTERVAL_SECONDS, +) -> T: + deadline = monotonic() + timeout_seconds + last_error: Exception | None = None + + while True: + try: + result = operation() + except Exception as error: + if not is_retryable_error(error) or monotonic() >= deadline: + raise + + last_error = error + else: + if is_result_ready is None or is_result_ready(result): + return result + + if monotonic() >= deadline: + return result + + sleep(interval_seconds) + + if monotonic() >= deadline and last_error is not None: + raise last_error diff --git a/tests/support/utils.py b/tests/support/utils.py index 7d6530b45..4a9499693 100644 --- a/tests/support/utils.py +++ b/tests/support/utils.py @@ -1,23 +1,72 @@ import base64 import os -from time import sleep +from typing import Any from uuid import uuid4 from langfuse.api import LangfuseAPI +from tests.support.retry import retry_until_ready + +READ_METHOD_NAMES = {"get", "get_by_id", "get_many", "get_run", "list"} +PAGINATION_ARGUMENTS = {"limit", "page"} + + +def _has_filters(kwargs: dict[str, Any]) -> bool: + return any( + key not in PAGINATION_ARGUMENTS and value is not None + for key, value in kwargs.items() + ) + + +class _RetryingApiProxy: + def __init__(self, target: Any): + self._target = target + + def __getattr__(self, name: str) -> Any: + attr = getattr(self._target, name) + + if callable(attr): + if name not in READ_METHOD_NAMES: + return attr + + def _call(*args: Any, **kwargs: Any) -> Any: + return retry_until_ready( + lambda: attr(*args, **kwargs), + is_result_ready=_result_ready(name, kwargs), + ) + + return _call + + if isinstance(attr, (str, bytes, int, float, bool, list, dict, tuple, set)): + return attr + + if attr is None: + return None + + return _RetryingApiProxy(attr) + + +def _result_ready(method_name: str, kwargs: dict[str, Any]): + if method_name not in {"get_many", "list"} or not _has_filters(kwargs): + return None + + def _has_data(result: Any) -> bool: + data = getattr(result, "data", None) + return data is None or len(data) > 0 + + return _has_data def create_uuid(): return str(uuid4()) -def get_api(): - sleep(2) - - return LangfuseAPI( +def get_api(*, retry: bool = True): + client = LangfuseAPI( username=os.environ.get("LANGFUSE_PUBLIC_KEY"), password=os.environ.get("LANGFUSE_SECRET_KEY"), base_url=os.environ.get("LANGFUSE_BASE_URL"), ) + return _RetryingApiProxy(client) if retry else client def encode_file_to_base64(image_path) -> str: diff --git a/tests/unit/test_e2e_support.py b/tests/unit/test_e2e_support.py new file mode 100644 index 000000000..71932e8a5 --- /dev/null +++ b/tests/unit/test_e2e_support.py @@ -0,0 +1,116 @@ +from types import SimpleNamespace + +from langfuse.api.commons.errors.not_found_error import NotFoundError +from tests.support.api_wrapper import LangfuseAPI as SupportLangfuseAPI +from tests.support.utils import get_api + + +def test_get_api_retries_not_found(monkeypatch): + monkeypatch.setattr("tests.support.retry.sleep", lambda _: None) + + attempts = {"count": 0} + + class FakeTraceService: + def get(self, trace_id): + attempts["count"] += 1 + + if attempts["count"] < 3: + raise NotFoundError( + body={ + "error": "LangfuseNotFoundError", + "message": f"Trace {trace_id} not found within authorized project", + } + ) + + return {"id": trace_id} + + class FakeClient: + trace = FakeTraceService() + + monkeypatch.setattr("tests.support.utils.LangfuseAPI", lambda **_: FakeClient()) + + trace = get_api().trace.get("trace-123") + + assert trace == {"id": "trace-123"} + assert attempts["count"] == 3 + + +def test_get_api_retries_filtered_lists(monkeypatch): + monkeypatch.setattr("tests.support.retry.sleep", lambda _: None) + + attempts = {"count": 0} + + class FakeTraceService: + def list(self, **kwargs): + attempts["count"] += 1 + + if attempts["count"] < 3: + return SimpleNamespace(data=[]) + + return SimpleNamespace(data=[kwargs["name"]]) + + class FakeClient: + trace = FakeTraceService() + + monkeypatch.setattr("tests.support.utils.LangfuseAPI", lambda **_: FakeClient()) + + response = get_api().trace.list(name="ready-trace") + + assert response.data == ["ready-trace"] + assert attempts["count"] == 3 + + +def test_get_api_retry_can_be_disabled(monkeypatch): + attempts = {"count": 0} + + class FakeTraceService: + def list(self, **kwargs): + attempts["count"] += 1 + return SimpleNamespace(data=[]) + + class FakeClient: + trace = FakeTraceService() + + monkeypatch.setattr("tests.support.utils.LangfuseAPI", lambda **_: FakeClient()) + + response = get_api(retry=False).trace.list(name="missing-trace") + + assert response.data == [] + assert attempts["count"] == 1 + + +def test_raw_api_wrapper_retries_not_found_payload(monkeypatch): + monkeypatch.setattr("tests.support.retry.sleep", lambda _: None) + + attempts = {"count": 0} + + class FakeResponse: + def __init__(self, status_code, payload): + self.status_code = status_code + self._payload = payload + self.headers = {} + + def json(self): + return self._payload + + def fake_get(*args, **kwargs): + attempts["count"] += 1 + + if attempts["count"] < 3: + return FakeResponse( + 404, + { + "error": "LangfuseNotFoundError", + "message": "Trace trace-123 not found within authorized project", + }, + ) + + return FakeResponse(200, {"id": "trace-123", "observations": []}) + + monkeypatch.setattr("tests.support.api_wrapper.httpx.get", fake_get) + + api = SupportLangfuseAPI(username="user", password="pass", base_url="http://test") + trace = api.get_trace("trace-123") + + assert trace["id"] == "trace-123" + assert attempts["count"] == 3 From 3c3b26422bb056d71a09fe71b99e1b1bfb275cd8 Mon Sep 17 00:00:00 2001 From: Hassieb Pakzad <68423100+hassiebp@users.noreply.github.com> Date: Sat, 4 Apr 2026 19:53:03 +0200 Subject: [PATCH 06/23] stabilize remaining e2e checks --- .github/workflows/ci.yml | 29 +- tests/e2e/test_core_sdk.py | 125 +++++-- tests/e2e/test_decorators.py | 581 +++++++++++++++++++-------------- tests/e2e/test_experiments.py | 8 +- tests/e2e/test_media.py | 23 +- tests/support/api_wrapper.py | 108 +++++- tests/support/utils.py | 40 ++- tests/unit/test_e2e_support.py | 26 +- 8 files changed, 635 insertions(+), 305 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 991c6e1e6..3aeb32748 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -98,6 +98,8 @@ jobs: LANGFUSE_BASE_URL: "http://localhost:3000" LANGFUSE_PUBLIC_KEY: "pk-lf-1234567890" LANGFUSE_SECRET_KEY: "sk-lf-1234567890" + LANGFUSE_E2E_READ_TIMEOUT_SECONDS: "30" + LANGFUSE_E2E_READ_INTERVAL_SECONDS: "0.5" OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }} # SERPAPI_API_KEY: ${{ secrets.SERPAPI_API_KEY }} HUGGINGFACEHUB_API_TOKEN: ${{ secrets.HUGGINGFACEHUB_API_TOKEN }} @@ -106,6 +108,16 @@ jobs: name: E2E tests on Python 3.13 steps: - uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6 + - name: Install uv and set Python version + uses: astral-sh/setup-uv@cec208311dfd045dd5311c1add060b2062131d57 # v8 + with: + version: "0.11.2" + python-version: "3.13" + enable-cache: true + - name: Install the project dependencies + run: uv sync --locked + - name: Check uv Python version + run: uv run --frozen python --version - uses: pnpm/action-setup@fc06bc1257f339d1d5d8b3a19a8cae5388b55320 # v5 with: version: 10.33.0 @@ -183,23 +195,10 @@ jobs: done echo "Langfuse server is up and running!" - - name: Install uv and set Python version - uses: astral-sh/setup-uv@cec208311dfd045dd5311c1add060b2062131d57 # v8 - with: - version: "0.11.2" - python-version: "3.13" - enable-cache: true - - - name: Check Python version - run: python --version - - - name: Install the project dependencies - run: uv sync --locked - - name: Run the end-to-end tests run: | - python --version - uv run --frozen pytest -n auto --dist loadfile -s -v --log-cli-level=INFO tests/e2e + uv run --frozen python --version + uv run --frozen pytest -n 4 --dist loadfile -s -v --log-cli-level=INFO tests/e2e all-tests-passed: # This allows us to have a branch protection rule for tests and deploys with matrix diff --git a/tests/e2e/test_core_sdk.py b/tests/e2e/test_core_sdk.py index ece3ee3a2..f3cc542c5 100644 --- a/tests/e2e/test_core_sdk.py +++ b/tests/e2e/test_core_sdk.py @@ -13,6 +13,8 @@ from tests.support.utils import ( create_uuid, get_api, + wait_for_result, + wait_for_trace, ) @@ -228,7 +230,7 @@ def test_create_boolean_score(): # Ensure data is sent langfuse.flush() - sleep(2) + api_wrapper.get_trace(trace_id) # Create a boolean score score_id = create_uuid() @@ -251,10 +253,14 @@ def test_create_boolean_score(): # Ensure data is sent langfuse.flush() - sleep(2) # Retrieve and verify - trace = api_wrapper.get_trace(trace_id) + trace = api_wrapper.get_trace( + trace_id, + is_result_ready=lambda trace: any( + score["name"] == "this-is-a-score" for score in trace.get("scores", []) + ), + ) # Find the score we created by name created_score = next( @@ -283,7 +289,7 @@ def test_create_categorical_score(): # Ensure data is sent langfuse.flush() - sleep(2) + api_wrapper.get_trace(trace_id) # Create a categorical score score_id = create_uuid() @@ -305,10 +311,14 @@ def test_create_categorical_score(): # Ensure data is sent langfuse.flush() - sleep(2) # Retrieve and verify - trace = api_wrapper.get_trace(trace_id) + trace = api_wrapper.get_trace( + trace_id, + is_result_ready=lambda trace: any( + score["name"] == "this-is-a-score" for score in trace.get("scores", []) + ), + ) # Find the score we created by name created_score = next( @@ -337,7 +347,7 @@ def test_create_score_with_custom_timestamp(): # Ensure data is sent langfuse.flush() - sleep(2) + api_wrapper.get_trace(trace_id) custom_timestamp = datetime.now(timezone.utc) - timedelta(hours=1) score_id = create_uuid() @@ -352,10 +362,15 @@ def test_create_score_with_custom_timestamp(): # Ensure data is sent langfuse.flush() - sleep(2) # Retrieve and verify - trace = api_wrapper.get_trace(trace_id) + trace = api_wrapper.get_trace( + trace_id, + is_result_ready=lambda trace: any( + score["name"] == "custom-timestamp-score" + for score in trace.get("scores", []) + ), + ) # Find the score we created by name created_score = next( @@ -398,10 +413,18 @@ def test_create_trace(): # Ensure data is sent to the API langfuse.flush() - sleep(2) # Retrieve the trace from the API - trace = LangfuseAPI().get_trace(trace_id) + trace = LangfuseAPI().get_trace( + trace_id, + is_result_ready=lambda trace: ( + trace.get("name") == trace_name + and trace.get("userId") == "test" + and trace.get("metadata", {}).get("key") == "value" + and trace.get("tags") == ["tag1", "tag2"] + and trace.get("public") is True + ), + ) # Verify all trace properties assert trace["name"] == trace_name @@ -437,11 +460,20 @@ def test_create_update_trace(): # Ensure data is sent to the API langfuse.flush() - sleep(2) assert isinstance(trace_id, str) # Retrieve and verify trace - trace = get_api().trace.get(trace_id) + trace = wait_for_trace( + trace_id, + is_result_ready=lambda trace: ( + trace.name == trace_name + and trace.user_id == "test" + and trace.metadata is not None + and trace.metadata.get("key") == "value" + and trace.metadata.get("key2") == "value2" + and trace.public is True + ), + ) assert trace.name == trace_name assert trace.user_id == "test" @@ -1735,16 +1767,20 @@ def test_fetch_traces(): # Ensure data is sent langfuse.flush() - sleep(3) - # Fetch all traces with the same name - # Note: Using session_id in the query is causing a server error, - # but we keep the session_id in the trace data to ensure it's being stored correctly - all_traces = get_api().trace.list(name=name, limit=10) + expected_trace_ids = set(trace_ids) + api = get_api(retry=False) + + # Fetch all traces with the same name. + all_traces = wait_for_result( + lambda: api.trace.list(name=name, limit=10), + is_result_ready=lambda response: ( + {trace.id for trace in response.data} == expected_trace_ids + ), + ) # Verify we got all traces assert len(all_traces.data) == 3 - assert all_traces.meta.total_items == 3 # Verify trace properties for trace in all_traces.data: @@ -1753,11 +1789,19 @@ def test_fetch_traces(): assert trace.input == {"key": "value"} assert trace.output == "output-value" - # Test pagination by fetching just one trace - paginated_response = get_api().trace.list(name=name, limit=1, page=2) - assert len(paginated_response.data) == 1 - assert paginated_response.meta.total_items == 3 - assert paginated_response.meta.total_pages == 3 + # Test pagination by fetching the first three pages one at a time and + # confirming they collectively cover the created traces. + paginated_ids = set() + for page in range(1, 4): + paginated_response = wait_for_result( + lambda page=page: api.trace.list(name=name, limit=1, page=page), + is_result_ready=lambda response: ( + len(response.data) == 1 and response.data[0].id in expected_trace_ids + ), + ) + paginated_ids.add(paginated_response.data[0].id) + + assert paginated_ids == expected_trace_ids def test_get_observation(): @@ -1812,10 +1856,16 @@ def test_get_observations(): # Ensure data is sent langfuse.flush() - sleep(2) + api = get_api(retry=False) # Fetch observations using the API - observations = get_api().legacy.observations_v1.get_many(name=name, limit=10) + expected_generation_ids = {gen1_id, gen2_id} + observations = wait_for_result( + lambda: api.legacy.observations_v1.get_many(name=name, limit=10), + is_result_ready=lambda response: expected_generation_ids.issubset( + {obs.id for obs in response.data} + ), + ) # Verify fetched observations assert len(observations.data) == 2 @@ -1829,13 +1879,22 @@ def test_get_observations(): assert gen1_id in gen_ids assert gen2_id in gen_ids - # Test pagination - paginated_response = get_api().legacy.observations_v1.get_many( - name=name, limit=1, page=2 - ) - assert len(paginated_response.data) == 1 - assert paginated_response.meta.total_items == 2 # Parent span + 2 generations - assert paginated_response.meta.total_pages == 2 + # Test pagination by confirming both created generations can be reached + # across separate pages. + paginated_ids = set() + for page in range(1, 3): + paginated_response = wait_for_result( + lambda page=page: api.legacy.observations_v1.get_many( + name=name, limit=1, page=page + ), + is_result_ready=lambda response: ( + len(response.data) == 1 + and response.data[0].id in expected_generation_ids + ), + ) + paginated_ids.add(paginated_response.data[0].id) + + assert paginated_ids == expected_generation_ids def test_get_trace_not_found(): diff --git a/tests/e2e/test_decorators.py b/tests/e2e/test_decorators.py index 754d41343..ec18584a2 100644 --- a/tests/e2e/test_decorators.py +++ b/tests/e2e/test_decorators.py @@ -16,7 +16,7 @@ from langfuse._client.resource_manager import LangfuseResourceManager from langfuse.langchain import CallbackHandler from langfuse.media import LangfuseMedia -from tests.support.utils import get_api +from tests.support.utils import get_api, wait_for_trace mock_metadata = {"key": "metadata"} mock_deep_metadata = {"key": "mock_deep_metadata"} @@ -32,6 +32,32 @@ def removeMockResourceManagerInstances(): LangfuseResourceManager._instances.pop(public_key) +def _get_observation_by_name(trace_data, name): + return next( + observation + for observation in trace_data.observations + if observation.name == name + ) + + +def _is_descendant(trace_data, child_id, ancestor_id): + observations_by_id = { + observation.id: observation for observation in trace_data.observations + } + current_id = child_id + + while current_id in observations_by_id: + current = observations_by_id[current_id] + parent_id = current.parent_observation_id + if parent_id == ancestor_id: + return True + if parent_id not in observations_by_id: + return False + current_id = parent_id + + return False + + def test_nested_observations(): mock_name = "test_nested_observations" langfuse = get_client() @@ -402,39 +428,44 @@ def level_1_function(*args, **kwargs): langfuse.flush() - trace_data = get_api().trace.get(mock_trace_id) + trace_data = wait_for_trace( + mock_trace_id, + is_result_ready=lambda trace: ( + trace.session_id == mock_session_id + and trace.name == mock_name + and { + "level_1_function", + "level_2_function", + "level_3_function", + "langchain_operations", + "ChatPromptTemplate", + }.issubset({observation.name for observation in trace.observations}) + ), + ) assert len(trace_data.observations) > 2 - # Check correct nesting - adjacencies = defaultdict(list) - for o in trace_data.observations: - adjacencies[o.parent_observation_id].append(o) - - assert len(adjacencies) > 2 - # trace parameters if set anywhere in the call stack assert trace_data.session_id == mock_session_id assert trace_data.name == mock_name - # Check that the langchain_operations is at the correct level - level_1_observation = next( - o - for o in trace_data.observations - if o.parent_observation_id not in [o.id for o in trace_data.observations] - ) - level_2_observation = adjacencies[level_1_observation.id][0] - level_3_observation = adjacencies[level_2_observation.id][0] - langchain_observation = adjacencies[level_3_observation.id][0] + level_1_observation = _get_observation_by_name(trace_data, "level_1_function") + level_2_observation = _get_observation_by_name(trace_data, "level_2_function") + level_3_observation = _get_observation_by_name(trace_data, "level_3_function") + langchain_observation = _get_observation_by_name(trace_data, "langchain_operations") + prompt_observation = _get_observation_by_name(trace_data, "ChatPromptTemplate") assert level_1_observation.name == "level_1_function" + assert _is_descendant(trace_data, level_2_observation.id, level_1_observation.id) assert level_2_observation.name == "level_2_function" assert level_2_observation.metadata["key"] == mock_metadata["key"] + assert _is_descendant(trace_data, level_3_observation.id, level_2_observation.id) assert level_3_observation.name == "level_3_function" assert level_3_observation.metadata["key"] == mock_deep_metadata["key"] + assert _is_descendant(trace_data, langchain_observation.id, level_3_observation.id) assert langchain_observation.name == "langchain_operations" # Check that LangChain components are captured - assert any([o.name == "ChatPromptTemplate" for o in trace_data.observations]) + assert _is_descendant(trace_data, prompt_observation.id, langchain_observation.id) def test_get_current_trace_url(): @@ -488,26 +519,37 @@ def level_1_function(*args, **kwargs): *mock_args, **mock_kwargs, langfuse_trace_id=mock_trace_id ) langfuse.flush() - sleep(1) assert result == "level_3" # Wrapped function returns correctly # ID setting for span or trace - trace_data = get_api().trace.get(mock_trace_id) + trace_data = wait_for_trace( + mock_trace_id, + is_result_ready=lambda trace: { + "test-observation-score", + "test-trace-score", + "another-test-trace-score", + }.issubset({score.name for score in trace.scores}), + ) assert ( len(trace_data.observations) == 3 ) # Top-most function is trace, so it's not an observations assert trace_data.name == mock_name # Check for correct scoring - scores = trace_data.scores + scores_by_name = defaultdict(list) + for score in trace_data.scores: + scores_by_name[score.name].append(score) - assert len(scores) == 3 + assert len(scores_by_name["test-trace-score"]) == 1 + assert len(scores_by_name["another-test-trace-score"]) == 1 + assert len(scores_by_name["test-observation-score"]) == 1 trace_scores = [ - s for s in scores if s.trace_id == mock_trace_id and s.observation_id is None + scores_by_name["test-trace-score"][0], + scores_by_name["another-test-trace-score"][0], ] - observation_score = [s for s in scores if s.observation_id is not None][0] + observation_score = scores_by_name["test-observation-score"][0] assert any( [ @@ -861,27 +903,29 @@ async def level_1_function(*args, **kwargs): assert result == "level_1" # Wrapped function returns correctly # ID setting for span or trace - trace_data = get_api().trace.get(mock_trace_id) - assert len(trace_data.observations) == 3 + trace_data = wait_for_trace( + mock_trace_id, + is_result_ready=lambda trace: ( + trace.session_id == mock_session_id + and trace.name == mock_name + and { + "level_1_function", + "level_2_function", + "OpenAI-generation", + }.issubset({observation.name for observation in trace.observations}) + ), + ) + assert len(trace_data.observations) >= 3 # trace parameters if set anywhere in the call stack assert trace_data.session_id == mock_session_id assert trace_data.name == mock_name - # Check correct nesting - adjacencies = defaultdict(list) - for o in trace_data.observations: - adjacencies[o.parent_observation_id or o.trace_id].append(o) - - assert len(adjacencies) == 3 - - level_1_observation = next( - o - for o in trace_data.observations - if o.parent_observation_id not in [o.id for o in trace_data.observations] - ) - level_2_observation = adjacencies[level_1_observation.id][0] - level_3_observation = adjacencies[level_2_observation.id][0] + level_1_observation = _get_observation_by_name(trace_data, "level_1_function") + level_2_observation = _get_observation_by_name(trace_data, "level_2_function") + level_3_observation = _get_observation_by_name(trace_data, "OpenAI-generation") + assert _is_descendant(trace_data, level_2_observation.id, level_1_observation.id) + assert _is_descendant(trace_data, level_3_observation.id, level_2_observation.id) assert level_2_observation.metadata["key"] == mock_metadata["key"] @@ -1008,8 +1052,14 @@ def function(): assert result == mock_output - trace_data = get_api().trace.get(mock_trace_id) - assert trace_data.observations[0].output == mock_output + trace_data = wait_for_trace( + mock_trace_id, + is_result_ready=lambda trace: any( + observation.name == "function" and observation.output == mock_output + for observation in trace.observations + ), + ) + assert _get_observation_by_name(trace_data, "function").output == mock_output def test_media(): @@ -1049,7 +1099,21 @@ def main(): langfuse.flush() - trace_data = get_api().trace.get(mock_trace_id) + trace_data = wait_for_trace( + mock_trace_id, + is_result_ready=lambda trace: ( + "@@@langfuseMedia:type=application/pdf|id=" + in (trace.input or {}).get("context", {}).get("nested", "") + and "@@@langfuseMedia:type=application/pdf|id=" + in (trace.output or {}).get("context", {}).get("nested", "") + and any( + "@@@langfuseMedia:type=application/pdf|id=" + in observation.metadata.get("context", {}).get("nested", "") + for observation in trace.observations + if observation.metadata + ) + ), + ) assert ( "@@@langfuseMedia:type=application/pdf|id=" @@ -1091,7 +1155,15 @@ def main(): langfuse.flush() - trace_data = get_api().trace.get(mock_trace_id) + trace_data = wait_for_trace( + mock_trace_id, + is_result_ready=lambda trace: ( + trace.metadata is not None + and trace.metadata.get("key1") == "value1" + and trace.metadata.get("key2") == "value2" + and trace.tags == ["tag1", "tag2"] + ), + ) assert trace_data.metadata["key1"] == "value1" assert trace_data.metadata["key2"] == "value2" @@ -1105,124 +1177,141 @@ def test_multiproject_context_propagation_basic(): client1 = Langfuse() # Reads from environment Langfuse(public_key="pk-test-project2", secret_key="sk-test-project2") - # Verify both instances are registered - assert len(LangfuseResourceManager._instances) == 2 + try: + # Verify both instances are registered + assert len(LangfuseResourceManager._instances) == 2 - mock_name = "test_multiproject_context_propagation_basic" - # Use known public key from environment - env_public_key = os.environ[LANGFUSE_PUBLIC_KEY] - # In multi-project setup, must specify which client to use - langfuse = get_client(public_key=env_public_key) - mock_trace_id = langfuse.create_trace_id() + mock_name = "test_multiproject_context_propagation_basic" + # Use known public key from environment + env_public_key = os.environ[LANGFUSE_PUBLIC_KEY] + # In multi-project setup, must specify which client to use + langfuse = get_client(public_key=env_public_key) + mock_trace_id = langfuse.create_trace_id() - @observe(as_type="generation", capture_output=False) - def level_3_function(): - # This function should inherit the public key from level_1_function - # and NOT need langfuse_public_key parameter - langfuse_client = get_client() - langfuse_client.update_current_generation(metadata={"level": "3"}) - with propagate_attributes(trace_name=mock_name): - pass - return "level_3" - - @observe() - def level_2_function(): - # This function should also inherit the public key - level_3_function() - langfuse_client = get_client() - langfuse_client.update_current_span(metadata={"level": "2"}) - return "level_2" + @observe(as_type="generation", capture_output=False) + def level_3_function(): + # This function should inherit the public key from level_1_function + # and NOT need langfuse_public_key parameter + langfuse_client = get_client() + langfuse_client.update_current_generation(metadata={"level": "3"}) + with propagate_attributes(trace_name=mock_name): + pass + return "level_3" - @observe() - def level_1_function(*args, **kwargs): - # Only this top-level function receives langfuse_public_key - level_2_function() - langfuse_client = get_client() - langfuse_client.update_current_span(metadata={"level": "1"}) - return "level_1" + @observe() + def level_2_function(): + # This function should also inherit the public key + level_3_function() + langfuse_client = get_client() + langfuse_client.update_current_span(metadata={"level": "2"}) + return "level_2" - result = level_1_function( - *mock_args, - **mock_kwargs, - langfuse_trace_id=mock_trace_id, - langfuse_public_key=env_public_key, # Only provided to top-level function - ) + @observe() + def level_1_function(*args, **kwargs): + # Only this top-level function receives langfuse_public_key + level_2_function() + langfuse_client = get_client() + langfuse_client.update_current_span(metadata={"level": "1"}) + return "level_1" - # Use the correct client for flushing - client1.flush() + result = level_1_function( + *mock_args, + **mock_kwargs, + langfuse_trace_id=mock_trace_id, + langfuse_public_key=env_public_key, # Only provided to top-level function + ) - assert result == "level_1" + # Use the correct client for flushing + client1.flush() - # Verify trace was created properly - trace_data = get_api().trace.get(mock_trace_id) - assert len(trace_data.observations) == 3 - assert trace_data.name == mock_name + assert result == "level_1" - # Reset instances to not leak to other test suites - removeMockResourceManagerInstances() + # Verify trace was created properly + trace_data = wait_for_trace( + mock_trace_id, + is_result_ready=lambda trace: ( + trace.name == mock_name and len(trace.observations) == 3 + ), + ) + assert len(trace_data.observations) == 3 + assert trace_data.name == mock_name + finally: + removeMockResourceManagerInstances() def test_multiproject_context_propagation_deep_nesting(): client1 = Langfuse() # Reads from environment Langfuse(public_key="pk-test-project2", secret_key="sk-test-project2") - # Verify both instances are registered - assert len(LangfuseResourceManager._instances) == 2 - - mock_name = "test_multiproject_context_propagation_deep_nesting" - env_public_key = os.environ[LANGFUSE_PUBLIC_KEY] - langfuse = get_client(public_key=env_public_key) - mock_trace_id = langfuse.create_trace_id() - - @observe(as_type="generation") - def level_4_function(): - langfuse_client = get_client() - langfuse_client.update_current_generation(metadata={"level": "4"}) - return "level_4" + try: + # Verify both instances are registered + assert len(LangfuseResourceManager._instances) == 2 - @observe() - def level_3_function(): - result = level_4_function() - langfuse_client = get_client() - langfuse_client.update_current_span(metadata={"level": "3"}) - return result + mock_name = "test_multiproject_context_propagation_deep_nesting" + env_public_key = os.environ[LANGFUSE_PUBLIC_KEY] + langfuse = get_client(public_key=env_public_key) + mock_trace_id = langfuse.create_trace_id() - @observe() - def level_2_function(): - result = level_3_function() - langfuse_client = get_client() - langfuse_client.update_current_span(metadata={"level": "2"}) - return result + @observe(as_type="generation") + def level_4_function(): + langfuse_client = get_client() + langfuse_client.update_current_generation(metadata={"level": "4"}) + return "level_4" - @observe() - def level_1_function(*args, **kwargs): - with propagate_attributes(trace_name=mock_name): - result = level_2_function() + @observe() + def level_3_function(): + result = level_4_function() langfuse_client = get_client() - langfuse_client.update_current_span(metadata={"level": "1"}) + langfuse_client.update_current_span(metadata={"level": "3"}) return result - result = level_1_function( - langfuse_trace_id=mock_trace_id, langfuse_public_key=env_public_key - ) - client1.flush() - - assert result == "level_4" - - trace_data = get_api().trace.get(mock_trace_id) - assert len(trace_data.observations) == 4 - assert trace_data.name == mock_name + @observe() + def level_2_function(): + result = level_3_function() + langfuse_client = get_client() + langfuse_client.update_current_span(metadata={"level": "2"}) + return result - # Verify all levels were captured - levels = [ - str(obs.metadata.get("level")) - for obs in trace_data.observations - if obs.metadata - ] - assert set(levels) == {"1", "2", "3", "4"} + @observe() + def level_1_function(*args, **kwargs): + with propagate_attributes(trace_name=mock_name): + result = level_2_function() + langfuse_client = get_client() + langfuse_client.update_current_span(metadata={"level": "1"}) + return result + + result = level_1_function( + langfuse_trace_id=mock_trace_id, langfuse_public_key=env_public_key + ) + client1.flush() + + assert result == "level_4" + + trace_data = wait_for_trace( + mock_trace_id, + is_result_ready=lambda trace: ( + trace.name == mock_name + and len(trace.observations) == 4 + and {"1", "2", "3", "4"} + == { + str(observation.metadata.get("level")) + for observation in trace.observations + if observation.metadata + } + ), + ) + assert len(trace_data.observations) == 4 + assert trace_data.name == mock_name - # Reset instances to not leak to other test suites - removeMockResourceManagerInstances() + # Verify all levels were captured + levels = [ + str(obs.metadata.get("level")) + for obs in trace_data.observations + if obs.metadata + ] + assert set(levels) == {"1", "2", "3", "4"} + finally: + removeMockResourceManagerInstances() def test_multiproject_context_propagation_override(): @@ -1230,52 +1319,59 @@ def test_multiproject_context_propagation_override(): client1 = Langfuse() # Reads from environment client2 = Langfuse(public_key="pk-test-project2", secret_key="sk-test-project2") - # Verify both instances are registered - assert len(LangfuseResourceManager._instances) == 2 - - mock_name = "test_multiproject_context_propagation_override" - env_public_key = os.environ[LANGFUSE_PUBLIC_KEY] - langfuse = get_client(public_key=env_public_key) - mock_trace_id = langfuse.create_trace_id() - - primary_public_key = env_public_key - override_public_key = "pk-test-project2" - - @observe(as_type="generation") - def level_3_function(): - # This function explicitly overrides the inherited public key - langfuse_client = get_client(public_key=override_public_key) - langfuse_client.update_current_generation(metadata={"used_override": "true"}) - return "level_3" + try: + # Verify both instances are registered + assert len(LangfuseResourceManager._instances) == 2 + + mock_name = "test_multiproject_context_propagation_override" + env_public_key = os.environ[LANGFUSE_PUBLIC_KEY] + langfuse = get_client(public_key=env_public_key) + mock_trace_id = langfuse.create_trace_id() + + primary_public_key = env_public_key + override_public_key = "pk-test-project2" + + @observe(as_type="generation") + def level_3_function(): + # This function explicitly overrides the inherited public key + langfuse_client = get_client(public_key=override_public_key) + langfuse_client.update_current_generation( + metadata={"used_override": "true"} + ) + return "level_3" - @observe() - def level_2_function(): - # This function should use the overridden key when calling level_3 - level_3_function(langfuse_public_key=override_public_key) - langfuse_client = get_client(public_key=primary_public_key) - langfuse_client.update_current_span(metadata={"level": "2"}) - return "level_2" + @observe() + def level_2_function(): + # This function should use the overridden key when calling level_3 + level_3_function(langfuse_public_key=override_public_key) + langfuse_client = get_client(public_key=primary_public_key) + langfuse_client.update_current_span(metadata={"level": "2"}) + return "level_2" - @observe() - def level_1_function(*args, **kwargs): - with propagate_attributes(trace_name=mock_name): - level_2_function() - return "level_1" + @observe() + def level_1_function(*args, **kwargs): + with propagate_attributes(trace_name=mock_name): + level_2_function() + return "level_1" - result = level_1_function( - langfuse_trace_id=mock_trace_id, langfuse_public_key=primary_public_key - ) - client1.flush() - client2.flush() + result = level_1_function( + langfuse_trace_id=mock_trace_id, langfuse_public_key=primary_public_key + ) + client1.flush() + client2.flush() - assert result == "level_1" + assert result == "level_1" - trace_data = get_api().trace.get(mock_trace_id) - assert len(trace_data.observations) == 2 - assert trace_data.name == mock_name - - # Reset instances to not leak to other test suites - removeMockResourceManagerInstances() + trace_data = wait_for_trace( + mock_trace_id, + is_result_ready=lambda trace: ( + trace.name == mock_name and len(trace.observations) == 2 + ), + ) + assert len(trace_data.observations) == 2 + assert trace_data.name == mock_name + finally: + removeMockResourceManagerInstances() def test_multiproject_context_propagation_no_public_key(): @@ -1339,68 +1435,79 @@ async def test_multiproject_async_context_propagation_basic(): client1 = Langfuse() # Reads from environment Langfuse(public_key="pk-test-project2", secret_key="sk-test-project2") - # Verify both instances are registered - assert len(LangfuseResourceManager._instances) == 2 - - mock_name = "test_multiproject_async_context_propagation_basic" - env_public_key = os.environ[LANGFUSE_PUBLIC_KEY] - langfuse = get_client(public_key=env_public_key) - mock_trace_id = langfuse.create_trace_id() - - @observe(as_type="generation", capture_output=False) - async def async_level_3_function(): - # This function should inherit the public key from level_1_function - # and NOT need langfuse_public_key parameter - await asyncio.sleep(0.01) # Simulate async work - langfuse_client = get_client() - langfuse_client.update_current_generation( - metadata={"level": "3", "async": True} - ) - with propagate_attributes(trace_name=mock_name): - pass - return "async_level_3" + try: + # Verify both instances are registered + assert len(LangfuseResourceManager._instances) == 2 - @observe() - async def async_level_2_function(): - # This function should also inherit the public key - result = await async_level_3_function() - langfuse_client = get_client() - langfuse_client.update_current_span(metadata={"level": "2", "async": True}) - return result + mock_name = "test_multiproject_async_context_propagation_basic" + env_public_key = os.environ[LANGFUSE_PUBLIC_KEY] + langfuse = get_client(public_key=env_public_key) + mock_trace_id = langfuse.create_trace_id() - @observe() - async def async_level_1_function(*args, **kwargs): - # Only this top-level function receives langfuse_public_key - result = await async_level_2_function() - langfuse_client = get_client() - langfuse_client.update_current_span(metadata={"level": "1", "async": True}) - return result - - result = await async_level_1_function( - *mock_args, - **mock_kwargs, - langfuse_trace_id=mock_trace_id, - langfuse_public_key=env_public_key, # Only provided to top-level function - ) + @observe(as_type="generation", capture_output=False) + async def async_level_3_function(): + # This function should inherit the public key from level_1_function + # and NOT need langfuse_public_key parameter + await asyncio.sleep(0.01) # Simulate async work + langfuse_client = get_client() + langfuse_client.update_current_generation( + metadata={"level": "3", "async": True} + ) + with propagate_attributes(trace_name=mock_name): + pass + return "async_level_3" - # Use the correct client for flushing - client1.flush() + @observe() + async def async_level_2_function(): + # This function should also inherit the public key + result = await async_level_3_function() + langfuse_client = get_client() + langfuse_client.update_current_span(metadata={"level": "2", "async": True}) + return result - assert result == "async_level_3" + @observe() + async def async_level_1_function(*args, **kwargs): + # Only this top-level function receives langfuse_public_key + result = await async_level_2_function() + langfuse_client = get_client() + langfuse_client.update_current_span(metadata={"level": "1", "async": True}) + return result - # Verify trace was created properly - trace_data = get_api().trace.get(mock_trace_id) - assert len(trace_data.observations) == 3 - assert trace_data.name == mock_name + result = await async_level_1_function( + *mock_args, + **mock_kwargs, + langfuse_trace_id=mock_trace_id, + langfuse_public_key=env_public_key, # Only provided to top-level function + ) - # Verify all observations have async metadata - async_flags = [ - obs.metadata.get("async") for obs in trace_data.observations if obs.metadata - ] - assert all(async_flags) + # Use the correct client for flushing + client1.flush() + + assert result == "async_level_3" + + # Verify trace was created properly + trace_data = wait_for_trace( + mock_trace_id, + is_result_ready=lambda trace: ( + trace.name == mock_name + and len(trace.observations) == 3 + and all( + observation.metadata.get("async") + for observation in trace.observations + if observation.metadata + ) + ), + ) + assert len(trace_data.observations) == 3 + assert trace_data.name == mock_name - # Reset instances to not leak to other test suites - removeMockResourceManagerInstances() + # Verify all observations have async metadata + async_flags = [ + obs.metadata.get("async") for obs in trace_data.observations if obs.metadata + ] + assert all(async_flags) + finally: + removeMockResourceManagerInstances() @pytest.mark.asyncio diff --git a/tests/e2e/test_experiments.py b/tests/e2e/test_experiments.py index cd17e80e0..5f18d9c0b 100644 --- a/tests/e2e/test_experiments.py +++ b/tests/e2e/test_experiments.py @@ -12,7 +12,7 @@ ExperimentItem, ExperimentItemResult, ) -from tests.support.utils import create_uuid, get_api +from tests.support.utils import create_uuid, get_api, wait_for_trace @pytest.fixture @@ -786,13 +786,15 @@ def mock_task_with_boolean_results(*, item: ExperimentItem, **kwargs): time.sleep(3) # Verify scores are persisted via API with correct data types - api = get_api() for i, item_result in enumerate(result.item_results): trace_id = item_result.trace_id assert trace_id is not None, f"Item {i} should have a trace_id" # Fetch trace from API to verify score persistence - trace = api.trace.get(trace_id) + trace = wait_for_trace( + trace_id, + is_result_ready=lambda trace: len(trace.scores) > 0, + ) assert trace is not None, f"Trace {trace_id} should exist" for score in trace.scores: diff --git a/tests/e2e/test_media.py b/tests/e2e/test_media.py index b262c0a86..d322e1788 100644 --- a/tests/e2e/test_media.py +++ b/tests/e2e/test_media.py @@ -4,7 +4,7 @@ from langfuse._client.client import Langfuse from langfuse.media import LangfuseMedia -from tests.support.utils import get_api +from tests.support.utils import wait_for_trace def test_replace_media_reference_string_in_object(): @@ -30,7 +30,17 @@ def test_replace_media_reference_string_in_object(): langfuse.flush() - fetched_trace = get_api().trace.get(span.trace_id) + fetched_trace = wait_for_trace( + span.trace_id, + is_result_ready=lambda trace: ( + bool(trace.observations) + and re.match( + r"^@@@langfuseMedia:type=audio/wav\|id=.+\|source=base64_data_uri@@@$", + trace.observations[0].metadata.get("context", {}).get("nested", ""), + ) + is not None + ), + ) media_ref = fetched_trace.observations[0].metadata["context"]["nested"] assert re.match( r"^@@@langfuseMedia:type=audio/wav\|id=.+\|source=base64_data_uri@@@$", @@ -51,7 +61,14 @@ def test_replace_media_reference_string_in_object(): langfuse.flush() - fetched_trace2 = get_api().trace.get(span2.trace_id) + fetched_trace2 = wait_for_trace( + span2.trace_id, + is_result_ready=lambda trace: ( + bool(trace.observations) + and trace.observations[0].metadata.get("context", {}).get("nested") + == fetched_trace.observations[0].metadata["context"]["nested"] + ), + ) assert ( fetched_trace2.observations[0].metadata["context"]["nested"] == fetched_trace.observations[0].metadata["context"]["nested"] diff --git a/tests/support/api_wrapper.py b/tests/support/api_wrapper.py index f4d66f00a..c4519252f 100644 --- a/tests/support/api_wrapper.py +++ b/tests/support/api_wrapper.py @@ -3,7 +3,12 @@ import httpx from langfuse.api.commons.errors.not_found_error import NotFoundError -from tests.support.retry import is_not_found_payload, retry_until_ready +from tests.support.retry import ( + DEFAULT_RETRY_INTERVAL_SECONDS, + DEFAULT_RETRY_TIMEOUT_SECONDS, + is_not_found_payload, + retry_until_ready, +) class LangfuseAPI: @@ -13,7 +18,16 @@ def __init__(self, username=None, password=None, base_url=None): self.auth = (username, password) self.BASE_URL = base_url if base_url else os.environ["LANGFUSE_BASE_URL"] - def _get_json(self, url, params=None): + def _get_json( + self, + url, + params=None, + *, + retry=True, + is_result_ready=None, + timeout_seconds=DEFAULT_RETRY_TIMEOUT_SECONDS, + interval_seconds=DEFAULT_RETRY_INTERVAL_SECONDS, + ): def _request(): response = httpx.get(url, params=params, auth=self.auth) payload = response.json() @@ -23,22 +37,94 @@ def _request(): return payload - return retry_until_ready(_request) + if not retry: + return _request() - def get_observation(self, observation_id): + return retry_until_ready( + _request, + is_result_ready=is_result_ready, + timeout_seconds=timeout_seconds, + interval_seconds=interval_seconds, + ) + + def get_observation( + self, + observation_id, + *, + retry=True, + is_result_ready=None, + timeout_seconds=DEFAULT_RETRY_TIMEOUT_SECONDS, + interval_seconds=DEFAULT_RETRY_INTERVAL_SECONDS, + ): url = f"{self.BASE_URL}/api/public/observations/{observation_id}" - return self._get_json(url) + return self._get_json( + url, + retry=retry, + is_result_ready=is_result_ready, + timeout_seconds=timeout_seconds, + interval_seconds=interval_seconds, + ) - def get_scores(self, page=None, limit=None, user_id=None, name=None): + def get_scores( + self, + page=None, + limit=None, + user_id=None, + name=None, + *, + retry=True, + is_result_ready=None, + timeout_seconds=DEFAULT_RETRY_TIMEOUT_SECONDS, + interval_seconds=DEFAULT_RETRY_INTERVAL_SECONDS, + ): params = {"page": page, "limit": limit, "userId": user_id, "name": name} url = f"{self.BASE_URL}/api/public/scores" - return self._get_json(url, params=params) + return self._get_json( + url, + params=params, + retry=retry, + is_result_ready=is_result_ready, + timeout_seconds=timeout_seconds, + interval_seconds=interval_seconds, + ) - def get_traces(self, page=None, limit=None, user_id=None, name=None): + def get_traces( + self, + page=None, + limit=None, + user_id=None, + name=None, + *, + retry=True, + is_result_ready=None, + timeout_seconds=DEFAULT_RETRY_TIMEOUT_SECONDS, + interval_seconds=DEFAULT_RETRY_INTERVAL_SECONDS, + ): params = {"page": page, "limit": limit, "userId": user_id, "name": name} url = f"{self.BASE_URL}/api/public/traces" - return self._get_json(url, params=params) + return self._get_json( + url, + params=params, + retry=retry, + is_result_ready=is_result_ready, + timeout_seconds=timeout_seconds, + interval_seconds=interval_seconds, + ) - def get_trace(self, trace_id): + def get_trace( + self, + trace_id, + *, + retry=True, + is_result_ready=None, + timeout_seconds=DEFAULT_RETRY_TIMEOUT_SECONDS, + interval_seconds=DEFAULT_RETRY_INTERVAL_SECONDS, + ): url = f"{self.BASE_URL}/api/public/traces/{trace_id}" - return self._get_json(url) + return self._get_json( + url, + retry=retry, + is_result_ready=is_result_ready, + timeout_seconds=timeout_seconds, + interval_seconds=interval_seconds, + ) diff --git a/tests/support/utils.py b/tests/support/utils.py index 4a9499693..a29274d3a 100644 --- a/tests/support/utils.py +++ b/tests/support/utils.py @@ -1,13 +1,18 @@ import base64 import os -from typing import Any +from typing import Any, Callable, TypeVar from uuid import uuid4 from langfuse.api import LangfuseAPI -from tests.support.retry import retry_until_ready +from tests.support.retry import ( + DEFAULT_RETRY_INTERVAL_SECONDS, + DEFAULT_RETRY_TIMEOUT_SECONDS, + retry_until_ready, +) READ_METHOD_NAMES = {"get", "get_by_id", "get_many", "get_run", "list"} PAGINATION_ARGUMENTS = {"limit", "page"} +T = TypeVar("T") def _has_filters(kwargs: dict[str, Any]) -> bool: @@ -69,6 +74,37 @@ def get_api(*, retry: bool = True): return _RetryingApiProxy(client) if retry else client +def wait_for_result( + operation: Callable[[], T], + *, + is_result_ready: Callable[[T], bool] | None = None, + timeout_seconds: float = DEFAULT_RETRY_TIMEOUT_SECONDS, + interval_seconds: float = DEFAULT_RETRY_INTERVAL_SECONDS, +) -> T: + return retry_until_ready( + operation, + is_result_ready=is_result_ready, + timeout_seconds=timeout_seconds, + interval_seconds=interval_seconds, + ) + + +def wait_for_trace( + trace_id: str, + *, + is_result_ready: Callable[[Any], bool] | None = None, + timeout_seconds: float = DEFAULT_RETRY_TIMEOUT_SECONDS, + interval_seconds: float = DEFAULT_RETRY_INTERVAL_SECONDS, +): + api = get_api(retry=False) + return wait_for_result( + lambda: api.trace.get(trace_id), + is_result_ready=is_result_ready, + timeout_seconds=timeout_seconds, + interval_seconds=interval_seconds, + ) + + def encode_file_to_base64(image_path) -> str: with open(image_path, "rb") as file: return base64.b64encode(file.read()).decode("utf-8") diff --git a/tests/unit/test_e2e_support.py b/tests/unit/test_e2e_support.py index 71932e8a5..79d7d5891 100644 --- a/tests/unit/test_e2e_support.py +++ b/tests/unit/test_e2e_support.py @@ -2,7 +2,7 @@ from langfuse.api.commons.errors.not_found_error import NotFoundError from tests.support.api_wrapper import LangfuseAPI as SupportLangfuseAPI -from tests.support.utils import get_api +from tests.support.utils import get_api, wait_for_trace def test_get_api_retries_not_found(monkeypatch): @@ -114,3 +114,27 @@ def fake_get(*args, **kwargs): assert trace["id"] == "trace-123" assert attempts["count"] == 3 + + +def test_wait_for_trace_retries_until_predicate_matches(monkeypatch): + monkeypatch.setattr("tests.support.retry.sleep", lambda _: None) + + attempts = {"count": 0} + + class FakeTraceService: + def get(self, trace_id): + attempts["count"] += 1 + return {"id": trace_id, "observations": [1] * attempts["count"]} + + class FakeClient: + trace = FakeTraceService() + + monkeypatch.setattr("tests.support.utils.LangfuseAPI", lambda **_: FakeClient()) + + trace = wait_for_trace( + "trace-123", is_result_ready=lambda trace: len(trace["observations"]) == 3 + ) + + assert trace["id"] == "trace-123" + assert len(trace["observations"]) == 3 + assert attempts["count"] == 3 From 78395e7c8530123264455a166ca6118a4f2bb325 Mon Sep 17 00:00:00 2001 From: Hassieb Pakzad <68423100+hassiebp@users.noreply.github.com> Date: Sat, 4 Apr 2026 20:02:10 +0200 Subject: [PATCH 07/23] reduce e2e ci load --- .github/workflows/ci.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 3aeb32748..54ec1d7cc 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -98,7 +98,7 @@ jobs: LANGFUSE_BASE_URL: "http://localhost:3000" LANGFUSE_PUBLIC_KEY: "pk-lf-1234567890" LANGFUSE_SECRET_KEY: "sk-lf-1234567890" - LANGFUSE_E2E_READ_TIMEOUT_SECONDS: "30" + LANGFUSE_E2E_READ_TIMEOUT_SECONDS: "60" LANGFUSE_E2E_READ_INTERVAL_SECONDS: "0.5" OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }} # SERPAPI_API_KEY: ${{ secrets.SERPAPI_API_KEY }} @@ -198,7 +198,7 @@ jobs: - name: Run the end-to-end tests run: | uv run --frozen python --version - uv run --frozen pytest -n 4 --dist loadfile -s -v --log-cli-level=INFO tests/e2e + uv run --frozen pytest -n 2 --dist loadfile -s -v --log-cli-level=INFO tests/e2e all-tests-passed: # This allows us to have a branch protection rule for tests and deploys with matrix From b5b0bc6d7c2c8d7d8a3fd13a114973f4ef1092d5 Mon Sep 17 00:00:00 2001 From: Hassieb Pakzad <68423100+hassiebp@users.noreply.github.com> Date: Sat, 4 Apr 2026 20:07:51 +0200 Subject: [PATCH 08/23] speed up langfuse server startup in ci --- .github/workflows/ci.yml | 29 ++++++++++------------------- 1 file changed, 10 insertions(+), 19 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 54ec1d7cc..25622546e 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -124,45 +124,36 @@ jobs: - name: Clone langfuse server run: | - git clone https://github.com/langfuse/langfuse.git ./langfuse-server && echo $(cd ./langfuse-server && git rev-parse HEAD) + git clone --depth 1 --filter=blob:none --single-branch https://github.com/langfuse/langfuse.git ./langfuse-server + echo "$(cd ./langfuse-server && git rev-parse HEAD)" - name: Setup node (for langfuse server) uses: actions/setup-node@53b83947a5a98c8d113130e565377fae1a50d02f # v6.3.0 with: node-version: 24 - - - name: Cache langfuse server dependencies - uses: actions/cache@668228422ae6a00e4ad889ee87cd7109ec5666a7 # v5 - with: - path: ./langfuse-server/node_modules - key: | - langfuse-server-${{ hashFiles('./langfuse-server/package-lock.json') }} - langfuse-server- + cache: pnpm + cache-dependency-path: ./langfuse-server/pnpm-lock.yaml - name: Run langfuse server run: | cd ./langfuse-server - echo "::group::Run langfuse server" - TELEMETRY_ENABLED=false docker compose up -d postgres - echo "::endgroup::" - - echo "::group::Logs from langfuse server" - TELEMETRY_ENABLED=false docker compose logs + echo "::group::Start backing services" + TELEMETRY_ENABLED=false docker compose up -d postgres redis clickhouse minio echo "::endgroup::" echo "::group::Install dependencies (necessary to run seeder)" - pnpm i + pnpm install --frozen-lockfile --prefer-offline echo "::endgroup::" echo "::group::Seed db" cp .env.dev.example .env pnpm run db:migrate pnpm run db:seed + rm -f .env echo "::endgroup::" - rm -rf .env - echo "::group::Run server" + echo "::group::Start langfuse web and worker" TELEMETRY_ENABLED=false \ LANGFUSE_S3_MEDIA_UPLOAD_ENDPOINT=http://localhost:9090 \ @@ -172,7 +163,7 @@ jobs: QUEUE_CONSUMER_EVENT_PROPAGATION_QUEUE_IS_ENABLED=true \ LANGFUSE_ENABLE_EVENTS_TABLE_V2_APIS=true \ LANGFUSE_ENABLE_EVENTS_TABLE_OBSERVATIONS=true \ - docker compose up -d + docker compose up -d langfuse-web langfuse-worker echo "::endgroup::" From c9ed554fcd01891206cb44eb0be19c9b83c78b96 Mon Sep 17 00:00:00 2001 From: Hassieb Pakzad <68423100+hassiebp@users.noreply.github.com> Date: Sat, 4 Apr 2026 20:19:09 +0200 Subject: [PATCH 09/23] use langfuse init bootstrap in ci --- .github/workflows/ci.yml | 59 ++++++++++++++++------------------------ 1 file changed, 23 insertions(+), 36 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 25622546e..d46ece2de 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -98,6 +98,15 @@ jobs: LANGFUSE_BASE_URL: "http://localhost:3000" LANGFUSE_PUBLIC_KEY: "pk-lf-1234567890" LANGFUSE_SECRET_KEY: "sk-lf-1234567890" + LANGFUSE_INIT_ORG_ID: "0c6c96f4-0ca0-4f16-92a8-6dd7d7c6a501" + LANGFUSE_INIT_ORG_NAME: "SDK Test Org" + LANGFUSE_INIT_PROJECT_ID: "7a88fb47-b4e2-43b8-a06c-a5ce950dc53a" + LANGFUSE_INIT_PROJECT_NAME: "SDK Test Project" + LANGFUSE_INIT_PROJECT_PUBLIC_KEY: "pk-lf-1234567890" + LANGFUSE_INIT_PROJECT_SECRET_KEY: "sk-lf-1234567890" + LANGFUSE_INIT_USER_EMAIL: "sdk-tests@langfuse.local" + LANGFUSE_INIT_USER_NAME: "SDK Tests" + LANGFUSE_INIT_USER_PASSWORD: "langfuse-ci-password" LANGFUSE_E2E_READ_TIMEOUT_SECONDS: "60" LANGFUSE_E2E_READ_INTERVAL_SECONDS: "0.5" OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }} @@ -118,44 +127,21 @@ jobs: run: uv sync --locked - name: Check uv Python version run: uv run --frozen python --version - - uses: pnpm/action-setup@fc06bc1257f339d1d5d8b3a19a8cae5388b55320 # v5 - with: - version: 10.33.0 - - - name: Clone langfuse server + - name: Prepare langfuse server compose run: | - git clone --depth 1 --filter=blob:none --single-branch https://github.com/langfuse/langfuse.git ./langfuse-server - echo "$(cd ./langfuse-server && git rev-parse HEAD)" - - - name: Setup node (for langfuse server) - uses: actions/setup-node@53b83947a5a98c8d113130e565377fae1a50d02f # v6.3.0 - with: - node-version: 24 - cache: pnpm - cache-dependency-path: ./langfuse-server/pnpm-lock.yaml + mkdir -p ./langfuse-server + LANGFUSE_SERVER_SHA="$(git ls-remote https://github.com/langfuse/langfuse.git HEAD | cut -f1)" + curl -fsSL "https://raw.githubusercontent.com/langfuse/langfuse/${LANGFUSE_SERVER_SHA}/docker-compose.yml" \ + -o ./langfuse-server/docker-compose.yml + echo "${LANGFUSE_SERVER_SHA}" - name: Run langfuse server run: | cd ./langfuse-server - echo "::group::Start backing services" - TELEMETRY_ENABLED=false docker compose up -d postgres redis clickhouse minio - echo "::endgroup::" - - echo "::group::Install dependencies (necessary to run seeder)" - pnpm install --frozen-lockfile --prefer-offline - echo "::endgroup::" - - echo "::group::Seed db" - cp .env.dev.example .env - pnpm run db:migrate - pnpm run db:seed - rm -f .env - echo "::endgroup::" - - echo "::group::Start langfuse web and worker" - + echo "::group::Start langfuse server" TELEMETRY_ENABLED=false \ + NEXT_PUBLIC_LANGFUSE_RUN_NEXT_INIT=true \ LANGFUSE_S3_MEDIA_UPLOAD_ENDPOINT=http://localhost:9090 \ LANGFUSE_INGESTION_QUEUE_DELAY_MS=10 \ LANGFUSE_INGESTION_CLICKHOUSE_WRITE_INTERVAL_MS=10 \ @@ -163,22 +149,23 @@ jobs: QUEUE_CONSUMER_EVENT_PROPAGATION_QUEUE_IS_ENABLED=true \ LANGFUSE_ENABLE_EVENTS_TABLE_V2_APIS=true \ LANGFUSE_ENABLE_EVENTS_TABLE_OBSERVATIONS=true \ - docker compose up -d langfuse-web langfuse-worker - + docker compose up -d echo "::endgroup::" - name: Health check for langfuse server run: | echo "Checking if the langfuse server is up..." retry_count=0 - max_retries=10 - until curl --output /dev/null --silent --head --fail http://localhost:3000/api/public/health + max_retries=20 + until curl --output /dev/null --silent --head --fail http://localhost:3000/api/public/health && \ + uv run --frozen python -c "from langfuse import Langfuse; client = Langfuse(); project_id = client._get_project_id(); assert project_id == '7a88fb47-b4e2-43b8-a06c-a5ce950dc53a', project_id; print(project_id)" do retry_count=`expr $retry_count + 1` echo "Attempt $retry_count of $max_retries..." if [ $retry_count -ge $max_retries ]; then echo "Langfuse server did not respond in time. Printing logs..." - docker logs langfuse-server-langfuse-web-1 + (cd ./langfuse-server && docker compose ps) + (cd ./langfuse-server && docker compose logs langfuse-web langfuse-worker) echo "Failing the step..." exit 1 fi From 46785f902107935364f8027549de6b37a60e9d84 Mon Sep 17 00:00:00 2001 From: Hassieb Pakzad <68423100+hassiebp@users.noreply.github.com> Date: Sat, 4 Apr 2026 20:23:11 +0200 Subject: [PATCH 10/23] split serial e2e tests from parallel ci lane --- .github/workflows/ci.yml | 6 +++++- pyproject.toml | 1 + tests/conftest.py | 13 +++++++++++++ 3 files changed, 19 insertions(+), 1 deletion(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index d46ece2de..383f0c6b5 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -176,7 +176,11 @@ jobs: - name: Run the end-to-end tests run: | uv run --frozen python --version - uv run --frozen pytest -n 2 --dist loadfile -s -v --log-cli-level=INFO tests/e2e + uv run --frozen pytest -n 4 --dist loadfile -s -v --log-cli-level=INFO tests/e2e -m "not serial_e2e" + + - name: Run serial end-to-end tests + run: | + uv run --frozen pytest -s -v --log-cli-level=INFO tests/e2e -m serial_e2e all-tests-passed: # This allows us to have a branch protection rule for tests and deploys with matrix diff --git a/pyproject.toml b/pyproject.toml index 81015f6a9..f9d79501b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -55,6 +55,7 @@ log_cli = true markers = [ "unit: deterministic tests that run without a Langfuse server", "e2e: tests that require a real Langfuse server or persisted backend behaviour", + "serial_e2e: e2e tests that must not share server concurrency with the rest of the suite", "live_provider: tests that call live model providers and are kept out of default CI", ] diff --git a/tests/conftest.py b/tests/conftest.py index c1e8e3b87..d4d36ce02 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -10,6 +10,17 @@ from langfuse._client.client import Langfuse from langfuse._client.resource_manager import LangfuseResourceManager +SERIAL_E2E_NODEIDS = { + "tests/e2e/test_core_sdk.py::test_create_boolean_score", + "tests/e2e/test_core_sdk.py::test_create_categorical_score", + "tests/e2e/test_core_sdk.py::test_create_score_with_custom_timestamp", + "tests/e2e/test_decorators.py::test_return_dict_for_output", + "tests/e2e/test_decorators.py::test_media", + "tests/e2e/test_decorators.py::test_merge_metadata_and_tags", + "tests/e2e/test_experiments.py::test_boolean_score_types", + "tests/e2e/test_media.py::test_replace_media_reference_string_in_object", +} + class InMemorySpanExporter(SpanExporter): """Simple in-memory exporter to collect spans for deterministic tests.""" @@ -45,6 +56,8 @@ def pytest_collection_modifyitems(items: list[pytest.Item]) -> None: if test_group == "e2e": item.add_marker(pytest.mark.e2e) + if item.nodeid in SERIAL_E2E_NODEIDS: + item.add_marker(pytest.mark.serial_e2e) continue if test_group == "live_provider": From 12c58477d62600c4700227fb37e93c607d4ba581 Mon Sep 17 00:00:00 2001 From: Hassieb Pakzad <68423100+hassiebp@users.noreply.github.com> Date: Sat, 4 Apr 2026 20:31:40 +0200 Subject: [PATCH 11/23] serialize flaky trace e2e test --- tests/conftest.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/conftest.py b/tests/conftest.py index d4d36ce02..24be7c447 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -11,6 +11,7 @@ from langfuse._client.resource_manager import LangfuseResourceManager SERIAL_E2E_NODEIDS = { + "tests/e2e/test_core_sdk.py::test_create_trace", "tests/e2e/test_core_sdk.py::test_create_boolean_score", "tests/e2e/test_core_sdk.py::test_create_categorical_score", "tests/e2e/test_core_sdk.py::test_create_score_with_custom_timestamp", From 9679ca913e3132192edf9c443696ff623abce477 Mon Sep 17 00:00:00 2001 From: Hassieb Pakzad <68423100+hassiebp@users.noreply.github.com> Date: Sat, 4 Apr 2026 21:15:30 +0200 Subject: [PATCH 12/23] split e2e ci into core and data shards --- .github/workflows/ci.yml | 23 +++++++++++++++++++---- 1 file changed, 19 insertions(+), 4 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 383f0c6b5..f56d21b3f 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -12,7 +12,7 @@ on: concurrency: group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }} - cancel-in-progress: false + cancel-in-progress: true jobs: linting: @@ -94,6 +94,21 @@ jobs: e2e-tests: runs-on: ubuntu-latest timeout-minutes: 30 + strategy: + fail-fast: false + matrix: + include: + - shard-name: core + test-paths: >- + tests/e2e/test_core_sdk.py + tests/e2e/test_decorators.py + tests/e2e/test_media.py + - shard-name: data + test-paths: >- + tests/e2e/test_batch_evaluation.py + tests/e2e/test_datasets.py + tests/e2e/test_experiments.py + tests/e2e/test_prompt.py env: LANGFUSE_BASE_URL: "http://localhost:3000" LANGFUSE_PUBLIC_KEY: "pk-lf-1234567890" @@ -114,7 +129,7 @@ jobs: HUGGINGFACEHUB_API_TOKEN: ${{ secrets.HUGGINGFACEHUB_API_TOKEN }} ANTHROPIC_API_KEY: ${{ secrets.ANTHROPIC_API_KEY }} - name: E2E tests on Python 3.13 + name: E2E ${{ matrix.shard-name }} tests on Python 3.13 steps: - uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6 - name: Install uv and set Python version @@ -176,11 +191,11 @@ jobs: - name: Run the end-to-end tests run: | uv run --frozen python --version - uv run --frozen pytest -n 4 --dist loadfile -s -v --log-cli-level=INFO tests/e2e -m "not serial_e2e" + uv run --frozen pytest -n 4 --dist worksteal -s -v --log-cli-level=INFO ${{ matrix.test-paths }} -m "not serial_e2e" - name: Run serial end-to-end tests run: | - uv run --frozen pytest -s -v --log-cli-level=INFO tests/e2e -m serial_e2e + uv run --frozen pytest -s -v --log-cli-level=INFO ${{ matrix.test-paths }} -m serial_e2e all-tests-passed: # This allows us to have a branch protection rule for tests and deploys with matrix From 174884ce4dcc99fb2392654acfdac683258c1aec Mon Sep 17 00:00:00 2001 From: Hassieb Pakzad <68423100+hassiebp@users.noreply.github.com> Date: Sat, 4 Apr 2026 22:30:28 +0200 Subject: [PATCH 13/23] make e2e data shard the catch-all --- .github/workflows/ci.yml | 24 +++++++++++--------- pyproject.toml | 2 ++ tests/conftest.py | 15 ++++++++++++- tests/e2e/test_batch_evaluation.py | 36 +++++++++++++++++++++++++++++- 4 files changed, 64 insertions(+), 13 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index f56d21b3f..3a6a75c07 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -99,16 +99,17 @@ jobs: matrix: include: - shard-name: core - test-paths: >- - tests/e2e/test_core_sdk.py - tests/e2e/test_decorators.py - tests/e2e/test_media.py + test-root: tests/e2e + parallel-marker: "e2e_core and not serial_e2e" + serial-marker: "e2e_core and serial_e2e" - shard-name: data - test-paths: >- - tests/e2e/test_batch_evaluation.py - tests/e2e/test_datasets.py - tests/e2e/test_experiments.py - tests/e2e/test_prompt.py + test-root: tests/e2e + parallel-marker: "e2e_data and not serial_e2e" + serial-marker: "e2e_data and serial_e2e" + - shard-name: live-provider + test-root: tests/live_provider + parallel-marker: "live_provider" + serial-marker: "" env: LANGFUSE_BASE_URL: "http://localhost:3000" LANGFUSE_PUBLIC_KEY: "pk-lf-1234567890" @@ -191,11 +192,12 @@ jobs: - name: Run the end-to-end tests run: | uv run --frozen python --version - uv run --frozen pytest -n 4 --dist worksteal -s -v --log-cli-level=INFO ${{ matrix.test-paths }} -m "not serial_e2e" + uv run --frozen pytest -n 4 --dist worksteal -s -v --log-cli-level=INFO ${{ matrix.test-root }} -m "${{ matrix.parallel-marker }}" - name: Run serial end-to-end tests + if: ${{ matrix.serial-marker != '' }} run: | - uv run --frozen pytest -s -v --log-cli-level=INFO ${{ matrix.test-paths }} -m serial_e2e + uv run --frozen pytest -s -v --log-cli-level=INFO ${{ matrix.test-root }} -m "${{ matrix.serial-marker }}" all-tests-passed: # This allows us to have a branch protection rule for tests and deploys with matrix diff --git a/pyproject.toml b/pyproject.toml index f9d79501b..a050e3461 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -55,6 +55,8 @@ log_cli = true markers = [ "unit: deterministic tests that run without a Langfuse server", "e2e: tests that require a real Langfuse server or persisted backend behaviour", + "e2e_core: the explicitly curated core e2e shard", + "e2e_data: the catch-all e2e shard for everything not in e2e_core", "serial_e2e: e2e tests that must not share server concurrency with the rest of the suite", "live_provider: tests that call live model providers and are kept out of default CI", ] diff --git a/tests/conftest.py b/tests/conftest.py index 24be7c447..d8125598e 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -10,6 +10,12 @@ from langfuse._client.client import Langfuse from langfuse._client.resource_manager import LangfuseResourceManager +CORE_E2E_FILENAMES = { + "test_core_sdk.py", + "test_decorators.py", + "test_media.py", +} + SERIAL_E2E_NODEIDS = { "tests/e2e/test_core_sdk.py::test_create_trace", "tests/e2e/test_core_sdk.py::test_create_boolean_score", @@ -49,7 +55,8 @@ def clear(self) -> None: def pytest_collection_modifyitems(items: list[pytest.Item]) -> None: for item in items: - test_group = Path(str(item.fspath)).parent.name + file_path = Path(str(item.fspath)) + test_group = file_path.parent.name if test_group == "unit": item.add_marker(pytest.mark.unit) @@ -57,6 +64,12 @@ def pytest_collection_modifyitems(items: list[pytest.Item]) -> None: if test_group == "e2e": item.add_marker(pytest.mark.e2e) + # Keep the data shard as the default so new tests under tests/e2e + # are picked up automatically unless we explicitly promote them. + if file_path.name in CORE_E2E_FILENAMES: + item.add_marker(pytest.mark.e2e_core) + else: + item.add_marker(pytest.mark.e2e_data) if item.nodeid in SERIAL_E2E_NODEIDS: item.add_marker(pytest.mark.serial_e2e) continue diff --git a/tests/e2e/test_batch_evaluation.py b/tests/e2e/test_batch_evaluation.py index 48accef7c..0632b21b8 100644 --- a/tests/e2e/test_batch_evaluation.py +++ b/tests/e2e/test_batch_evaluation.py @@ -18,7 +18,7 @@ EvaluatorStats, ) from langfuse.experiment import Evaluation -from tests.support.utils import create_uuid +from tests.support.utils import create_uuid, get_api, wait_for_result # ============================================================================ # FIXTURES & SETUP @@ -40,6 +40,40 @@ def sample_trace_name(): return f"batch-eval-test-{create_uuid()}" +def _seed_trace_corpus( + *, trace_count: int = 6, tag: str | None = None +) -> tuple[str, list[str]]: + langfuse_client = get_client() + corpus_tag = tag or f"batch-eval-seed-{create_uuid()}" + trace_names: list[str] = [] + + for index in range(trace_count): + trace_name = f"{corpus_tag}-trace-{index}" + trace_names.append(trace_name) + with langfuse_client.start_as_current_observation(name=trace_name) as span: + with propagate_attributes(tags=[corpus_tag]): + span.set_trace_io( + input=f"Seed input {index}", + output=f"Seed output {index}", + ) + + langfuse_client.flush() + + filter_json = f'[{{"type": "arrayOptions", "column": "tags", "operator": "any of", "value": ["{corpus_tag}"]}}]' + api = get_api(retry=False) + wait_for_result( + lambda: api.trace.list(filter=filter_json, limit=trace_count), + is_result_ready=lambda response: len(response.data) >= trace_count, + ) + + return corpus_tag, trace_names + + +@pytest.fixture(scope="module", autouse=True) +def seeded_batch_evaluation_traces(): + _seed_trace_corpus() + + def simple_trace_mapper(*, item): """Simple mapper for traces.""" return EvaluatorInputs( From a98487e017aee91b0af537d683222112b4c7d577 Mon Sep 17 00:00:00 2001 From: Hassieb Pakzad <68423100+hassiebp@users.noreply.github.com> Date: Sat, 4 Apr 2026 22:33:49 +0200 Subject: [PATCH 14/23] cache langfuse docker images in ci --- .github/workflows/ci.yml | 39 +++++++++++++++++++++++++++++++++++++++ 1 file changed, 39 insertions(+) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 3a6a75c07..b296787f3 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -150,6 +150,37 @@ jobs: curl -fsSL "https://raw.githubusercontent.com/langfuse/langfuse/${LANGFUSE_SERVER_SHA}/docker-compose.yml" \ -o ./langfuse-server/docker-compose.yml echo "${LANGFUSE_SERVER_SHA}" + - name: Resolve langfuse server image digests + run: | + docker compose -f ./langfuse-server/docker-compose.yml config --images \ + | tee ./langfuse-server/docker-images.txt + + while read -r image; do + docker manifest inspect "$image" \ + | jq -r --arg image "$image" ' + if .manifests then + .manifests[] + | select(.platform.os == "linux" and .platform.architecture == "amd64") + | "\($image)@\(.digest)" + else + "\($image)@\(.config.digest // .Descriptor.digest // .digest)" + end + ' + done < ./langfuse-server/docker-images.txt \ + | tee ./langfuse-server/docker-image-digests.txt + - name: Restore langfuse server image cache + id: docker-image-cache + uses: actions/cache@668228422ae6a00e4ad889ee87cd7109ec5666a7 # v5 + with: + path: ./langfuse-server/docker-image-cache + key: langfuse-docker-${{ runner.os }}-${{ hashFiles('langfuse-server/docker-compose.yml', 'langfuse-server/docker-image-digests.txt') }} + - name: Load cached langfuse server images + if: ${{ steps.docker-image-cache.outputs.cache-hit == 'true' }} + run: | + shopt -s nullglob + for image_tar in ./langfuse-server/docker-image-cache/*.tar; do + docker load -i "$image_tar" + done - name: Run langfuse server run: | @@ -167,6 +198,14 @@ jobs: LANGFUSE_ENABLE_EVENTS_TABLE_OBSERVATIONS=true \ docker compose up -d echo "::endgroup::" + - name: Save langfuse server images to cache + if: ${{ steps.docker-image-cache.outputs.cache-hit != 'true' && matrix.shard-name == 'core' }} + run: | + mkdir -p ./langfuse-server/docker-image-cache + while read -r image; do + safe_name=$(echo "$image" | tr '/:.' '_') + docker save -o "./langfuse-server/docker-image-cache/${safe_name}.tar" "$image" + done < ./langfuse-server/docker-images.txt - name: Health check for langfuse server run: | From 6fd4dde0ce2343a3f642bc0834ec20b22bbf7c23 Mon Sep 17 00:00:00 2001 From: Hassieb Pakzad <68423100+hassiebp@users.noreply.github.com> Date: Sat, 4 Apr 2026 22:37:09 +0200 Subject: [PATCH 15/23] stabilize live-provider langchain assertions --- tests/live_provider/test_langchain.py | 52 ++++++++++++++++----------- 1 file changed, 32 insertions(+), 20 deletions(-) diff --git a/tests/live_provider/test_langchain.py b/tests/live_provider/test_langchain.py index 33135a7fe..82ed1372c 100644 --- a/tests/live_provider/test_langchain.py +++ b/tests/live_provider/test_langchain.py @@ -51,14 +51,16 @@ def test_callback_generated_from_trace_chat(): assert trace.id == trace_id - assert len(trace.observations) == 3 + assert len(trace.observations) >= 2 + assert any(observation.name == "parent" for observation in trace.observations) - langchain_generation_span = list( - filter( - lambda o: o.type == "GENERATION" and o.name == "ChatOpenAI", - trace.observations, - ) - )[0] + generation_observations = [ + observation + for observation in trace.observations + if observation.type == "GENERATION" and observation.name == "ChatOpenAI" + ] + assert len(generation_observations) == 1 + langchain_generation_span = generation_observations[0] assert langchain_generation_span.usage_details["input"] > 0 assert langchain_generation_span.usage_details["output"] > 0 @@ -294,19 +296,26 @@ def test_openai_instruct_usage(): observations = get_api().trace.get(trace_id).observations - # Add 1 to account for the wrapping span - assert len(observations) == 4 + assert len(observations) >= 3 + assert any( + observation.name == "openai_instruct_usage_test" and observation.type == "SPAN" + for observation in observations + ) - for observation in observations: - if observation.type == "GENERATION": - assert observation.output is not None - assert observation.output != "" - assert observation.input is not None - assert observation.input != "" - assert observation.usage is not None - assert observation.usage_details["input"] is not None - assert observation.usage_details["output"] is not None - assert observation.usage_details["total"] is not None + generation_observations = [ + observation for observation in observations if observation.type == "GENERATION" + ] + assert len(generation_observations) == len(input_list) + + for observation in generation_observations: + assert observation.output is not None + assert observation.output != "" + assert observation.input is not None + assert observation.input != "" + assert observation.usage is not None + assert observation.usage_details["input"] is not None + assert observation.usage_details["output"] is not None + assert observation.usage_details["total"] is not None def test_get_langchain_prompt_with_jinja2(): @@ -869,7 +878,10 @@ def test_multimodal(): trace = get_api().trace.get(trace_id=trace_id) - assert len(trace.observations) == 3 + assert len(trace.observations) >= 2 + assert any( + observation.name == "test_multimodal" for observation in trace.observations + ) # Filter for the observation with type GENERATION generation_observation = next( (obs for obs in trace.observations if obs.type == "GENERATION"), None From 30ffc8fc6b791cb85d1bf248fa951bcfc08db579 Mon Sep 17 00:00:00 2001 From: Hassieb Pakzad <68423100+hassiebp@users.noreply.github.com> Date: Sat, 4 Apr 2026 22:53:43 +0200 Subject: [PATCH 16/23] replace marker-based e2e sharding --- .github/workflows/ci.yml | 78 +++++++++++++++++----- pyproject.toml | 4 +- scripts/select_e2e_shard.py | 114 ++++++++++++++++++++++++++++++++ tests/conftest.py | 15 +---- tests/unit/test_e2e_sharding.py | 54 +++++++++++++++ 5 files changed, 230 insertions(+), 35 deletions(-) create mode 100644 scripts/select_e2e_shard.py create mode 100644 tests/unit/test_e2e_sharding.py diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index b296787f3..c4d4fd1fc 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -98,18 +98,22 @@ jobs: fail-fast: false matrix: include: - - shard-name: core - test-root: tests/e2e - parallel-marker: "e2e_core and not serial_e2e" - serial-marker: "e2e_core and serial_e2e" - - shard-name: data - test-root: tests/e2e - parallel-marker: "e2e_data and not serial_e2e" - serial-marker: "e2e_data and serial_e2e" - - shard-name: live-provider - test-root: tests/live_provider - parallel-marker: "live_provider" - serial-marker: "" + - suite: e2e + job_name: E2E shard 1 tests on Python 3.13 + shard_name: shard-1 + shard_index: 0 + shard_count: 2 + cache_writer: true + - suite: e2e + job_name: E2E shard 2 tests on Python 3.13 + shard_name: shard-2 + shard_index: 1 + shard_count: 2 + cache_writer: false + - suite: live_provider + job_name: E2E live-provider tests on Python 3.13 + shard_name: live-provider + cache_writer: false env: LANGFUSE_BASE_URL: "http://localhost:3000" LANGFUSE_PUBLIC_KEY: "pk-lf-1234567890" @@ -130,7 +134,7 @@ jobs: HUGGINGFACEHUB_API_TOKEN: ${{ secrets.HUGGINGFACEHUB_API_TOKEN }} ANTHROPIC_API_KEY: ${{ secrets.ANTHROPIC_API_KEY }} - name: E2E ${{ matrix.shard-name }} tests on Python 3.13 + name: ${{ matrix.job_name }} steps: - uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6 - name: Install uv and set Python version @@ -199,7 +203,7 @@ jobs: docker compose up -d echo "::endgroup::" - name: Save langfuse server images to cache - if: ${{ steps.docker-image-cache.outputs.cache-hit != 'true' && matrix.shard-name == 'core' }} + if: ${{ steps.docker-image-cache.outputs.cache-hit != 'true' && matrix.cache_writer }} run: | mkdir -p ./langfuse-server/docker-image-cache while read -r image; do @@ -228,15 +232,53 @@ jobs: done echo "Langfuse server is up and running!" - - name: Run the end-to-end tests + - name: Select e2e shard files + if: ${{ matrix.suite == 'e2e' }} + run: | + uv run --frozen python scripts/select_e2e_shard.py \ + --shard-index ${{ matrix.shard_index }} \ + --shard-count ${{ matrix.shard_count }} \ + --json + uv run --frozen python scripts/select_e2e_shard.py \ + --shard-index ${{ matrix.shard_index }} \ + --shard-count ${{ matrix.shard_count }} \ + > "$RUNNER_TEMP/e2e-shard-files.txt" + cat "$RUNNER_TEMP/e2e-shard-files.txt" + + - name: Run the parallel end-to-end tests + if: ${{ matrix.suite == 'e2e' }} run: | uv run --frozen python --version - uv run --frozen pytest -n 4 --dist worksteal -s -v --log-cli-level=INFO ${{ matrix.test-root }} -m "${{ matrix.parallel-marker }}" + mapfile -t e2e_files < "$RUNNER_TEMP/e2e-shard-files.txt" + set +e + uv run --frozen pytest -n 4 --dist worksteal -s -v --log-cli-level=INFO "${e2e_files[@]}" -m "not serial_e2e" + status=$? + set -e + if [ "$status" -eq 5 ]; then + echo "No parallel e2e tests selected for this shard." + elif [ "$status" -ne 0 ]; then + exit "$status" + fi - name: Run serial end-to-end tests - if: ${{ matrix.serial-marker != '' }} + if: ${{ matrix.suite == 'e2e' }} run: | - uv run --frozen pytest -s -v --log-cli-level=INFO ${{ matrix.test-root }} -m "${{ matrix.serial-marker }}" + mapfile -t e2e_files < "$RUNNER_TEMP/e2e-shard-files.txt" + set +e + uv run --frozen pytest -s -v --log-cli-level=INFO "${e2e_files[@]}" -m "serial_e2e" + status=$? + set -e + if [ "$status" -eq 5 ]; then + echo "No serial e2e tests selected for this shard." + elif [ "$status" -ne 0 ]; then + exit "$status" + fi + + - name: Run live-provider tests + if: ${{ matrix.suite == 'live_provider' }} + run: | + uv run --frozen python --version + uv run --frozen pytest -n 4 --dist worksteal -s -v --log-cli-level=INFO tests/live_provider -m "live_provider" all-tests-passed: # This allows us to have a branch protection rule for tests and deploys with matrix diff --git a/pyproject.toml b/pyproject.toml index a050e3461..666c34e52 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -55,10 +55,8 @@ log_cli = true markers = [ "unit: deterministic tests that run without a Langfuse server", "e2e: tests that require a real Langfuse server or persisted backend behaviour", - "e2e_core: the explicitly curated core e2e shard", - "e2e_data: the catch-all e2e shard for everything not in e2e_core", "serial_e2e: e2e tests that must not share server concurrency with the rest of the suite", - "live_provider: tests that call live model providers and are kept out of default CI", + "live_provider: tests that call live model providers and run as a dedicated CI suite", ] [tool.mypy] diff --git a/scripts/select_e2e_shard.py b/scripts/select_e2e_shard.py new file mode 100644 index 000000000..688d8468d --- /dev/null +++ b/scripts/select_e2e_shard.py @@ -0,0 +1,114 @@ +import argparse +import ast +import json +from pathlib import Path + +REPO_ROOT = Path(__file__).resolve().parents[1] +E2E_ROOT = REPO_ROOT / "tests" / "e2e" + +# These weights keep the existing balance close to the observed runtime split, +# while new files automatically fall back to their local test count. +HISTORICAL_WEIGHTS = { + "tests/e2e/test_batch_evaluation.py": 41, + "tests/e2e/test_core_sdk.py": 53, + "tests/e2e/test_datasets.py": 7, + "tests/e2e/test_decorators.py": 32, + "tests/e2e/test_experiments.py": 17, + "tests/e2e/test_media.py": 1, + "tests/e2e/test_prompt.py": 27, +} + + +def relative_test_path(path: Path) -> str: + return path.relative_to(REPO_ROOT).as_posix() + + +def discover_e2e_files() -> list[Path]: + return sorted(E2E_ROOT.glob("test_*.py")) + + +def count_test_functions(path: Path) -> int: + module = ast.parse(path.read_text(encoding="utf-8")) + return sum( + 1 + for node in module.body + if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)) + and node.name.startswith("test_") + ) + + +def estimate_weight(path: Path) -> int: + try: + relative_path = relative_test_path(path) + except ValueError: + relative_path = None + if relative_path is not None and relative_path in HISTORICAL_WEIGHTS: + return HISTORICAL_WEIGHTS[relative_path] + + return max(count_test_functions(path), 1) + + +def assign_shards( + paths: list[Path], shard_count: int +) -> tuple[list[list[str]], list[int]]: + shard_loads = [0] * shard_count + shards: list[list[str]] = [[] for _ in range(shard_count)] + + weighted_paths = sorted( + ((estimate_weight(path), relative_test_path(path)) for path in paths), + key=lambda item: (-item[0], item[1]), + ) + + for weight, relative_path in weighted_paths: + shard_index = min( + range(shard_count), key=lambda index: (shard_loads[index], index) + ) + shards[shard_index].append(relative_path) + shard_loads[shard_index] += weight + + return [sorted(shard) for shard in shards], shard_loads + + +def parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser( + description="Select the files for one e2e CI shard." + ) + parser.add_argument("--shard-index", required=True, type=int) + parser.add_argument("--shard-count", default=2, type=int) + parser.add_argument("--json", action="store_true") + return parser.parse_args() + + +def main() -> int: + args = parse_args() + + if args.shard_count < 1: + raise SystemExit("--shard-count must be at least 1") + + if args.shard_index < 0 or args.shard_index >= args.shard_count: + raise SystemExit("--shard-index must be within the configured shard count") + + shards, shard_loads = assign_shards(discover_e2e_files(), args.shard_count) + selected_files = shards[args.shard_index] + + if args.json: + print( + json.dumps( + { + "shard_count": args.shard_count, + "shard_index": args.shard_index, + "selected_files": selected_files, + "shard_loads": shard_loads, + } + ) + ) + return 0 + + for path in selected_files: + print(path) + + return 0 + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/tests/conftest.py b/tests/conftest.py index d8125598e..24be7c447 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -10,12 +10,6 @@ from langfuse._client.client import Langfuse from langfuse._client.resource_manager import LangfuseResourceManager -CORE_E2E_FILENAMES = { - "test_core_sdk.py", - "test_decorators.py", - "test_media.py", -} - SERIAL_E2E_NODEIDS = { "tests/e2e/test_core_sdk.py::test_create_trace", "tests/e2e/test_core_sdk.py::test_create_boolean_score", @@ -55,8 +49,7 @@ def clear(self) -> None: def pytest_collection_modifyitems(items: list[pytest.Item]) -> None: for item in items: - file_path = Path(str(item.fspath)) - test_group = file_path.parent.name + test_group = Path(str(item.fspath)).parent.name if test_group == "unit": item.add_marker(pytest.mark.unit) @@ -64,12 +57,6 @@ def pytest_collection_modifyitems(items: list[pytest.Item]) -> None: if test_group == "e2e": item.add_marker(pytest.mark.e2e) - # Keep the data shard as the default so new tests under tests/e2e - # are picked up automatically unless we explicitly promote them. - if file_path.name in CORE_E2E_FILENAMES: - item.add_marker(pytest.mark.e2e_core) - else: - item.add_marker(pytest.mark.e2e_data) if item.nodeid in SERIAL_E2E_NODEIDS: item.add_marker(pytest.mark.serial_e2e) continue diff --git a/tests/unit/test_e2e_sharding.py b/tests/unit/test_e2e_sharding.py new file mode 100644 index 000000000..a1ec84e58 --- /dev/null +++ b/tests/unit/test_e2e_sharding.py @@ -0,0 +1,54 @@ +import importlib.util +from pathlib import Path + +REPO_ROOT = Path(__file__).resolve().parents[2] +SCRIPT_PATH = REPO_ROOT / "scripts" / "select_e2e_shard.py" + + +def load_shard_script(): + spec = importlib.util.spec_from_file_location("select_e2e_shard", SCRIPT_PATH) + if spec is None or spec.loader is None: + raise AssertionError(f"Unable to load shard selector from {SCRIPT_PATH}") + + module = importlib.util.module_from_spec(spec) + spec.loader.exec_module(module) + return module + + +def test_e2e_shards_cover_all_files_once(): + shard_script = load_shard_script() + + all_files = sorted( + path.relative_to(REPO_ROOT).as_posix() + for path in (REPO_ROOT / "tests" / "e2e").glob("test_*.py") + ) + + shards, shard_loads = shard_script.assign_shards( + shard_script.discover_e2e_files(), shard_count=2 + ) + + assert len(shards) == 2 + assert set(shards[0]).isdisjoint(shards[1]) + assert sorted([path for shard in shards for path in shard]) == all_files + assert all(load > 0 for load in shard_loads) + + +def test_unknown_file_weight_falls_back_to_test_count(tmp_path: Path): + shard_script = load_shard_script() + + test_file = tmp_path / "test_future_suite.py" + test_file.write_text( + "\n".join( + [ + "def test_one():", + " pass", + "", + "async def test_two():", + " pass", + ] + ), + encoding="utf-8", + ) + + assert shard_script.count_test_functions(test_file) == 2 + assert shard_script.estimate_weight(test_file) == 2 From 27ae9d3eed0eecb1ab22587c3bca4e33d43c3ddc Mon Sep 17 00:00:00 2001 From: Hassieb Pakzad <68423100+hassiebp@users.noreply.github.com> Date: Sat, 4 Apr 2026 22:57:13 +0200 Subject: [PATCH 17/23] remove docker image cache from e2e ci --- .github/workflows/ci.yml | 42 ---------------------------------------- 1 file changed, 42 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index c4d4fd1fc..9b2b4ad03 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -103,17 +103,14 @@ jobs: shard_name: shard-1 shard_index: 0 shard_count: 2 - cache_writer: true - suite: e2e job_name: E2E shard 2 tests on Python 3.13 shard_name: shard-2 shard_index: 1 shard_count: 2 - cache_writer: false - suite: live_provider job_name: E2E live-provider tests on Python 3.13 shard_name: live-provider - cache_writer: false env: LANGFUSE_BASE_URL: "http://localhost:3000" LANGFUSE_PUBLIC_KEY: "pk-lf-1234567890" @@ -154,37 +151,6 @@ jobs: curl -fsSL "https://raw.githubusercontent.com/langfuse/langfuse/${LANGFUSE_SERVER_SHA}/docker-compose.yml" \ -o ./langfuse-server/docker-compose.yml echo "${LANGFUSE_SERVER_SHA}" - - name: Resolve langfuse server image digests - run: | - docker compose -f ./langfuse-server/docker-compose.yml config --images \ - | tee ./langfuse-server/docker-images.txt - - while read -r image; do - docker manifest inspect "$image" \ - | jq -r --arg image "$image" ' - if .manifests then - .manifests[] - | select(.platform.os == "linux" and .platform.architecture == "amd64") - | "\($image)@\(.digest)" - else - "\($image)@\(.config.digest // .Descriptor.digest // .digest)" - end - ' - done < ./langfuse-server/docker-images.txt \ - | tee ./langfuse-server/docker-image-digests.txt - - name: Restore langfuse server image cache - id: docker-image-cache - uses: actions/cache@668228422ae6a00e4ad889ee87cd7109ec5666a7 # v5 - with: - path: ./langfuse-server/docker-image-cache - key: langfuse-docker-${{ runner.os }}-${{ hashFiles('langfuse-server/docker-compose.yml', 'langfuse-server/docker-image-digests.txt') }} - - name: Load cached langfuse server images - if: ${{ steps.docker-image-cache.outputs.cache-hit == 'true' }} - run: | - shopt -s nullglob - for image_tar in ./langfuse-server/docker-image-cache/*.tar; do - docker load -i "$image_tar" - done - name: Run langfuse server run: | @@ -202,14 +168,6 @@ jobs: LANGFUSE_ENABLE_EVENTS_TABLE_OBSERVATIONS=true \ docker compose up -d echo "::endgroup::" - - name: Save langfuse server images to cache - if: ${{ steps.docker-image-cache.outputs.cache-hit != 'true' && matrix.cache_writer }} - run: | - mkdir -p ./langfuse-server/docker-image-cache - while read -r image; do - safe_name=$(echo "$image" | tr '/:.' '_') - docker save -o "./langfuse-server/docker-image-cache/${safe_name}.tar" "$image" - done < ./langfuse-server/docker-images.txt - name: Health check for langfuse server run: | From 67e0682659460d51ba168aca91fd9511fd76a0bc Mon Sep 17 00:00:00 2001 From: Hassieb Pakzad <68423100+hassiebp@users.noreply.github.com> Date: Sat, 4 Apr 2026 23:01:16 +0200 Subject: [PATCH 18/23] add shared agent instructions --- AGENTS.md | 128 +++++++++++++++++++++++++++++++++++++++++++++++++ CLAUDE.md | 141 ++---------------------------------------------------- 2 files changed, 132 insertions(+), 137 deletions(-) create mode 100644 AGENTS.md diff --git a/AGENTS.md b/AGENTS.md new file mode 100644 index 000000000..c350a79c8 --- /dev/null +++ b/AGENTS.md @@ -0,0 +1,128 @@ +# AGENTS.md + +Shared instructions for coding agents working in this repository. + +Keep this file concise, concrete, and repo-specific. If guidance grows large, split it into referenced docs instead of turning this file into a handbook. + +## Project Summary + +This repository contains the Langfuse Python SDK. + +- `langfuse/_client/`: core SDK, OpenTelemetry integration, resource management, decorators, datasets +- `langfuse/openai.py`: OpenAI instrumentation +- `langfuse/langchain/`: LangChain integration +- `langfuse/_task_manager/`: background consumers for media and score ingestion +- `langfuse/api/`: generated Fern API client, do not hand-edit +- `tests/unit/`: deterministic local tests, no Langfuse server +- `tests/e2e/`: real Langfuse-server tests +- `tests/live_provider/`: live OpenAI / LangChain provider tests +- `tests/support/`: shared helpers for e2e tests +- `scripts/select_e2e_shard.py`: CI shard selector for `tests/e2e` + +## Working Style + +- Prefer small, targeted changes that preserve existing behavior. +- Do not weaken assertions just to make tests faster or greener. +- If a test is slow, first optimize setup, teardown, polling, or fixtures. +- Keep repo-shared instructions here. Keep personal or machine-specific notes out of version control. + +## Setup And Quality Commands + +```bash +uv sync --locked +uv run pre-commit install +uv run --frozen ruff check . +uv run --frozen ruff format . +uv run --frozen mypy langfuse --no-error-summary +``` + +## Test Commands + +Use the directory-based test split. + +```bash +# Unit tests +uv run --frozen pytest -n auto --dist worksteal tests/unit + +# All e2e tests that can run concurrently +uv run --frozen pytest -n 4 --dist worksteal tests/e2e -m "not serial_e2e" + +# E2E tests that must run serially +uv run --frozen pytest tests/e2e -m "serial_e2e" + +# Live provider tests +uv run --frozen pytest -n 4 --dist worksteal tests/live_provider -m "live_provider" + +# Single test +uv run --frozen pytest tests/unit/test_resource_manager.py::test_pause_signals_score_consumer_shutdown +``` + +## Test Topology + +### `tests/unit` + +- Must not require a running Langfuse server. +- Prefer in-memory exporters and local fakes over real network calls. +- If tracing behavior is under test, use the shared in-memory fixtures in `tests/conftest.py`. + +### `tests/e2e` + +- Use for persisted backend behavior that genuinely needs a real Langfuse server. +- Prefer bounded polling helpers in `tests/support/` over raw `sleep()` calls. +- Use `serial_e2e` only for tests that are unsafe under shared-server concurrency. +- New e2e files should be named `tests/e2e/test_*.py`. +- Do not add `e2e_core` / `e2e_data` markers. CI shards `tests/e2e` mechanically with `scripts/select_e2e_shard.py`. + +### `tests/live_provider` + +- This suite uses real provider calls and always runs as one dedicated CI suite. +- Do not split or shard `tests/live_provider` into separate smoke and extended jobs unless the team explicitly changes that policy. +- Keep assertions focused on stable provider-facing behavior rather than brittle observation counts. + +## CI Contract + +The main CI workflow currently runs: + +- linting on Python 3.13 +- mypy on Python 3.13 +- `tests/unit` on a Python 3.10-3.14 matrix +- `tests/e2e` in 2 mechanical shards plus a serial subset inside each shard +- `tests/live_provider` as one always-on suite + +If you change the e2e split: + +- update `scripts/select_e2e_shard.py`, not marker routing in `tests/conftest.py` +- make sure new `tests/e2e/test_*.py` files are automatically covered +- keep `serial_e2e` as the only scheduling-specific pytest marker + +If you change CI bootstrap: + +- preserve the `LANGFUSE_INIT_*` startup path for the Langfuse server unless there is a strong reason to change it +- preserve `cancel-in-progress: true` + +## Codebase Rules + +- Prefer `LANGFUSE_BASE_URL`; `LANGFUSE_HOST` is deprecated and is only kept for compatibility tests. +- If you touch `langfuse/api/`, regenerate it from the upstream Fern/OpenAPI source instead of hand-editing files. +- If you touch shutdown, flushing, or worker-thread behavior, run the relevant resource-manager and OTEL-heavy tests. +- If you change OpenAI or LangChain instrumentation, keep as much coverage as possible in `tests/unit` using exporter-local assertions, and leave only the minimal necessary coverage in `tests/e2e` / `tests/live_provider`. + +## Python-Specific Notes + +- Exception messages should not inline f-string literals in the `raise` statement. Build the message in a variable first. +- Prefer ASCII-only edits unless the file already uses Unicode or Unicode is clearly required. + +## Release And Docs + +```bash +uv build --no-sources +uv run --group docs pdoc -o docs/ --docformat google --logo "https://langfuse.com/langfuse_logo.svg" langfuse +``` + +Releases are handled by GitHub Actions. Do not build an ad hoc local release flow into repository instructions. + +## External Docs + +- Prefer official documentation first when answering product or API questions. +- For OpenAI API, ChatGPT Apps SDK, or Codex questions, use the official OpenAI developer docs or Docs MCP server if available. +- If this repository keeps agent-specific files for multiple tools, treat `AGENTS.md` as the shared source of truth and import it from tool-specific files instead of duplicating instructions. diff --git a/CLAUDE.md b/CLAUDE.md index 330193c59..6f612e8aa 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -1,141 +1,8 @@ # CLAUDE.md -This file provides guidance to Claude Code (claude.ai/code) when working with code in this repository. +@AGENTS.md -## Project Overview +## Claude-Specific Notes -This is the Langfuse Python SDK, a client library for accessing the Langfuse observability platform. The SDK provides integration with OpenTelemetry (OTel) for tracing, automatic instrumentation for popular LLM frameworks (OpenAI, Langchain, etc.), and direct API access to Langfuse's features. - -## Development Commands - -### Setup - -```bash -# Install the project and development dependencies -uv sync - -# Setup pre-commit hooks -uv run pre-commit install -``` - -### Testing - -```bash -# Run all tests with verbose output -uv run --env-file .env pytest -s -v --log-cli-level=INFO - -# Run a specific test -uv run --env-file .env pytest -s -v --log-cli-level=INFO tests/test_core_sdk.py::test_flush - -# Run tests in parallel (faster) -uv run --env-file .env pytest -s -v --log-cli-level=INFO -n auto -``` - -### Code Quality - -```bash -# Format code with Ruff -uv run ruff format . - -# Run linting (development config) -uv run ruff check . - -# Run type checking -uv run mypy . - -# Run pre-commit hooks manually -uv run pre-commit run --all-files -``` - -### Building and Releasing - -```bash -# Build the package locally (for testing) -uv build --no-sources - -# Generate documentation -uv run --group docs pdoc -o docs/ --docformat google --logo "https://langfuse.com/langfuse_logo.svg" langfuse -``` - -Releases are automated via GitHub Actions. To release: - -1. Go to Actions > "Release Python SDK" workflow -2. Click "Run workflow" -3. Select version bump type (patch/minor/major/prepatch/preminor/premajor) -4. For prereleases, select the type (alpha/beta/rc) - -The workflow handles versioning, building, PyPI publishing (via OIDC), and GitHub release creation. - -## Architecture - -### Core Components - -- **`langfuse/_client/`**: Main SDK implementation built on OpenTelemetry - - `client.py`: Core Langfuse client with OTel integration - - `span.py`: LangfuseSpan, LangfuseGeneration, LangfuseEvent classes - - `observe.py`: Decorator for automatic instrumentation - - `datasets.py`: Dataset management functionality - -- **`langfuse/api/`**: Auto-generated Fern API client - - Contains all API resources and types - - Generated from OpenAPI spec - do not manually edit these files - -- **`langfuse/_task_manager/`**: Background processing - - Media upload handling and queue management - - Score ingestion consumer - -- **Integration modules**: - - `langfuse/openai.py`: OpenAI instrumentation - - `langfuse/langchain/`: Langchain integration via CallbackHandler - -### Key Design Patterns - -The SDK is built on OpenTelemetry for observability, using: - -- Spans for tracing LLM operations -- Attributes for metadata (see `LangfuseOtelSpanAttributes`) -- Resource management for efficient batching and flushing - -The client follows an async-first design with automatic batching of events and background flushing to the Langfuse API. - -## Configuration - -Environment variables (defined in `_client/environment_variables.py`): - -- `LANGFUSE_PUBLIC_KEY` / `LANGFUSE_SECRET_KEY`: API credentials -- `LANGFUSE_HOST`: API endpoint (defaults to https://cloud.langfuse.com) -- `LANGFUSE_DEBUG`: Enable debug logging -- `LANGFUSE_TRACING_ENABLED`: Enable/disable tracing -- `LANGFUSE_SAMPLE_RATE`: Sampling rate for traces - -## Testing Notes - -- Create `.env` file based on `.env.template` for integration tests -- E2E tests with external APIs (OpenAI, SERP) are typically skipped in CI -- Remove `@pytest.mark.skip` decorators in test files to run external API tests -- Tests use `respx` for HTTP mocking and `pytest-httpserver` for test servers - -## Important Files - -- `pyproject.toml`: uv project metadata, dependencies, and tool settings -- `uv.lock`: Locked dependency graph for local development and CI - -## API Generation - -The `langfuse/api/` directory is auto-generated from the Langfuse OpenAPI specification using Fern. To update: - -1. Generate new SDK in main Langfuse repo -2. Copy generated files from `generated/python` to `langfuse/api/` -3. Run `uv run ruff format .` to format the generated code - -## Testing Guidelines - -### Approach to Test Changes - -- Don't remove functionality from existing unit tests just to make tests pass. Only change the test, if underlying code changes warrant a test change. - -## Python Code Rules - -### Exception Handling - -- Exception must not use an f-string literal, assign to variable first +- Keep repository-wide instructions in `AGENTS.md`. +- Use `CLAUDE.local.md` for personal or machine-specific preferences that should not be committed. From 1719aee7dca1fda3e849b4c93fa8288edc36b980 Mon Sep 17 00:00:00 2001 From: Hassieb Pakzad <68423100+hassiebp@users.noreply.github.com> Date: Sat, 4 Apr 2026 23:03:44 +0200 Subject: [PATCH 19/23] sync agent guidance with monorepo standards --- AGENTS.md | 26 ++++++++++++++++++++++++-- CLAUDE.md | 8 -------- 2 files changed, 24 insertions(+), 10 deletions(-) delete mode 100644 CLAUDE.md diff --git a/AGENTS.md b/AGENTS.md index c350a79c8..807d1674b 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -4,6 +4,13 @@ Shared instructions for coding agents working in this repository. Keep this file concise, concrete, and repo-specific. If guidance grows large, split it into referenced docs instead of turning this file into a handbook. +## Maintenance Contract + +- `AGENTS.md` is a living document. +- Update it in the same PR when repo-wide workflows, architecture, CI contracts, release processes, or durable coding defaults materially change. +- Do not edit this file for one-off task preferences. +- Keep this file as the canonical shared agent guide for this repository. + ## Project Summary This repository contains the Langfuse Python SDK. @@ -25,6 +32,8 @@ This repository contains the Langfuse Python SDK. - Do not weaken assertions just to make tests faster or greener. - If a test is slow, first optimize setup, teardown, polling, or fixtures. - Keep repo-shared instructions here. Keep personal or machine-specific notes out of version control. +- Keep tests independent and parallel-safe by default. +- For bug fixes, prefer writing or identifying the failing test first, confirm the failure, then implement the fix. ## Setup And Quality Commands @@ -100,12 +109,21 @@ If you change CI bootstrap: - preserve the `LANGFUSE_INIT_*` startup path for the Langfuse server unless there is a strong reason to change it - preserve `cancel-in-progress: true` -## Codebase Rules +## Repo Rules +- Keep changes scoped. Avoid unrelated refactors. - Prefer `LANGFUSE_BASE_URL`; `LANGFUSE_HOST` is deprecated and is only kept for compatibility tests. - If you touch `langfuse/api/`, regenerate it from the upstream Fern/OpenAPI source instead of hand-editing files. - If you touch shutdown, flushing, or worker-thread behavior, run the relevant resource-manager and OTEL-heavy tests. - If you change OpenAI or LangChain instrumentation, keep as much coverage as possible in `tests/unit` using exporter-local assertions, and leave only the minimal necessary coverage in `tests/e2e` / `tests/live_provider`. +- Never commit secrets or credentials. +- Keep `.env.template` in sync with required local-development environment variables. + +## Commit And PR Rules + +- Commit messages and PR titles should follow Conventional Commits: `type(scope): description` or `type: description`. +- Keep commits focused and atomic. +- In PR descriptions, list the main verification commands you ran. ## Python-Specific Notes @@ -125,4 +143,8 @@ Releases are handled by GitHub Actions. Do not build an ad hoc local release flo - Prefer official documentation first when answering product or API questions. - For OpenAI API, ChatGPT Apps SDK, or Codex questions, use the official OpenAI developer docs or Docs MCP server if available. -- If this repository keeps agent-specific files for multiple tools, treat `AGENTS.md` as the shared source of truth and import it from tool-specific files instead of duplicating instructions. + +## Git Safety + +- Do not use destructive git commands such as `git reset --hard` unless explicitly requested. +- Do not revert unrelated working-tree changes. diff --git a/CLAUDE.md b/CLAUDE.md deleted file mode 100644 index 6f612e8aa..000000000 --- a/CLAUDE.md +++ /dev/null @@ -1,8 +0,0 @@ -# CLAUDE.md - -@AGENTS.md - -## Claude-Specific Notes - -- Keep repository-wide instructions in `AGENTS.md`. -- Use `CLAUDE.local.md` for personal or machine-specific preferences that should not be committed. From 871fc312c73c17843316502d3893537426098fac Mon Sep 17 00:00:00 2001 From: Hassieb Pakzad <68423100+hassiebp@users.noreply.github.com> Date: Sat, 4 Apr 2026 23:05:48 +0200 Subject: [PATCH 20/23] fix(tests): wait for generation visibility in e2e --- tests/e2e/test_core_sdk.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/tests/e2e/test_core_sdk.py b/tests/e2e/test_core_sdk.py index f3cc542c5..dd8746a52 100644 --- a/tests/e2e/test_core_sdk.py +++ b/tests/e2e/test_core_sdk.py @@ -1230,10 +1230,14 @@ def test_end_generation(): # Ensure data is sent langfuse.flush() - sleep(2) # Retrieve and verify - trace = api_wrapper.get_trace(trace_id) + trace = api_wrapper.get_trace( + trace_id, + is_result_ready=lambda trace: any( + obs["name"] == "query-generation" for obs in trace.get("observations", []) + ), + ) # Find generation by name generations = [ From eb98e3ad5225ceb927336be51c716c34f400f0f7 Mon Sep 17 00:00:00 2001 From: Hassieb Pakzad <68423100+hassiebp@users.noreply.github.com> Date: Sat, 4 Apr 2026 23:11:36 +0200 Subject: [PATCH 21/23] fix(prompt-cache): avoid redundant refresh races --- langfuse/_client/client.py | 3 +- langfuse/_utils/prompt_cache.py | 68 ++++++++++++++++++++++++--------- tests/unit/test_prompt.py | 52 +++++++++++++++++++++++++ 3 files changed, 105 insertions(+), 18 deletions(-) diff --git a/langfuse/_client/client.py b/langfuse/_client/client.py index 85ec83a4e..a2760f8c0 100644 --- a/langfuse/_client/client.py +++ b/langfuse/_client/client.py @@ -3482,8 +3482,9 @@ def refresh_task() -> None: fetch_timeout_seconds=fetch_timeout_seconds, ) - self._resources.prompt_cache.add_refresh_prompt_task( + self._resources.prompt_cache.add_refresh_prompt_task_if_current( cache_key, + cached_prompt, refresh_task, ) langfuse_logger.debug( diff --git a/langfuse/_utils/prompt_cache.py b/langfuse/_utils/prompt_cache.py index b927bc22f..adefcaf83 100644 --- a/langfuse/_utils/prompt_cache.py +++ b/langfuse/_utils/prompt_cache.py @@ -4,7 +4,7 @@ import os from datetime import datetime from queue import Queue -from threading import Thread +from threading import RLock, Thread from typing import Callable, Dict, List, Optional, Set from langfuse._client.environment_variables import ( @@ -77,12 +77,14 @@ class PromptCacheTaskManager(object): _threads: int _queue: Queue _processing_keys: Set[str] + _lock: RLock def __init__(self, threads: int = 1): self._queue = Queue() self._consumers = [] self._threads = threads self._processing_keys = set() + self._lock = RLock() for i in range(self._threads): consumer = PromptCacheRefreshConsumer(self._queue, i) @@ -92,16 +94,20 @@ def __init__(self, threads: int = 1): atexit.register(self.shutdown) def add_task(self, key: str, task: Callable[[], None]) -> None: - if key not in self._processing_keys: - logger.debug(f"Adding prompt cache refresh task for key: {key}") - self._processing_keys.add(key) - wrapped_task = self._wrap_task(key, task) - self._queue.put((wrapped_task)) - else: - logger.debug(f"Prompt cache refresh task already submitted for key: {key}") + with self._lock: + if key not in self._processing_keys: + logger.debug(f"Adding prompt cache refresh task for key: {key}") + self._processing_keys.add(key) + wrapped_task = self._wrap_task(key, task) + self._queue.put((wrapped_task)) + else: + logger.debug( + f"Prompt cache refresh task already submitted for key: {key}" + ) def active_tasks(self) -> int: - return len(self._processing_keys) + with self._lock: + return len(self._processing_keys) def wait_for_idle(self) -> None: self._queue.join() @@ -112,7 +118,8 @@ def wrapped() -> None: try: task() finally: - self._processing_keys.remove(key) + with self._lock: + self._processing_keys.remove(key) logger.debug(f"Refreshed prompt cache for key: {key}") return wrapped @@ -139,6 +146,7 @@ def shutdown(self) -> None: class PromptCache: _cache: Dict[str, PromptCacheItem] + _lock: RLock _task_manager: PromptCacheTaskManager """Task manager for refreshing cache""" @@ -147,34 +155,60 @@ def __init__( self, max_prompt_refresh_workers: int = DEFAULT_PROMPT_CACHE_REFRESH_WORKERS ): self._cache = {} + self._lock = RLock() self._task_manager = PromptCacheTaskManager(threads=max_prompt_refresh_workers) logger.debug("Prompt cache initialized.") def get(self, key: str) -> Optional[PromptCacheItem]: - return self._cache.get(key, None) + with self._lock: + return self._cache.get(key, None) def set(self, key: str, value: PromptClient, ttl_seconds: Optional[int]) -> None: if ttl_seconds is None: ttl_seconds = DEFAULT_PROMPT_CACHE_TTL_SECONDS - self._cache[key] = PromptCacheItem(value, ttl_seconds) + with self._lock: + self._cache[key] = PromptCacheItem(value, ttl_seconds) def delete(self, key: str) -> None: - self._cache.pop(key, None) + with self._lock: + self._cache.pop(key, None) def invalidate(self, prompt_name: str) -> None: """Invalidate all cached prompts with the given prompt name.""" - for key in list(self._cache): - if key.startswith(prompt_name): - del self._cache[key] + with self._lock: + for key in list(self._cache): + if key.startswith(prompt_name): + del self._cache[key] def add_refresh_prompt_task(self, key: str, fetch_func: Callable[[], None]) -> None: logger.debug(f"Submitting refresh task for key: {key}") self._task_manager.add_task(key, fetch_func) + def add_refresh_prompt_task_if_current( + self, + key: str, + expected_item: PromptCacheItem, + fetch_func: Callable[[], None], + ) -> None: + with self._lock: + current_item = self._cache.get(key) + if ( + current_item is not None + and current_item is not expected_item + and not current_item.is_expired() + ): + logger.debug( + f"Skipping refresh task for key: {key} because cache is already fresh." + ) + return + + self.add_refresh_prompt_task(key, fetch_func) + def clear(self) -> None: """Clear the entire prompt cache, removing all cached prompts.""" - self._cache.clear() + with self._lock: + self._cache.clear() @staticmethod def generate_cache_key( diff --git a/tests/unit/test_prompt.py b/tests/unit/test_prompt.py index dca601b06..508e6e2f3 100644 --- a/tests/unit/test_prompt.py +++ b/tests/unit/test_prompt.py @@ -492,6 +492,58 @@ def test_get_stale_prompt_when_expired_cache_default_ttl(mock_time, langfuse: La assert updated_result == TextPromptClient(updated_prompt) +@patch.object(PromptCacheItem, "get_epoch_seconds") +def test_skip_redundant_refresh_when_cache_already_updated( + mock_time, langfuse: Langfuse +) -> None: + prompt_name = "test_skip_redundant_refresh_when_cache_already_updated" + cache_key = PromptCache.generate_cache_key(prompt_name, version=None, label=None) + + mock_time.return_value = 0 + + initial_prompt = Prompt_Text( + name=prompt_name, + version=1, + prompt="Make me laugh", + labels=[], + type="text", + config={}, + tags=[], + ) + updated_prompt = Prompt_Text( + name=prompt_name, + version=2, + prompt="Make me laugh", + labels=[], + type="text", + config={}, + tags=[], + ) + + stale_result = TextPromptClient(initial_prompt) + fresh_result = TextPromptClient(updated_prompt) + + langfuse._resources.prompt_cache.set(cache_key, stale_result, None) + stale_item = langfuse._resources.prompt_cache.get(cache_key) + assert stale_item is not None + + mock_time.return_value = DEFAULT_PROMPT_CACHE_TTL_SECONDS + 1 + assert stale_item.is_expired() + + langfuse._resources.prompt_cache.set(cache_key, fresh_result, None) + + add_task_mock = Mock() + langfuse._resources.prompt_cache._task_manager.add_task = add_task_mock + + langfuse._resources.prompt_cache.add_refresh_prompt_task_if_current( + cache_key, + stale_item, + Mock(), + ) + + add_task_mock.assert_not_called() + + @patch.object(PromptCacheItem, "get_epoch_seconds") def test_get_fresh_prompt_when_expired_cache_default_ttl(mock_time, langfuse: Langfuse): mock_time.return_value = 0 From 6fed925e43b46b38bf9348b79f758a1843f5affc Mon Sep 17 00:00:00 2001 From: Hassieb Pakzad <68423100+hassiebp@users.noreply.github.com> Date: Sat, 4 Apr 2026 23:27:39 +0200 Subject: [PATCH 22/23] ci: migrate GitHub Actions to Blacksmith runners --- .github/workflows/ci.yml | 10 +++++----- AGENTS.md | 1 + 2 files changed, 6 insertions(+), 5 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 9b2b4ad03..3a48515fa 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -16,7 +16,7 @@ concurrency: jobs: linting: - runs-on: ubuntu-latest + runs-on: blacksmith-4vcpu-ubuntu-2404 steps: - uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6 - name: Install uv and set Python version @@ -31,7 +31,7 @@ jobs: run: uv run --frozen ruff check . type-checking: - runs-on: ubuntu-latest + runs-on: blacksmith-4vcpu-ubuntu-2404 steps: - uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6 - name: Install uv and set Python version @@ -53,7 +53,7 @@ jobs: run: uv run --frozen mypy langfuse --no-error-summary unit-tests: - runs-on: ubuntu-latest + runs-on: blacksmith-4vcpu-ubuntu-2404 timeout-minutes: 30 env: LANGFUSE_BASE_URL: "http://localhost:3000" @@ -92,7 +92,7 @@ jobs: uv run --frozen pytest -n auto --dist worksteal -s -v --log-cli-level=INFO tests/unit e2e-tests: - runs-on: ubuntu-latest + runs-on: blacksmith-4vcpu-ubuntu-2404 timeout-minutes: 30 strategy: fail-fast: false @@ -240,7 +240,7 @@ jobs: all-tests-passed: # This allows us to have a branch protection rule for tests and deploys with matrix - runs-on: ubuntu-latest + runs-on: blacksmith-4vcpu-ubuntu-2404 needs: [unit-tests, e2e-tests, linting, type-checking] if: always() steps: diff --git a/AGENTS.md b/AGENTS.md index 807d1674b..977adf5a9 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -92,6 +92,7 @@ uv run --frozen pytest tests/unit/test_resource_manager.py::test_pause_signals_s The main CI workflow currently runs: +- on Blacksmith Ubuntu 24.04 runners - linting on Python 3.13 - mypy on Python 3.13 - `tests/unit` on a Python 3.10-3.14 matrix From 170beb8201650d61aa3ed21aad5bbab39a1474ba Mon Sep 17 00:00:00 2001 From: Hassieb Pakzad <68423100+hassiebp@users.noreply.github.com> Date: Sat, 4 Apr 2026 23:36:13 +0200 Subject: [PATCH 23/23] revert(ci): switch back from Blacksmith runners --- .github/workflows/ci.yml | 10 +++++----- AGENTS.md | 2 -- 2 files changed, 5 insertions(+), 7 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 3a48515fa..9b2b4ad03 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -16,7 +16,7 @@ concurrency: jobs: linting: - runs-on: blacksmith-4vcpu-ubuntu-2404 + runs-on: ubuntu-latest steps: - uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6 - name: Install uv and set Python version @@ -31,7 +31,7 @@ jobs: run: uv run --frozen ruff check . type-checking: - runs-on: blacksmith-4vcpu-ubuntu-2404 + runs-on: ubuntu-latest steps: - uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6 - name: Install uv and set Python version @@ -53,7 +53,7 @@ jobs: run: uv run --frozen mypy langfuse --no-error-summary unit-tests: - runs-on: blacksmith-4vcpu-ubuntu-2404 + runs-on: ubuntu-latest timeout-minutes: 30 env: LANGFUSE_BASE_URL: "http://localhost:3000" @@ -92,7 +92,7 @@ jobs: uv run --frozen pytest -n auto --dist worksteal -s -v --log-cli-level=INFO tests/unit e2e-tests: - runs-on: blacksmith-4vcpu-ubuntu-2404 + runs-on: ubuntu-latest timeout-minutes: 30 strategy: fail-fast: false @@ -240,7 +240,7 @@ jobs: all-tests-passed: # This allows us to have a branch protection rule for tests and deploys with matrix - runs-on: blacksmith-4vcpu-ubuntu-2404 + runs-on: ubuntu-latest needs: [unit-tests, e2e-tests, linting, type-checking] if: always() steps: diff --git a/AGENTS.md b/AGENTS.md index 977adf5a9..eff827d76 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -91,8 +91,6 @@ uv run --frozen pytest tests/unit/test_resource_manager.py::test_pause_signals_s ## CI Contract The main CI workflow currently runs: - -- on Blacksmith Ubuntu 24.04 runners - linting on Python 3.13 - mypy on Python 3.13 - `tests/unit` on a Python 3.10-3.14 matrix