From db0deb6204c9c22fd2d0a58d84ba74f84cd091ff Mon Sep 17 00:00:00 2001 From: gotsysdba Date: Fri, 28 Nov 2025 21:43:45 +0000 Subject: [PATCH 01/20] centralise testing fake passwords --- test/integration/server/api/conftest.py | 3 +- test/integration/server/bootstrap/conftest.py | 12 ++++--- test/shared_fixtures.py | 33 ++++++++++++++++--- test/unit/server/api/conftest.py | 9 +++-- .../server/api/utils/test_utils_databases.py | 5 +-- test/unit/server/api/utils/test_utils_mcp.py | 18 +++++----- test/unit/server/api/v1/test_v1_mcp.py | 4 ++- .../server/bootstrap/test_bootstrap_models.py | 12 +++---- 8 files changed, 67 insertions(+), 29 deletions(-) diff --git a/test/integration/server/api/conftest.py b/test/integration/server/api/conftest.py index 6b25297f..7b20baf6 100644 --- a/test/integration/server/api/conftest.py +++ b/test/integration/server/api/conftest.py @@ -23,6 +23,7 @@ make_database, make_model, DEFAULT_LL_MODEL_CONFIG, + TEST_AUTH_TOKEN, ) import pytest @@ -42,7 +43,7 @@ # Test configuration - extends shared DB config with integration-specific settings TEST_CONFIG = { "client": "integration_test", - "auth_token": "integration-test-token", + "auth_token": TEST_AUTH_TOKEN, **TEST_DB_CONFIG, } diff --git a/test/integration/server/bootstrap/conftest.py b/test/integration/server/bootstrap/conftest.py index 00848e66..4049fd39 100644 --- a/test/integration/server/bootstrap/conftest.py +++ b/test/integration/server/bootstrap/conftest.py @@ -21,6 +21,10 @@ clean_env, BOOTSTRAP_ENV_VARS, DEFAULT_LL_MODEL_CONFIG, + TEST_INTEGRATION_DB_USER, + TEST_INTEGRATION_DB_PASSWORD, + TEST_INTEGRATION_DB_DSN, + TEST_API_KEY_ALT, ) import pytest @@ -107,9 +111,9 @@ def sample_database_config(): """Sample database configuration dict.""" return { "name": "INTEGRATION_DB", - "user": "integration_user", - "password": "integration_pass", - "dsn": "localhost:1521/INTPDB", + "user": TEST_INTEGRATION_DB_USER, + "password": TEST_INTEGRATION_DB_PASSWORD, + "dsn": TEST_INTEGRATION_DB_DSN, } @@ -121,7 +125,7 @@ def sample_model_config(): "type": "ll", "provider": "openai", "enabled": True, - "api_key": "test-api-key", + "api_key": TEST_API_KEY_ALT, "api_base": "https://api.openai.com/v1", "max_tokens": 4096, } diff --git a/test/shared_fixtures.py b/test/shared_fixtures.py index 22953e39..8c2a152d 100644 --- a/test/shared_fixtures.py +++ b/test/shared_fixtures.py @@ -28,6 +28,31 @@ from server.bootstrap.configfile import ConfigStore +################################################# +# Test Credentials Constants +################################################# +# Centralized fake credentials for testing. +# These are NOT real secrets - they are placeholder values used in tests. +# Using constants ensures consistent values across tests and allows +# security scanners to be configured to ignore this single location. + +# Database credentials (fake - for testing only) +TEST_DB_USER = "test_user" +TEST_DB_PASSWORD = "test_password" # noqa: S105 - not a real password +TEST_DB_DSN = "localhost:1521/TESTPDB" +TEST_DB_WALLET_PASSWORD = "test_wallet_pass" # noqa: S105 - not a real password + +# API keys (fake - for testing only) +TEST_API_KEY = "test-key" # noqa: S105 - not a real API key +TEST_API_KEY_ALT = "test-api-key" # noqa: S105 - not a real API key +TEST_AUTH_TOKEN = "integration-test-token" # noqa: S105 - not a real token + +# Integration test database credentials (fake - for testing only) +TEST_INTEGRATION_DB_USER = "integration_user" +TEST_INTEGRATION_DB_PASSWORD = "integration_pass" # noqa: S105 - not a real password +TEST_INTEGRATION_DB_DSN = "localhost:1521/INTPDB" + + # Default test model settings - shared across test fixtures DEFAULT_LL_MODEL_CONFIG = { "model": "gpt-4o-mini", @@ -78,9 +103,9 @@ def make_database(): def _make_database( name: str = "TEST_DB", - user: str = "test_user", - password: str = "test_password", - dsn: str = "localhost:1521/TESTPDB", + user: str = TEST_DB_USER, + password: str = TEST_DB_PASSWORD, + dsn: str = TEST_DB_DSN, wallet_password: str = None, **kwargs, ) -> Database: @@ -108,7 +133,7 @@ def _make_model( model_type: str = "ll", provider: str = "openai", enabled: bool = True, - api_key: str = "test-key", + api_key: str = TEST_API_KEY, api_base: str = "https://api.openai.com/v1", **kwargs, ) -> Model: diff --git a/test/unit/server/api/conftest.py b/test/unit/server/api/conftest.py index 8ec4542a..d10ed367 100644 --- a/test/unit/server/api/conftest.py +++ b/test/unit/server/api/conftest.py @@ -19,6 +19,9 @@ make_ll_settings, make_settings, make_configuration, + TEST_DB_USER, + TEST_DB_PASSWORD, + TEST_DB_DSN, ) import pytest @@ -36,9 +39,9 @@ def make_database_auth(): def _make_database_auth(**overrides) -> DatabaseAuth: defaults = { - "user": "test_user", - "password": "test_password", - "dsn": "localhost:1521/TESTPDB", + "user": TEST_DB_USER, + "password": TEST_DB_PASSWORD, + "dsn": TEST_DB_DSN, "wallet_password": None, } defaults.update(overrides) diff --git a/test/unit/server/api/utils/test_utils_databases.py b/test/unit/server/api/utils/test_utils_databases.py index ebc92d80..62f54a32 100644 --- a/test/unit/server/api/utils/test_utils_databases.py +++ b/test/unit/server/api/utils/test_utils_databases.py @@ -13,6 +13,7 @@ # pylint: disable=too-few-public-methods from test.conftest import TEST_CONFIG +from test.shared_fixtures import TEST_DB_WALLET_PASSWORD from unittest.mock import patch, MagicMock import pytest @@ -177,7 +178,7 @@ def test_connect_raises_permission_error_invalid_credentials(self, db_container, # pylint: disable=unused-argument config = make_database( user="INVALID_USER", - password="wrong_password", + password=TEST_DB_WALLET_PASSWORD, # Using a fake password for invalid login test dsn=TEST_CONFIG["db_dsn"], ) @@ -214,7 +215,7 @@ def test_connect_wallet_location_defaults_to_config_dir(self, mock_connect, make """connect should default wallet_location to config_dir if not set (mocked - verifies call args).""" mock_conn = MagicMock() mock_connect.return_value = mock_conn - config = make_database(wallet_password="secret", config_dir="/path/to/config") + config = make_database(wallet_password=TEST_DB_WALLET_PASSWORD, config_dir="/path/to/config") utils_databases.connect(config) diff --git a/test/unit/server/api/utils/test_utils_mcp.py b/test/unit/server/api/utils/test_utils_mcp.py index 0301a321..f2c69a5a 100644 --- a/test/unit/server/api/utils/test_utils_mcp.py +++ b/test/unit/server/api/utils/test_utils_mcp.py @@ -10,13 +10,15 @@ import os import pytest +from test.shared_fixtures import TEST_API_KEY, TEST_API_KEY_ALT + from server.api.utils import mcp class TestGetClient: """Tests for the get_client function.""" - @patch.dict(os.environ, {"API_SERVER_KEY": "test-api-key"}) + @patch.dict(os.environ, {"API_SERVER_KEY": TEST_API_KEY_ALT}) def test_get_client_default_values(self): """get_client should return default configuration.""" result = mcp.get_client() @@ -27,23 +29,23 @@ def test_get_client_default_values(self): assert result["mcpServers"]["optimizer"]["transport"] == "streamable_http" assert "http://127.0.0.1:8000/mcp/" in result["mcpServers"]["optimizer"]["url"] - @patch.dict(os.environ, {"API_SERVER_KEY": "test-api-key"}) + @patch.dict(os.environ, {"API_SERVER_KEY": TEST_API_KEY_ALT}) def test_get_client_custom_server_port(self): """get_client should use custom server and port.""" result = mcp.get_client(server="http://custom.server", port=9000) assert "http://custom.server:9000/mcp/" in result["mcpServers"]["optimizer"]["url"] - @patch.dict(os.environ, {"API_SERVER_KEY": "secret-key"}) + @patch.dict(os.environ, {"API_SERVER_KEY": TEST_API_KEY_ALT}) def test_get_client_includes_auth_header(self): """get_client should include authorization header.""" result = mcp.get_client() headers = result["mcpServers"]["optimizer"]["headers"] assert "Authorization" in headers - assert headers["Authorization"] == "Bearer secret-key" + assert headers["Authorization"] == f"Bearer {TEST_API_KEY_ALT}" - @patch.dict(os.environ, {"API_SERVER_KEY": "test-key"}) + @patch.dict(os.environ, {"API_SERVER_KEY": TEST_API_KEY}) def test_get_client_langgraph_removes_type(self): """get_client should remove type field for langgraph client.""" result = mcp.get_client(client="langgraph") @@ -51,14 +53,14 @@ def test_get_client_langgraph_removes_type(self): assert "type" not in result["mcpServers"]["optimizer"] assert "transport" in result["mcpServers"]["optimizer"] - @patch.dict(os.environ, {"API_SERVER_KEY": "test-key"}) + @patch.dict(os.environ, {"API_SERVER_KEY": TEST_API_KEY}) def test_get_client_non_langgraph_keeps_type(self): """get_client should keep type field for non-langgraph clients.""" result = mcp.get_client(client="other") assert "type" in result["mcpServers"]["optimizer"] - @patch.dict(os.environ, {"API_SERVER_KEY": "test-key"}) + @patch.dict(os.environ, {"API_SERVER_KEY": TEST_API_KEY}) def test_get_client_none_client_keeps_type(self): """get_client should keep type field when client is None.""" result = mcp.get_client(client=None) @@ -73,7 +75,7 @@ def test_get_client_empty_api_key(self): headers = result["mcpServers"]["optimizer"]["headers"] assert headers["Authorization"] == "Bearer " - @patch.dict(os.environ, {"API_SERVER_KEY": "key"}) + @patch.dict(os.environ, {"API_SERVER_KEY": TEST_API_KEY}) def test_get_client_structure(self): """get_client should return expected structure.""" result = mcp.get_client() diff --git a/test/unit/server/api/v1/test_v1_mcp.py b/test/unit/server/api/v1/test_v1_mcp.py index dc10c82b..8cd6ba50 100644 --- a/test/unit/server/api/v1/test_v1_mcp.py +++ b/test/unit/server/api/v1/test_v1_mcp.py @@ -11,6 +11,8 @@ from unittest.mock import patch, MagicMock, AsyncMock import pytest +from test.shared_fixtures import TEST_API_KEY + from server.api.v1 import mcp @@ -41,7 +43,7 @@ async def test_get_client_returns_config(self, mock_get_client): "type": "streamableHttp", "transport": "streamable_http", "url": "http://127.0.0.1:8000/mcp/", - "headers": {"Authorization": "Bearer test-key"}, + "headers": {"Authorization": f"Bearer {TEST_API_KEY}"}, } } } diff --git a/test/unit/server/bootstrap/test_bootstrap_models.py b/test/unit/server/bootstrap/test_bootstrap_models.py index 8f8a8b1d..4950b09c 100644 --- a/test/unit/server/bootstrap/test_bootstrap_models.py +++ b/test/unit/server/bootstrap/test_bootstrap_models.py @@ -11,7 +11,7 @@ import os from unittest.mock import patch -from test.shared_fixtures import assert_model_list_valid, get_model_by_id +from test.shared_fixtures import assert_model_list_valid, get_model_by_id, TEST_API_KEY import pytest @@ -38,13 +38,13 @@ def test_main_includes_base_models(self): def test_main_enables_models_with_api_keys(self): """main() should enable models when API keys are present.""" - os.environ["OPENAI_API_KEY"] = "test-openai-key" + os.environ["OPENAI_API_KEY"] = TEST_API_KEY try: model_list = models_module.main() gpt_model = get_model_by_id(model_list, "gpt-4o-mini") assert gpt_model.enabled is True - assert gpt_model.api_key == "test-openai-key" + assert gpt_model.api_key == TEST_API_KEY finally: del os.environ["OPENAI_API_KEY"] @@ -57,7 +57,7 @@ def test_main_disables_models_without_api_keys(self): @pytest.mark.usefixtures("reset_config_store", "clean_env") def test_main_checks_url_accessibility(self): """main() should check URL accessibility for enabled models.""" - os.environ["OPENAI_API_KEY"] = "test-key" + os.environ["OPENAI_API_KEY"] = TEST_API_KEY with patch("server.bootstrap.models.is_url_accessible") as mock_accessible: mock_accessible.return_value = (False, "Connection refused") @@ -73,8 +73,8 @@ def test_main_checks_url_accessibility(self): @pytest.mark.usefixtures("reset_config_store", "clean_env") def test_main_caches_url_accessibility_results(self): """main() should cache URL accessibility results for same URLs.""" - os.environ["OPENAI_API_KEY"] = "test-key" - os.environ["COHERE_API_KEY"] = "test-key" + os.environ["OPENAI_API_KEY"] = TEST_API_KEY + os.environ["COHERE_API_KEY"] = TEST_API_KEY with patch("server.bootstrap.models.is_url_accessible") as mock_accessible: mock_accessible.return_value = (True, "OK") From f2a9d0af4a91178cd70f1c76da5a5df1492e4112 Mon Sep 17 00:00:00 2001 From: gotsysdba Date: Fri, 28 Nov 2025 23:17:44 +0000 Subject: [PATCH 02/20] increase coverage --- test/integration/server/api/v1/test_embed.py | 173 ++++++++++++++ test/unit/server/api/utils/test_utils_oci.py | 225 ++++++++++++++++++ test/unit/server/api/v1/test_v1_embed.py | 235 ++++++++++++++++++- 3 files changed, 632 insertions(+), 1 deletion(-) create mode 100644 test/integration/server/api/v1/test_embed.py diff --git a/test/integration/server/api/v1/test_embed.py b/test/integration/server/api/v1/test_embed.py new file mode 100644 index 00000000..debdb019 --- /dev/null +++ b/test/integration/server/api/v1/test_embed.py @@ -0,0 +1,173 @@ +""" +Copyright (c) 2024, 2025, Oracle and/or its affiliates. +Licensed under the Universal Permissive License v1.0 as shown at http://oss.oracle.com/licenses/upl. + +Integration tests for server/api/v1/embed.py + +Tests the embedding and vector store endpoints through the full API stack. +These endpoints require authentication. +""" + + +class TestEmbedDropVs: + """Integration tests for the embed_drop_vs endpoint.""" + + def test_embed_drop_vs_requires_auth(self, client): + """DELETE /v1/embed/{vs} should require authentication.""" + response = client.delete("/v1/embed/VS_TEST") + + assert response.status_code == 401 + + def test_embed_drop_vs_rejects_invalid_token(self, client, auth_headers): + """DELETE /v1/embed/{vs} should reject invalid tokens.""" + response = client.delete("/v1/embed/VS_TEST", headers=auth_headers["invalid_auth"]) + + assert response.status_code == 401 + + +class TestEmbedGetFiles: + """Integration tests for the embed_get_files endpoint.""" + + def test_embed_get_files_requires_auth(self, client): + """GET /v1/embed/{vs}/files should require authentication.""" + response = client.get("/v1/embed/VS_TEST/files") + + assert response.status_code == 401 + + def test_embed_get_files_rejects_invalid_token(self, client, auth_headers): + """GET /v1/embed/{vs}/files should reject invalid tokens.""" + response = client.get("/v1/embed/VS_TEST/files", headers=auth_headers["invalid_auth"]) + + assert response.status_code == 401 + + +class TestCommentVs: + """Integration tests for the comment_vs endpoint.""" + + def test_comment_vs_requires_auth(self, client): + """PATCH /v1/embed/comment should require authentication.""" + response = client.patch( + "/v1/embed/comment", + json={"vector_store": "VS_TEST", "model": "text-embedding-3", "chunk_size": 1000, "chunk_overlap": 200}, + ) + + assert response.status_code == 401 + + def test_comment_vs_rejects_invalid_token(self, client, auth_headers): + """PATCH /v1/embed/comment should reject invalid tokens.""" + response = client.patch( + "/v1/embed/comment", + headers=auth_headers["invalid_auth"], + json={"vector_store": "VS_TEST", "model": "text-embedding-3", "chunk_size": 1000, "chunk_overlap": 200}, + ) + + assert response.status_code == 401 + + +class TestStoreSqlFile: + """Integration tests for the store_sql_file endpoint.""" + + def test_store_sql_file_requires_auth(self, client): + """POST /v1/embed/sql/store should require authentication.""" + response = client.post("/v1/embed/sql/store", json=["conn_str", "SELECT 1"]) + + assert response.status_code == 401 + + def test_store_sql_file_rejects_invalid_token(self, client, auth_headers): + """POST /v1/embed/sql/store should reject invalid tokens.""" + response = client.post( + "/v1/embed/sql/store", + headers=auth_headers["invalid_auth"], + json=["conn_str", "SELECT 1"], + ) + + assert response.status_code == 401 + + +class TestStoreWebFile: + """Integration tests for the store_web_file endpoint.""" + + def test_store_web_file_requires_auth(self, client): + """POST /v1/embed/web/store should require authentication.""" + response = client.post("/v1/embed/web/store", json=["https://example.com/doc.pdf"]) + + assert response.status_code == 401 + + def test_store_web_file_rejects_invalid_token(self, client, auth_headers): + """POST /v1/embed/web/store should reject invalid tokens.""" + response = client.post( + "/v1/embed/web/store", + headers=auth_headers["invalid_auth"], + json=["https://example.com/doc.pdf"], + ) + + assert response.status_code == 401 + + +class TestStoreLocalFile: + """Integration tests for the store_local_file endpoint.""" + + def test_store_local_file_requires_auth(self, client): + """POST /v1/embed/local/store should require authentication.""" + response = client.post( + "/v1/embed/local/store", + files={"files": ("test.txt", b"Test content", "text/plain")}, + ) + + assert response.status_code == 401 + + def test_store_local_file_rejects_invalid_token(self, client, auth_headers): + """POST /v1/embed/local/store should reject invalid tokens.""" + response = client.post( + "/v1/embed/local/store", + headers=auth_headers["invalid_auth"], + files={"files": ("test.txt", b"Test content", "text/plain")}, + ) + + assert response.status_code == 401 + + +class TestSplitEmbed: + """Integration tests for the split_embed endpoint.""" + + def test_split_embed_requires_auth(self, client): + """POST /v1/embed should require authentication.""" + response = client.post( + "/v1/embed", + json={"model": "text-embedding-3", "chunk_size": 1000, "chunk_overlap": 200}, + ) + + assert response.status_code == 401 + + def test_split_embed_rejects_invalid_token(self, client, auth_headers): + """POST /v1/embed should reject invalid tokens.""" + response = client.post( + "/v1/embed", + headers=auth_headers["invalid_auth"], + json={"model": "text-embedding-3", "chunk_size": 1000, "chunk_overlap": 200}, + ) + + assert response.status_code == 401 + + +class TestRefreshVectorStore: + """Integration tests for the refresh_vector_store endpoint.""" + + def test_refresh_vector_store_requires_auth(self, client): + """POST /v1/embed/refresh should require authentication.""" + response = client.post( + "/v1/embed/refresh", + json={"vector_store_alias": "test_alias", "bucket_name": "test-bucket"}, + ) + + assert response.status_code == 401 + + def test_refresh_vector_store_rejects_invalid_token(self, client, auth_headers): + """POST /v1/embed/refresh should reject invalid tokens.""" + response = client.post( + "/v1/embed/refresh", + headers=auth_headers["invalid_auth"], + json={"vector_store_alias": "test_alias", "bucket_name": "test-bucket"}, + ) + + assert response.status_code == 401 diff --git a/test/unit/server/api/utils/test_utils_oci.py b/test/unit/server/api/utils/test_utils_oci.py index 3e6d9f2b..d11acbad 100644 --- a/test/unit/server/api/utils/test_utils_oci.py +++ b/test/unit/server/api/utils/test_utils_oci.py @@ -583,6 +583,231 @@ def test_init_genai_client_calls_init_client(self, mock_init_client, make_oci_co assert result == mock_client +class TestInitClientSecurityToken: + """Tests for init_client with security token authentication.""" + + @patch("server.api.utils.oci.get_signer") + @patch("server.api.utils.oci.oci.signer.load_private_key_from_file") + @patch("server.api.utils.oci.oci.auth.signers.SecurityTokenSigner") + @patch("server.api.utils.oci.oci.object_storage.ObjectStorageClient") + @patch("builtins.open", create=True) + def test_init_client_security_token_auth( + self, mock_open, mock_client_class, mock_sec_token_signer, mock_load_key, mock_get_signer, make_oci_config + ): + """init_client should use security token authentication when configured.""" + mock_get_signer.return_value = None + mock_open.return_value.__enter__ = MagicMock(return_value=MagicMock(read=MagicMock(return_value="token_data"))) + mock_open.return_value.__exit__ = MagicMock(return_value=False) + mock_private_key = MagicMock() + mock_load_key.return_value = mock_private_key + mock_signer = MagicMock() + mock_sec_token_signer.return_value = mock_signer + mock_client = MagicMock() + mock_client_class.return_value = mock_client + + config = make_oci_config() + config.authentication = "security_token" + config.security_token_file = "/path/to/token" + config.key_file = "/path/to/key" + config.region = "us-ashburn-1" + + result = utils_oci.init_client(oci.object_storage.ObjectStorageClient, config) + + assert result == mock_client + mock_sec_token_signer.assert_called_once() + + +class TestInitClientOkeWorkloadIdentityTenancy: + """Tests for init_client OKE workload identity tenancy extraction.""" + + @patch("server.api.utils.oci.get_signer") + @patch("server.api.utils.oci.oci.object_storage.ObjectStorageClient") + def test_init_client_oke_workload_extracts_tenancy(self, mock_client_class, mock_get_signer, make_oci_config): + """init_client should extract tenancy from OKE workload identity token.""" + import base64 + import json + + # Create a mock JWT token with tenant claim + payload = {"tenant": "ocid1.tenancy.oc1..test"} + payload_json = json.dumps(payload) + payload_b64 = base64.urlsafe_b64encode(payload_json.encode()).decode().rstrip("=") + mock_token = f"header.{payload_b64}.signature" + + mock_signer = MagicMock() + mock_signer.get_security_token.return_value = mock_token + mock_get_signer.return_value = mock_signer + mock_client = MagicMock() + mock_client_class.return_value = mock_client + + config = make_oci_config() + config.authentication = "oke_workload_identity" + config.region = "us-ashburn-1" + config.tenancy = None # Not set, should be extracted from token + + utils_oci.init_client(oci.object_storage.ObjectStorageClient, config) + + assert config.tenancy == "ocid1.tenancy.oc1..test" + + +class TestGetNamespaceExceptionHandling: + """Tests for get_namespace exception handling.""" + + @patch("server.api.utils.oci.init_client") + def test_get_namespace_raises_on_unbound_local_error(self, mock_init_client, make_oci_config): + """get_namespace should raise OciException on UnboundLocalError.""" + mock_init_client.side_effect = UnboundLocalError("Client not initialized") + config = make_oci_config() + + with pytest.raises(OciException) as exc_info: + utils_oci.get_namespace(config) + + assert exc_info.value.status_code == 500 + assert "No Configuration" in exc_info.value.detail + + @patch("server.api.utils.oci.init_client") + def test_get_namespace_raises_on_request_exception(self, mock_init_client, make_oci_config): + """get_namespace should raise OciException on RequestException.""" + mock_client = MagicMock() + mock_client.get_namespace.side_effect = oci.exceptions.RequestException("Connection timeout") + mock_init_client.return_value = mock_client + config = make_oci_config() + + with pytest.raises(OciException) as exc_info: + utils_oci.get_namespace(config) + + assert exc_info.value.status_code == 503 + + @patch("server.api.utils.oci.init_client") + def test_get_namespace_raises_on_generic_exception(self, mock_init_client, make_oci_config): + """get_namespace should raise OciException on generic Exception.""" + mock_client = MagicMock() + mock_client.get_namespace.side_effect = RuntimeError("Unexpected error") + mock_init_client.return_value = mock_client + config = make_oci_config() + + with pytest.raises(OciException) as exc_info: + utils_oci.get_namespace(config) + + assert exc_info.value.status_code == 500 + assert "Unexpected error" in exc_info.value.detail + + +class TestGetGenaiModelsExceptionHandling: + """Tests for get_genai_models exception handling.""" + + @patch("server.api.utils.oci.init_client") + def test_get_genai_models_handles_service_error(self, mock_init_client, make_oci_config): + """get_genai_models should handle ServiceError gracefully.""" + mock_client = MagicMock() + mock_client.list_models.side_effect = oci.exceptions.ServiceError( + status=403, code="NotAuthorized", headers={}, message="Not authorized" + ) + mock_init_client.return_value = mock_client + + config = make_oci_config(genai_region="us-chicago-1") + config.genai_compartment_id = "ocid1.compartment.oc1..test" + + result = utils_oci.get_genai_models(config, regional=True) + + # Should return empty list instead of raising + assert result == [] + + @patch("server.api.utils.oci.init_client") + def test_get_genai_models_handles_request_exception(self, mock_init_client, make_oci_config): + """get_genai_models should handle RequestException gracefully.""" + import urllib3.exceptions + + mock_client = MagicMock() + mock_client.list_models.side_effect = urllib3.exceptions.MaxRetryError(None, "url") + mock_init_client.return_value = mock_client + + config = make_oci_config(genai_region="us-chicago-1") + config.genai_compartment_id = "ocid1.compartment.oc1..test" + + result = utils_oci.get_genai_models(config, regional=True) + + # Should return empty list instead of raising + assert result == [] + + @patch("server.api.utils.oci.init_client") + def test_get_genai_models_excludes_deprecated(self, mock_init_client, make_oci_config): + """get_genai_models should exclude deprecated models.""" + mock_active_model = MagicMock() + mock_active_model.display_name = "active-model" + mock_active_model.capabilities = ["TEXT_GENERATION"] + mock_active_model.vendor = "cohere" + mock_active_model.id = "ocid1.model.active" + mock_active_model.time_deprecated = None + mock_active_model.time_dedicated_retired = None + mock_active_model.time_on_demand_retired = None + + mock_deprecated_model = MagicMock() + mock_deprecated_model.display_name = "deprecated-model" + mock_deprecated_model.capabilities = ["TEXT_GENERATION"] + mock_deprecated_model.vendor = "cohere" + mock_deprecated_model.id = "ocid1.model.deprecated" + mock_deprecated_model.time_deprecated = datetime(2024, 1, 1) + mock_deprecated_model.time_dedicated_retired = None + mock_deprecated_model.time_on_demand_retired = None + + mock_response = MagicMock() + mock_response.data.items = [mock_active_model, mock_deprecated_model] + + mock_client = MagicMock() + mock_client.list_models.return_value = mock_response + mock_init_client.return_value = mock_client + + config = make_oci_config(genai_region="us-chicago-1") + config.genai_compartment_id = "ocid1.compartment.oc1..test" + + result = utils_oci.get_genai_models(config, regional=True) + + assert len(result) == 1 + assert result[0]["model_name"] == "active-model" + + +class TestGetBucketObjectsWithMetadataServiceError: + """Tests for get_bucket_objects_with_metadata service error handling.""" + + @patch("server.api.utils.oci.init_client") + def test_get_bucket_objects_with_metadata_returns_empty_on_service_error(self, mock_init_client, make_oci_config): + """get_bucket_objects_with_metadata should return empty list on ServiceError.""" + mock_client = MagicMock() + mock_client.list_objects.side_effect = oci.exceptions.ServiceError( + status=404, code="BucketNotFound", headers={}, message="Bucket not found" + ) + mock_init_client.return_value = mock_client + + config = make_oci_config() + config.namespace = "test-namespace" + + result = utils_oci.get_bucket_objects_with_metadata("nonexistent-bucket", config) + + assert result == [] + + +class TestGetClientDerivedAuthProfileNoMatch: + """Tests for get function when derived auth profile has no matching OCI config.""" + + @patch("server.api.utils.oci.bootstrap.SETTINGS_OBJECTS") + @patch("server.api.utils.oci.bootstrap.OCI_OBJECTS") + def test_get_raises_when_derived_profile_not_found(self, mock_oci, mock_settings, make_oci_config, make_settings): + """get should raise ValueError when client's derived auth_profile has no matching OCI config.""" + settings = make_settings(client="test_client") + settings.oci.auth_profile = "MISSING_PROFILE" + mock_settings.__iter__ = lambda _: iter([settings]) + mock_settings.__len__ = lambda _: 1 + + # OCI config with different profile + oci_config = make_oci_config(auth_profile="OTHER_PROFILE") + mock_oci.__iter__ = lambda _: iter([oci_config]) + + with pytest.raises(ValueError) as exc_info: + utils_oci.get(client="test_client") + + assert "No settings found for client" in str(exc_info.value) + + class TestLoggerConfiguration: """Tests for logger configuration.""" diff --git a/test/unit/server/api/v1/test_v1_embed.py b/test/unit/server/api/v1/test_v1_embed.py index 17a442c9..094d01ae 100644 --- a/test/unit/server/api/v1/test_v1_embed.py +++ b/test/unit/server/api/v1/test_v1_embed.py @@ -315,7 +315,6 @@ async def test_store_web_file_pdf_success(self, mock_session_class, mock_get_tem assert result.status_code == 200 - class TestStoreLocalFile: """Tests for the store_local_file endpoint.""" @@ -456,6 +455,109 @@ async def test_split_embed_raises_500_on_value_error( assert exc_info.value.status_code == 500 + @pytest.mark.asyncio + @patch("server.api.v1.embed.utils_oci.get") + @patch("server.api.v1.embed.utils_embed.get_temp_directory") + @patch("server.api.v1.embed.utils_embed.load_and_split_documents") + @patch("shutil.rmtree") + async def test_split_embed_raises_500_on_runtime_error( + self, _mock_rmtree, mock_load_split, mock_get_temp, mock_oci_get, tmp_path, make_oci_config + ): + """split_embed should raise 500 on RuntimeError during processing.""" + mock_oci_get.return_value = make_oci_config() + mock_get_temp.return_value = tmp_path + mock_load_split.side_effect = RuntimeError("Processing failed") + + # Create a test file + (tmp_path / "test.txt").write_text("Test content") + + request = DatabaseVectorStorage(model="text-embedding-3", chunk_size=1000, chunk_overlap=200) + + with pytest.raises(HTTPException) as exc_info: + await embed.split_embed(request=request, rate_limit=0, client="test_client") + + assert exc_info.value.status_code == 500 + assert "Processing failed" in exc_info.value.detail + + @pytest.mark.asyncio + @patch("server.api.v1.embed.utils_oci.get") + @patch("server.api.v1.embed.utils_embed.get_temp_directory") + @patch("server.api.v1.embed.utils_embed.load_and_split_documents") + @patch("shutil.rmtree") + async def test_split_embed_raises_500_on_generic_exception( + self, _mock_rmtree, mock_load_split, mock_get_temp, mock_oci_get, tmp_path, make_oci_config + ): + """split_embed should raise 500 on generic Exception during processing.""" + mock_oci_get.return_value = make_oci_config() + mock_get_temp.return_value = tmp_path + mock_load_split.side_effect = Exception("Unexpected error occurred") + + # Create a test file + (tmp_path / "test.txt").write_text("Test content") + + request = DatabaseVectorStorage(model="text-embedding-3", chunk_size=1000, chunk_overlap=200) + + with pytest.raises(HTTPException) as exc_info: + await embed.split_embed(request=request, rate_limit=0, client="test_client") + + assert exc_info.value.status_code == 500 + assert "Unexpected error occurred" in exc_info.value.detail + + @pytest.mark.asyncio + async def test_split_embed_loads_file_metadata( + self, split_embed_mocks, tmp_path, make_oci_config, make_database + ): + """split_embed should load file metadata when available.""" + mocks = split_embed_mocks + mocks["oci_get"].return_value = make_oci_config() + mocks["get_temp"].return_value = tmp_path + mocks["load_split"].return_value = (["doc1"], None) + mocks["get_embed"].return_value = MagicMock() + mocks["get_vs_table"].return_value = ("VS_TEST", "test_alias") + mocks["populate"].return_value = None + mocks["get_db"].return_value = make_database() + + # Create a test file and metadata + (tmp_path / "test.txt").write_text("Test content") + metadata = {"test.txt": {"size": 12, "time_modified": "2024-01-01T00:00:00Z"}} + (tmp_path / ".file_metadata.json").write_text(json.dumps(metadata)) + + request = DatabaseVectorStorage(model="text-embedding-3", chunk_size=1000, chunk_overlap=200) + + result = await embed.split_embed(request=request, rate_limit=0, client="test_client") + + assert result.status_code == 200 + # Verify load_and_split_documents was called with file_metadata + call_kwargs = mocks["load_split"].call_args.kwargs + assert call_kwargs.get("file_metadata") == metadata + + @pytest.mark.asyncio + async def test_split_embed_handles_corrupt_metadata( + self, split_embed_mocks, tmp_path, make_oci_config, make_database + ): + """split_embed should handle corrupt metadata file gracefully.""" + mocks = split_embed_mocks + mocks["oci_get"].return_value = make_oci_config() + mocks["get_temp"].return_value = tmp_path + mocks["load_split"].return_value = (["doc1"], None) + mocks["get_embed"].return_value = MagicMock() + mocks["get_vs_table"].return_value = ("VS_TEST", "test_alias") + mocks["populate"].return_value = None + mocks["get_db"].return_value = make_database() + + # Create a test file and corrupt metadata + (tmp_path / "test.txt").write_text("Test content") + (tmp_path / ".file_metadata.json").write_text("{ invalid json }") + + request = DatabaseVectorStorage(model="text-embedding-3", chunk_size=1000, chunk_overlap=200) + + result = await embed.split_embed(request=request, rate_limit=0, client="test_client") + + # Should still succeed, falling back to None for metadata + assert result.status_code == 200 + call_kwargs = mocks["load_split"].call_args.kwargs + assert call_kwargs.get("file_metadata") is None + class TestRefreshVectorStore: """Tests for the refresh_vector_store endpoint.""" @@ -519,6 +621,137 @@ async def test_refresh_vector_store_raises_500_on_db_exception( assert exc_info.value.status_code == 500 + @pytest.mark.asyncio + @patch("server.api.v1.embed.utils_oci.get") + @patch("server.api.v1.embed.utils_databases.get_client_database") + @patch("server.api.v1.embed.utils_embed.get_vector_store_by_alias") + @patch("server.api.v1.embed.utils_oci.get_bucket_objects_with_metadata") + @patch("server.api.v1.embed.utils_embed.get_processed_objects_metadata") + @patch("server.api.v1.embed.utils_oci.detect_changed_objects") + @patch("server.api.v1.embed.utils_embed.get_total_chunks_count") + async def test_refresh_vector_store_no_changes( + self, + mock_get_chunks, + mock_detect_changed, + mock_get_processed, + mock_get_objects, + mock_get_vs, + mock_get_db, + mock_oci_get, + make_oci_config, + make_database, + make_vector_store, + ): + """refresh_vector_store should return success when no changes detected.""" + mock_oci_get.return_value = make_oci_config() + mock_get_db.return_value = make_database() + mock_get_vs.return_value = make_vector_store() + mock_get_objects.return_value = [{"name": "file.pdf", "etag": "abc123"}] + mock_get_processed.return_value = {"file.pdf": {"etag": "abc123"}} + mock_detect_changed.return_value = ([], []) # No new, no modified + mock_get_chunks.return_value = 100 + + request = VectorStoreRefreshRequest(vector_store_alias="test_alias", bucket_name="test-bucket") + + result = await embed.refresh_vector_store(request=request, client="test_client") + + assert result.status_code == 200 + body = json.loads(result.body) + assert body["message"] == "No new or modified files to process" + assert body["total_chunks_in_store"] == 100 + + @pytest.mark.asyncio + @patch("server.api.v1.embed.utils_oci.get") + @patch("server.api.v1.embed.utils_databases.get_client_database") + @patch("server.api.v1.embed.utils_embed.get_vector_store_by_alias") + @patch("server.api.v1.embed.utils_oci.get_bucket_objects_with_metadata") + @patch("server.api.v1.embed.utils_embed.get_processed_objects_metadata") + @patch("server.api.v1.embed.utils_oci.detect_changed_objects") + @patch("server.api.v1.embed.utils_models.get_client_embed") + @patch("server.api.v1.embed.utils_embed.refresh_vector_store_from_bucket") + @patch("server.api.v1.embed.utils_embed.get_total_chunks_count") + async def test_refresh_vector_store_with_changes( + self, + mock_get_chunks, + mock_refresh, + mock_get_embed, + mock_detect_changed, + mock_get_processed, + mock_get_objects, + mock_get_vs, + mock_get_db, + mock_oci_get, + make_oci_config, + make_database, + make_vector_store, + ): + """refresh_vector_store should process changed files.""" + mock_oci_get.return_value = make_oci_config() + mock_get_db.return_value = make_database() + mock_get_vs.return_value = make_vector_store(model="text-embedding-3-small") + mock_get_objects.return_value = [ + {"name": "new_file.pdf", "etag": "new123"}, + {"name": "modified.pdf", "etag": "mod456"}, + ] + mock_get_processed.return_value = {"modified.pdf": {"etag": "old_etag"}} + mock_detect_changed.return_value = ( + [{"name": "new_file.pdf", "etag": "new123"}], # new + [{"name": "modified.pdf", "etag": "mod456"}], # modified + ) + mock_get_embed.return_value = MagicMock() + mock_refresh.return_value = {"message": "Processed 2 files", "processed_files": 2, "total_chunks": 50} + mock_get_chunks.return_value = 150 + + request = VectorStoreRefreshRequest(vector_store_alias="test_alias", bucket_name="test-bucket") + + result = await embed.refresh_vector_store(request=request, client="test_client") + + assert result.status_code == 200 + body = json.loads(result.body) + assert body["status"] == "completed" + assert body["new_files"] == 1 + assert body["updated_files"] == 1 + assert body["total_chunks_in_store"] == 150 + mock_refresh.assert_called_once() + + @pytest.mark.asyncio + @patch("server.api.v1.embed.utils_oci.get") + @patch("server.api.v1.embed.utils_databases.get_client_database") + @patch("server.api.v1.embed.utils_embed.get_vector_store_by_alias") + @patch("server.api.v1.embed.utils_oci.get_bucket_objects_with_metadata") + @patch("server.api.v1.embed.utils_embed.get_processed_objects_metadata") + @patch("server.api.v1.embed.utils_oci.detect_changed_objects") + @patch("server.api.v1.embed.utils_models.get_client_embed") + async def test_refresh_vector_store_raises_500_on_generic_exception( + self, + mock_get_embed, + mock_detect_changed, + mock_get_processed, + mock_get_objects, + mock_get_vs, + mock_get_db, + mock_oci_get, + make_oci_config, + make_database, + make_vector_store, + ): + """refresh_vector_store should raise 500 on generic Exception.""" + mock_oci_get.return_value = make_oci_config() + mock_get_db.return_value = make_database() + mock_get_vs.return_value = make_vector_store() + mock_get_objects.return_value = [{"name": "file.pdf", "etag": "abc123"}] + mock_get_processed.return_value = {} + mock_detect_changed.return_value = ([{"name": "file.pdf"}], []) + mock_get_embed.side_effect = RuntimeError("Embedding service unavailable") + + request = VectorStoreRefreshRequest(vector_store_alias="test_alias", bucket_name="test-bucket") + + with pytest.raises(HTTPException) as exc_info: + await embed.refresh_vector_store(request=request, client="test_client") + + assert exc_info.value.status_code == 500 + assert "Embedding service unavailable" in exc_info.value.detail + class TestRouterConfiguration: """Tests for router configuration.""" From 4dd79ab47cf14567c7c0bf51cef9f785cbb8d5ba Mon Sep 17 00:00:00 2001 From: gotsysdba Date: Fri, 28 Nov 2025 23:21:29 +0000 Subject: [PATCH 03/20] sqlcl branch updates --- src/client/content/config/tabs/mcp.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/src/client/content/config/tabs/mcp.py b/src/client/content/config/tabs/mcp.py index 32760515..484de2b9 100644 --- a/src/client/content/config/tabs/mcp.py +++ b/src/client/content/config/tabs/mcp.py @@ -123,19 +123,17 @@ def render_configs(mcp_server: str, mcp_type: str, configs: list) -> None: col1.markdown("Name", unsafe_allow_html=True) col2.markdown("​") for mcp_name in configs: - # The key prefix is to give each widget a unique key in the loop; the key itself is never used - key_prefix = f"{mcp_server}_{mcp_type}_{mcp_name}" col1.text_input( "Name", value=mcp_name, label_visibility="collapsed", disabled=True, - key=f"{key_prefix}_name", + key=f"{mcp_server}_{mcp_type}_{mcp_name}_input", ) col2.button( "Details", on_click=mcp_details, - key=f"{key_prefix}_details", + key=f"{mcp_server}_{mcp_type}_{mcp_name}_details", kwargs={"mcp_server": mcp_server, "mcp_type": mcp_type, "mcp_name": mcp_name}, ) From 30aa8e5b303f7e13d321ec427b545988ae66f995 Mon Sep 17 00:00:00 2001 From: gotsysdba Date: Fri, 28 Nov 2025 23:22:13 +0000 Subject: [PATCH 04/20] split_embed --- src/client/content/tools/tabs/split_embed.py | 38 ++++++++++++++++---- 1 file changed, 32 insertions(+), 6 deletions(-) diff --git a/src/client/content/tools/tabs/split_embed.py b/src/client/content/tools/tabs/split_embed.py index ff44f0ee..b7e3f6b5 100644 --- a/src/client/content/tools/tabs/split_embed.py +++ b/src/client/content/tools/tabs/split_embed.py @@ -123,23 +123,49 @@ def files_data_editor(files, key): def update_chunk_overlap_slider() -> None: - """Keep text and slider input aligned""" - state.selected_chunk_overlap_slider = state.selected_chunk_overlap_input + """Keep text and slider input aligned and ensure overlap doesn't exceed chunk size""" + new_overlap = state.selected_chunk_overlap_input + # Ensure overlap doesn't exceed chunk size + if hasattr(state, "selected_chunk_size_slider"): + chunk_size = state.selected_chunk_size_slider + if new_overlap >= chunk_size: + new_overlap = max(0, chunk_size - 1) + state.selected_chunk_overlap_input = new_overlap + state.selected_chunk_overlap_slider = new_overlap def update_chunk_overlap_input() -> None: - """Keep text and slider input aligned""" - state.selected_chunk_overlap_input = state.selected_chunk_overlap_slider + """Keep text and slider input aligned and ensure overlap doesn't exceed chunk size""" + new_overlap = state.selected_chunk_overlap_slider + # Ensure overlap doesn't exceed chunk size + if hasattr(state, "selected_chunk_size_slider"): + chunk_size = state.selected_chunk_size_slider + if new_overlap >= chunk_size: + new_overlap = max(0, chunk_size - 1) + state.selected_chunk_overlap_slider = new_overlap + state.selected_chunk_overlap_input = new_overlap def update_chunk_size_slider() -> None: - """Keep text and slider input aligned""" + """Keep text and slider input aligned and adjust overlap if needed""" state.selected_chunk_size_slider = state.selected_chunk_size_input + # If overlap exceeds new chunk size, cap it + if hasattr(state, "selected_chunk_overlap_slider"): + if state.selected_chunk_overlap_slider >= state.selected_chunk_size_slider: + new_overlap = max(0, state.selected_chunk_size_slider - 1) + state.selected_chunk_overlap_slider = new_overlap + state.selected_chunk_overlap_input = new_overlap def update_chunk_size_input() -> None: - """Keep text and slider input aligned""" + """Keep text and slider input aligned and adjust overlap if needed""" state.selected_chunk_size_input = state.selected_chunk_size_slider + # If overlap exceeds new chunk size, cap it + if hasattr(state, "selected_chunk_overlap_input"): + if state.selected_chunk_overlap_input >= state.selected_chunk_size_input: + new_overlap = max(0, state.selected_chunk_size_input - 1) + state.selected_chunk_overlap_input = new_overlap + state.selected_chunk_overlap_slider = new_overlap ############################################################################# From 90f4822a091379958f32d5f0c1b5deedbc2d545e Mon Sep 17 00:00:00 2001 From: gotsysdba Date: Fri, 28 Nov 2025 23:51:21 +0000 Subject: [PATCH 05/20] Migrated unit server api tests --- .../server/api/utils/test_utils_settings.py | 48 ++ .../server/unit/api/utils/test_utils_chat.py | 200 --------- .../api/utils/test_utils_databases_crud.py | 248 ----------- .../utils/test_utils_databases_functions.py | 413 ------------------ .../server/unit/api/utils/test_utils_embed.py | 283 ------------ .../unit/api/utils/test_utils_models.py | 241 ---------- tests/server/unit/api/utils/test_utils_oci.py | 337 -------------- .../unit/api/utils/test_utils_oci_refresh.py | 248 ----------- .../unit/api/utils/test_utils_settings.py | 133 ------ .../unit/api/utils/test_utils_testbed.py | 99 ----- 10 files changed, 48 insertions(+), 2202 deletions(-) delete mode 100644 tests/server/unit/api/utils/test_utils_chat.py delete mode 100644 tests/server/unit/api/utils/test_utils_databases_crud.py delete mode 100644 tests/server/unit/api/utils/test_utils_databases_functions.py delete mode 100644 tests/server/unit/api/utils/test_utils_embed.py delete mode 100644 tests/server/unit/api/utils/test_utils_models.py delete mode 100644 tests/server/unit/api/utils/test_utils_oci.py delete mode 100644 tests/server/unit/api/utils/test_utils_oci_refresh.py delete mode 100644 tests/server/unit/api/utils/test_utils_settings.py delete mode 100644 tests/server/unit/api/utils/test_utils_testbed.py diff --git a/test/unit/server/api/utils/test_utils_settings.py b/test/unit/server/api/utils/test_utils_settings.py index 1f84ec41..42299559 100644 --- a/test/unit/server/api/utils/test_utils_settings.py +++ b/test/unit/server/api/utils/test_utils_settings.py @@ -226,6 +226,44 @@ def test_update_server_loads_prompt_configs(self, mock_load_prompts, make_settin mock_load_prompts.assert_called_once_with(config_data) + @patch("server.api.utils.settings.bootstrap") + def test_update_server_mutates_lists_not_replaces(self, mock_bootstrap, make_settings): + """update_server should mutate existing lists rather than replacing them. + + This is critical because other modules import these lists directly + (e.g., `from server.bootstrap.bootstrap import DATABASE_OBJECTS`). + If we replace the list, those modules would hold stale references. + """ + original_db_list = [] + original_model_list = [] + original_oci_list = [] + + mock_bootstrap.DATABASE_OBJECTS = original_db_list + mock_bootstrap.MODEL_OBJECTS = original_model_list + mock_bootstrap.OCI_OBJECTS = original_oci_list + + config_data = { + "client_settings": make_settings().model_dump(), + "database_configs": [{"name": "test_db", "user": "user", "password": "pass", "dsn": "dsn"}], + "model_configs": [{"id": "test-model", "provider": "openai", "type": "ll"}], + "oci_configs": [{"auth_profile": "DEFAULT", "compartment_id": "ocid1.compartment.oc1..test"}], + } + + utils_settings.update_server(config_data) + + # Verify the lists are the SAME objects (mutated, not replaced) + assert mock_bootstrap.DATABASE_OBJECTS is original_db_list, "DATABASE_OBJECTS was replaced instead of mutated" + assert mock_bootstrap.MODEL_OBJECTS is original_model_list, "MODEL_OBJECTS was replaced instead of mutated" + assert mock_bootstrap.OCI_OBJECTS is original_oci_list, "OCI_OBJECTS was replaced instead of mutated" + + # Verify the lists now contain the new data + assert len(original_db_list) == 1 + assert original_db_list[0].name == "test_db" + assert len(original_model_list) == 1 + assert original_model_list[0].id == "test-model" + assert len(original_oci_list) == 1 + assert original_oci_list[0].auth_profile == "DEFAULT" + class TestLoadPromptOverride: # pylint: disable=protected-access """Tests for the _load_prompt_override function.""" @@ -250,6 +288,16 @@ def test_load_prompt_override_without_text(self, mock_set_override): assert result is False mock_set_override.assert_not_called() + @patch("server.api.utils.settings.cache.set_override") + def test_load_prompt_override_with_empty_text(self, mock_set_override): + """_load_prompt_override should return False when text is empty string.""" + prompt = {"name": "test_prompt", "text": ""} + + result = utils_settings._load_prompt_override(prompt) + + assert result is False + mock_set_override.assert_not_called() + class TestLoadPromptConfigs: # pylint: disable=protected-access """Tests for the _load_prompt_configs function.""" diff --git a/tests/server/unit/api/utils/test_utils_chat.py b/tests/server/unit/api/utils/test_utils_chat.py deleted file mode 100644 index b614a00e..00000000 --- a/tests/server/unit/api/utils/test_utils_chat.py +++ /dev/null @@ -1,200 +0,0 @@ -""" -Copyright (c) 2024, 2025, Oracle and/or its affiliates. -Licensed under the Universal Permissive License v1.0 as shown at http://oss.oracle.com/licenses/upl. -""" -# spell-checker: disable -# pylint: disable=protected-access import-error import-outside-toplevel - -from unittest.mock import patch, MagicMock -import pytest - -from langchain_core.messages import ChatMessage - -from server.api.utils import chat -from common.schema import ( - ChatRequest, - Settings, - LargeLanguageSettings, - VectorSearchSettings, - OciSettings, -) - - -class TestChatUtils: - """Test chat utility functions""" - - @pytest.fixture - def sample_message(self): - """Sample chat message fixture""" - return ChatMessage(role="user", content="Hello, how are you?") - - @pytest.fixture - def sample_request(self, sample_message): - """Sample chat request fixture""" - return ChatRequest(messages=[sample_message], model="openai/gpt-4") - - @pytest.fixture - def sample_client_settings(self): - """Sample client settings fixture""" - return Settings( - client="test_client", - ll_model=LargeLanguageSettings(model="openai/gpt-4", chat_history=True, temperature=0.7, max_tokens=4096), - vector_search=VectorSearchSettings(), - oci=OciSettings(auth_profile="DEFAULT"), - ) - - @patch("server.api.utils.settings.get_client") - @patch("server.api.utils.oci.get") - @patch("server.api.utils.models.get_litellm_config") - @patch("server.agents.chatbot.chatbot_graph.astream") - @pytest.mark.asyncio - async def test_completion_generator_success( - self, mock_astream, mock_get_litellm_config, mock_get_oci, mock_get_client, - sample_request, sample_client_settings - ): - """Test successful completion generation""" - # Setup mocks - mock_get_client.return_value = sample_client_settings - mock_get_oci.return_value = MagicMock() - mock_get_litellm_config.return_value = {"model": "gpt-4", "temperature": 0.7} - - # Mock the async generator - this should only yield the final completion for "completions" mode - async def mock_generator(): - yield {"stream": "Hello"} - yield {"stream": " there"} - yield {"completion": "Hello there"} - - mock_astream.return_value = mock_generator() - - # Test the function - results = [] - async for result in chat.completion_generator("test_client", sample_request, "completions"): - results.append(result) - - # Verify results - for "completions" mode, we get stream chunks + final completion - assert len(results) == 3 - assert results[0] == b"Hello" # Stream chunks are encoded as bytes - assert results[1] == b" there" - assert results[2] == "Hello there" # Final completion is a string - mock_get_client.assert_called_once_with("test_client") - mock_get_oci.assert_called_once_with(client="test_client") - - @patch("server.api.utils.settings.get_client") - @patch("server.api.utils.oci.get") - @patch("server.api.utils.models.get_litellm_config") - @patch("server.agents.chatbot.chatbot_graph.astream") - @pytest.mark.asyncio - async def test_completion_generator_streaming( - self, mock_astream, mock_get_litellm_config, mock_get_oci, mock_get_client, - sample_request, sample_client_settings - ): - """Test streaming completion generation""" - # Setup mocks - mock_get_client.return_value = sample_client_settings - mock_get_oci.return_value = MagicMock() - mock_get_litellm_config.return_value = {"model": "gpt-4", "temperature": 0.7} - - # Mock the async generator - async def mock_generator(): - yield {"stream": "Hello"} - yield {"stream": " there"} - yield {"completion": "Hello there"} - - mock_astream.return_value = mock_generator() - - # Test the function - results = [] - async for result in chat.completion_generator("test_client", sample_request, "streams"): - results.append(result) - - # Verify results - should include encoded stream chunks and finish marker - assert len(results) == 3 - assert results[0] == b"Hello" - assert results[1] == b" there" - assert results[2] == "[stream_finished]" - - @patch("server.api.utils.settings.get_client") - @patch("server.api.utils.oci.get") - @patch("server.api.utils.models.get_litellm_config") - @patch("server.api.utils.databases.get_client_database") - @patch("server.api.utils.models.get_client_embed") - @patch("server.agents.chatbot.chatbot_graph.astream") - @pytest.mark.asyncio - async def test_completion_generator_with_vector_search( - self, - mock_astream, - mock_get_client_embed, - mock_get_client_database, - mock_get_litellm_config, - mock_get_oci, - mock_get_client, - sample_request, - sample_client_settings, - ): - """Test completion generation with vector search enabled""" - # Setup settings with vector search enabled via tools_enabled - vector_search_settings = sample_client_settings.model_copy() - vector_search_settings.tools_enabled = ["Vector Search"] - - # Setup mocks - mock_get_client.return_value = vector_search_settings - mock_get_oci.return_value = MagicMock() - mock_get_litellm_config.return_value = {"model": "gpt-4", "temperature": 0.7} - - mock_db = MagicMock() - mock_db.connection = MagicMock() - mock_get_client_database.return_value = mock_db - mock_get_client_embed.return_value = MagicMock() - - # Mock the async generator - async def mock_generator(): - yield {"completion": "Response with vector search"} - - mock_astream.return_value = mock_generator() - - # Test the function - results = [] - async for result in chat.completion_generator("test_client", sample_request, "completions"): - results.append(result) - - # Verify vector search setup - mock_get_client_database.assert_called_once_with("test_client", False) - mock_get_client_embed.assert_called_once() - assert len(results) == 1 - - @patch("server.api.utils.settings.get_client") - @patch("server.api.utils.oci.get") - @patch("server.api.utils.models.get_litellm_config") - @patch("server.agents.chatbot.chatbot_graph.astream") - @pytest.mark.asyncio - async def test_completion_generator_no_model_specified( - self, mock_astream, mock_get_litellm_config, mock_get_oci, mock_get_client, - sample_message, sample_client_settings - ): - """Test completion generation when no model is specified in request""" - # Create request without model - request_no_model = ChatRequest(messages=[sample_message], model=None) - - # Setup mocks - mock_get_client.return_value = sample_client_settings - mock_get_oci.return_value = MagicMock() - mock_get_litellm_config.return_value = {"model": "gpt-4", "temperature": 0.7} - - # Mock the async generator - async def mock_generator(): - yield {"completion": "Response using default model"} - - mock_astream.return_value = mock_generator() - - # Test the function - results = [] - async for result in chat.completion_generator("test_client", request_no_model, "completions"): - results.append(result) - - # Should use model from client settings - assert len(results) == 1 - - def test_logger_exists(self): - """Test that logger is properly configured""" - assert hasattr(chat, "logger") - assert chat.logger.name == "api.utils.chat" diff --git a/tests/server/unit/api/utils/test_utils_databases_crud.py b/tests/server/unit/api/utils/test_utils_databases_crud.py deleted file mode 100644 index 62c06ef0..00000000 --- a/tests/server/unit/api/utils/test_utils_databases_crud.py +++ /dev/null @@ -1,248 +0,0 @@ -""" -Copyright (c) 2024, 2025, Oracle and/or its affiliates. -Licensed under the Universal Permissive License v1.0 as shown at http://oss.oracle.com/licenses/upl. -""" -# spell-checker: disable -# pylint: disable=protected-access import-error import-outside-toplevel - -from unittest.mock import patch, MagicMock - -import pytest - -from server.api.utils import databases -from server.api.utils.databases import DbException -from common.schema import Database - - -class TestDatabases: - """Test databases module functionality""" - - sample_database: Database - sample_database_2: Database - - def setup_method(self): - """Setup test data before each test""" - self.sample_database = Database(name="test_db", user="test_user", password="test_password", dsn="test_dsn") - self.sample_database_2 = Database( - name="test_db_2", user="test_user_2", password="test_password_2", dsn="test_dsn_2" - ) - - # test_get_all: See test/unit/server/api/utils/test_utils_databases.py::TestGet::test_get_all_databases - # test_get_by_name_found: See test/unit/server/api/utils/test_utils_databases.py::TestGet::test_get_specific_database - # test_get_by_name_not_found: See test/unit/server/api/utils/test_utils_databases.py::TestGet::test_get_raises_unknown_error - - @patch("server.api.utils.databases.DATABASE_OBJECTS") - def test_get_empty_list(self, mock_database_objects): - """Test getting databases when list is empty""" - mock_database_objects.__iter__ = MagicMock(return_value=iter([])) - mock_database_objects.__len__ = MagicMock(return_value=0) - - result = databases.get() - - assert result == [] - - @patch("server.api.utils.databases.DATABASE_OBJECTS") - def test_get_empty_list_with_name(self, mock_database_objects): - """Test getting database by name when list is empty""" - mock_database_objects.__iter__ = MagicMock(return_value=iter([])) - mock_database_objects.__len__ = MagicMock(return_value=0) - - with pytest.raises(ValueError, match="test_db not found"): - databases.get(name="test_db") - - # test_create_success: See test/unit/server/api/utils/test_utils_databases.py::TestCreate::test_create_success - # test_create_already_exists: See test/unit/server/api/utils/test_utils_databases.py::TestCreate::test_create_raises_exists_error - # test_create_missing_user: See test/unit/server/api/utils/test_utils_databases.py::TestCreate::test_create_raises_value_error_missing_fields - - def test_create_missing_password(self, db_container, db_objects_manager): - """Test database creation with missing password field""" - assert db_container is not None - assert db_objects_manager is not None - databases.DATABASE_OBJECTS.clear() - - # Create database with missing password - incomplete_db = Database(name="incomplete_db", user="test_user", dsn="test_dsn") - - with pytest.raises(ValueError, match="'user', 'password', and 'dsn' are required"): - databases.create(incomplete_db) - - def test_create_missing_dsn(self, db_container, db_objects_manager): - """Test database creation with missing dsn field""" - assert db_container is not None - assert db_objects_manager is not None - databases.DATABASE_OBJECTS.clear() - - # Create database with missing dsn - incomplete_db = Database(name="incomplete_db", user="test_user", password="test_password") - - with pytest.raises(ValueError, match="'user', 'password', and 'dsn' are required"): - databases.create(incomplete_db) - - def test_create_multiple_missing_fields(self, db_container, db_objects_manager): - """Test database creation with multiple missing required fields""" - assert db_container is not None - assert db_objects_manager is not None - databases.DATABASE_OBJECTS.clear() - - # Create database with multiple missing fields - incomplete_db = Database(name="incomplete_db") - - with pytest.raises(ValueError, match="'user', 'password', and 'dsn' are required"): - databases.create(incomplete_db) - - # test_delete: See test/unit/server/api/utils/test_utils_databases.py::TestDelete::test_delete_removes_database - - def test_delete_nonexistent(self, db_container, db_objects_manager): - """Test deleting non-existent database""" - assert db_container is not None - assert db_objects_manager is not None - - # Setup test data - db1 = Database(name="test_db_1", user="user1", password="pass1", dsn="dsn1") - databases.DATABASE_OBJECTS.clear() - databases.DATABASE_OBJECTS.append(db1) - - original_length = len(databases.DATABASE_OBJECTS) - - # Try to delete non-existent database (should not raise error) - databases.delete("nonexistent") - - # Verify no change - assert len(databases.DATABASE_OBJECTS) == original_length - assert databases.DATABASE_OBJECTS[0].name == "test_db_1" - - def test_delete_empty_list(self, db_container, db_objects_manager): - """Test deleting from empty database list""" - assert db_container is not None - assert db_objects_manager is not None - databases.DATABASE_OBJECTS.clear() - - # Try to delete from empty list (should not raise error) - databases.delete("any_name") - - # Verify still empty - assert len(databases.DATABASE_OBJECTS) == 0 - - def test_delete_multiple_same_name(self, db_container, db_objects_manager): - """Test deleting when multiple databases have the same name""" - assert db_container is not None - assert db_objects_manager is not None - # Setup test data with duplicate names - db1 = Database(name="duplicate", user="user1", password="pass1", dsn="dsn1") - db2 = Database(name="duplicate", user="user2", password="pass2", dsn="dsn2") - db3 = Database(name="other", user="user3", password="pass3", dsn="dsn3") - - databases.DATABASE_OBJECTS.clear() - databases.DATABASE_OBJECTS.extend([db1, db2, db3]) - - # Delete databases with duplicate name - databases.delete("duplicate") - - # Verify all duplicates are removed - assert len(databases.DATABASE_OBJECTS) == 1 - assert databases.DATABASE_OBJECTS[0].name == "other" - - # test_logger_exists: See test/unit/server/api/utils/test_utils_databases.py::TestLoggerConfiguration::test_logger_exists - - def test_get_filters_correctly(self, db_container, db_objects_manager): - """Test that get correctly filters by name""" - assert db_container is not None - assert db_objects_manager is not None - # Setup test data - db1 = Database(name="alpha", user="user1", password="pass1", dsn="dsn1") - db2 = Database(name="beta", user="user2", password="pass2", dsn="dsn2") - db3 = Database(name="alpha", user="user3", password="pass3", dsn="dsn3") # Duplicate name - - databases.DATABASE_OBJECTS.clear() - databases.DATABASE_OBJECTS.extend([db1, db2, db3]) - - # Test getting all - all_dbs = databases.get() - assert len(all_dbs) == 3 - - # Test getting by specific name - alpha_dbs = databases.get(name="alpha") - assert len(alpha_dbs) == 2 - assert all(db.name == "alpha" for db in alpha_dbs) - - beta_dbs = databases.get(name="beta") - assert len(beta_dbs) == 1 - assert beta_dbs[0].name == "beta" - - def test_database_model_validation(self, db_container): - """Test Database model validation and optional fields""" - assert db_container is not None - # Test with all required fields - complete_db = Database(name="complete", user="test_user", password="test_password", dsn="test_dsn") - assert complete_db.name == "complete" - assert complete_db.user == "test_user" - assert complete_db.password == "test_password" - assert complete_db.dsn == "test_dsn" - assert complete_db.connected is False # Default value - assert complete_db.tcp_connect_timeout == 5 # Default value - assert complete_db.vector_stores == [] # Default value - - # Test with optional fields - complete_db_with_options = Database( - name="complete_with_options", - user="test_user", - password="test_password", - dsn="test_dsn", - wallet_location="/path/to/wallet", - wallet_password="wallet_pass", - tcp_connect_timeout=10, - ) - assert complete_db_with_options.wallet_location == "/path/to/wallet" - assert complete_db_with_options.wallet_password == "wallet_pass" - assert complete_db_with_options.tcp_connect_timeout == 10 - - def test_create_real_scenario(self, db_container, db_objects_manager): - """Test create with realistic data using container DB""" - assert db_container is not None - assert db_objects_manager is not None - databases.DATABASE_OBJECTS.clear() - - # Create database with realistic configuration - test_db = Database( - name="container_test", - user="PYTEST", - password="OrA_41_3xPl0d3r", - dsn="//localhost:1525/FREEPDB1", - tcp_connect_timeout=10, - ) - - result = databases.create(test_db) - - # Verify creation - assert len(databases.DATABASE_OBJECTS) == 1 - created_db = databases.DATABASE_OBJECTS[0] - assert created_db.name == "container_test" - assert created_db.user == "PYTEST" - assert created_db.dsn == "//localhost:1525/FREEPDB1" - assert created_db.tcp_connect_timeout == 10 - assert result == [test_db] - - -class TestDbException: - """Test custom database exception class""" - - # test_db_exception_initialization: See test/unit/server/api/utils/test_utils_databases.py::TestDbException::test_db_exception_init - - def test_db_exception_inheritance(self): - """Test DbException inherits from Exception""" - exc = DbException(status_code=404, detail="Not found") - assert isinstance(exc, Exception) - - def test_db_exception_different_status_codes(self): - """Test DbException with different status codes""" - test_cases = [ - (400, "Bad request"), - (401, "Unauthorized"), - (403, "Forbidden"), - (503, "Service unavailable"), - ] - - for status_code, detail in test_cases: - exc = DbException(status_code=status_code, detail=detail) - assert exc.status_code == status_code - assert exc.detail == detail diff --git a/tests/server/unit/api/utils/test_utils_databases_functions.py b/tests/server/unit/api/utils/test_utils_databases_functions.py deleted file mode 100644 index a4ac6b74..00000000 --- a/tests/server/unit/api/utils/test_utils_databases_functions.py +++ /dev/null @@ -1,413 +0,0 @@ -""" -Copyright (c) 2024, 2025, Oracle and/or its affiliates. -Licensed under the Universal Permissive License v1.0 as shown at http://oss.oracle.com/licenses/upl. -""" -# spell-checker: disable -# pylint: disable=protected-access import-error import-outside-toplevel - -import json -from unittest.mock import patch, MagicMock - -import pytest -import oracledb -from conftest import TEST_CONFIG - -from server.api.utils import databases -from server.api.utils.databases import DbException -from common.schema import Database - - -class TestDatabaseUtilsPrivateFunctions: - """Test private utility functions""" - - sample_database: Database - - def setup_method(self): - """Setup test data""" - self.sample_database = Database( - name="test_db", - user=TEST_CONFIG["db_username"], - password=TEST_CONFIG["db_password"], - dsn=TEST_CONFIG["db_dsn"], - ) - - # test_test_function_success: See test/unit/server/api/utils/test_utils_databases.py::TestTestConnection::test_test_connection_active - # test_test_function_reconnect: See test/unit/server/api/utils/test_utils_databases.py::TestTestConnection::test_test_connection_refreshes_on_database_error - # test_test_function_value_error: See test/unit/server/api/utils/test_utils_databases.py::TestTestConnection::test_test_raises_db_exception_on_value_error - # test_test_function_permission_error: See test/unit/server/api/utils/test_utils_databases.py::TestTestConnection::test_test_raises_db_exception_on_permission_error - # test_test_function_connection_error: See test/unit/server/api/utils/test_utils_databases.py::TestTestConnection::test_test_raises_db_exception_on_connection_error - # test_test_function_generic_exception: See test/unit/server/api/utils/test_utils_databases.py::TestTestConnection::test_test_raises_db_exception_on_generic_exception - # test_get_vs_with_real_database: See test/unit/server/api/utils/test_utils_databases.py::TestGetVs::test_get_vs_returns_list - - @patch("server.api.utils.databases.execute_sql") - def test_get_vs_with_mock_data(self, mock_execute_sql): - """Test vector storage retrieval with mocked data""" - mock_connection = MagicMock() - mock_execute_sql.return_value = [ - ( - "TEST_TABLE", - '{"alias": "test_alias", "model": "test_model", "chunk_size": 1000, "distance_metric": "COSINE"}', - ), - ( - "ANOTHER_TABLE", - '{"alias": "another_alias", "model": "another_model", ' - '"chunk_size": 500, "distance_metric": "EUCLIDEAN_DISTANCE"}', - ), - ] - - result = databases._get_vs(mock_connection) - - assert len(result) == 2 - assert result[0].vector_store == "TEST_TABLE" - assert result[0].alias == "test_alias" - assert result[0].model == "test_model" - assert result[0].chunk_size == 1000 - assert result[0].distance_metric == "COSINE" - - assert result[1].vector_store == "ANOTHER_TABLE" - assert result[1].alias == "another_alias" - assert result[1].distance_metric == "EUCLIDEAN_DISTANCE" - - @patch("server.api.utils.databases.execute_sql") - def test_get_vs_empty_result(self, mock_execute_sql): - """Test vector storage retrieval with empty results""" - mock_connection = MagicMock() - mock_execute_sql.return_value = [] - - result = databases._get_vs(mock_connection) - - assert isinstance(result, list) - assert len(result) == 0 - - @patch("server.api.utils.databases.execute_sql") - def test_get_vs_malformed_json(self, mock_execute_sql): - """Test vector storage retrieval with malformed JSON""" - mock_connection = MagicMock() - mock_execute_sql.return_value = [ - ("TEST_TABLE", '{"invalid_json": }'), - ] - - with pytest.raises(json.JSONDecodeError): - databases._get_vs(mock_connection) - -class TestDatabaseUtilsPublicFunctions: - """Test public utility functions - connection and execution""" - - sample_database: Database - - def setup_method(self): - """Setup test data""" - self.sample_database = Database( - name="test_db", - user=TEST_CONFIG["db_username"], - password=TEST_CONFIG["db_password"], - dsn=TEST_CONFIG["db_dsn"], - ) - - # test_connect_success_with_real_database: See test/unit/server/api/utils/test_utils_databases.py::TestConnect::test_connect_success_real_db - # test_connect_missing_user: See test/unit/server/api/utils/test_utils_databases.py::TestConnect::test_connect_raises_value_error_missing_details - # test_connect_missing_password: See test/unit/server/api/utils/test_utils_databases.py::TestConnect::test_connect_raises_value_error_missing_details - # test_connect_missing_dsn: See test/unit/server/api/utils/test_utils_databases.py::TestConnect::test_connect_raises_value_error_missing_details - - def test_connect_with_wallet_configuration(self, db_container): - """Test connection with wallet configuration""" - assert db_container is not None - db_with_wallet = Database( - name="test_db", - user=TEST_CONFIG["db_username"], - password=TEST_CONFIG["db_password"], - dsn=TEST_CONFIG["db_dsn"], - wallet_password="wallet_pass", - config_dir="/path/to/config", - ) - - # This should attempt to connect but may fail due to wallet config - # The test verifies the code path works, not necessarily successful connection - try: - result = databases.connect(db_with_wallet) - databases.disconnect(result) - except oracledb.DatabaseError: - # Expected if wallet doesn't exist - pass - - def test_connect_wallet_password_without_location(self, db_container): - """Test connection with wallet password but no location""" - assert db_container is not None - db_with_wallet = Database( - name="test_db", - user=TEST_CONFIG["db_username"], - password=TEST_CONFIG["db_password"], - dsn=TEST_CONFIG["db_dsn"], - wallet_password="wallet_pass", - config_dir="/default/config", - ) - - # This should set wallet_location to config_dir - try: - result = databases.connect(db_with_wallet) - databases.disconnect(result) - except oracledb.DatabaseError: - # Expected if wallet doesn't exist - pass - - # test_connect_invalid_credentials: See test/unit/server/api/utils/test_utils_databases.py::TestConnect::test_connect_raises_permission_error_invalid_credentials - - def test_connect_invalid_dsn(self, db_container): - """Test connection with invalid DSN""" - assert db_container is not None - invalid_db = Database( - name="test_db", - user=TEST_CONFIG["db_username"], - password=TEST_CONFIG["db_password"], - dsn="//invalid:1521/INVALID", - ) - - # This will raise socket.gaierror which is wrapped in oracledb.DatabaseError - with pytest.raises(Exception): # Catch any exception - DNS resolution errors vary by environment - databases.connect(invalid_db) - - # test_disconnect_success: See test/unit/server/api/utils/test_utils_databases.py::TestDisconnect::test_disconnect_closes_connection - # test_execute_sql_success_with_real_database: See test/unit/server/api/utils/test_utils_databases.py::TestExecuteSql::test_execute_sql_returns_rows - # test_execute_sql_with_binds: See test/unit/server/api/utils/test_utils_databases.py::TestExecuteSql::test_execute_sql_with_binds - - def test_execute_sql_no_rows(self, db_container): - """Test SQL execution that returns no rows""" - assert db_container is not None - conn = databases.connect(self.sample_database) - - try: - # Test query with no results - result = databases.execute_sql(conn, "SELECT 1 FROM DUAL WHERE 1=0") - assert result == [] - finally: - databases.disconnect(conn) - - def test_execute_sql_ddl_statement(self, db_container): - """Test SQL execution with DDL statement""" - assert db_container is not None - conn = databases.connect(self.sample_database) - - try: - # Create a test table - databases.execute_sql(conn, "CREATE TABLE test_temp (id NUMBER)") - - # Drop the test table - result = databases.execute_sql(conn, "DROP TABLE test_temp") - # DDL statements typically return None - assert result is None - except oracledb.DatabaseError as e: - # If table already exists or other DDL error, that's okay for testing - if "name is already used" not in str(e): - raise - finally: - # Clean up if table still exists - try: - databases.execute_sql(conn, "DROP TABLE test_temp") - except oracledb.DatabaseError: - pass # Table doesn't exist, which is fine - databases.disconnect(conn) - - def test_execute_sql_table_exists_error(self, db_container): - """Test SQL execution with table exists error (ORA-00955)""" - assert db_container is not None - conn = databases.connect(self.sample_database) - - try: - # Create table twice to trigger ORA-00955 - databases.execute_sql(conn, "CREATE TABLE test_exists (id NUMBER)") - - # This should log but not raise an exception - databases.execute_sql(conn, "CREATE TABLE test_exists (id NUMBER)") - - except oracledb.DatabaseError: - # Expected behavior - the function should handle this gracefully - pass - finally: - try: - databases.execute_sql(conn, "DROP TABLE test_exists") - except oracledb.DatabaseError: - pass - databases.disconnect(conn) - - def test_execute_sql_table_not_exists_error(self, db_container): - """Test SQL execution with table not exists error (ORA-00942)""" - assert db_container is not None - conn = databases.connect(self.sample_database) - - try: - # Try to select from non-existent table to trigger ORA-00942 - databases.execute_sql(conn, "SELECT * FROM non_existent_table") - except oracledb.DatabaseError: - # Expected behavior - the function should handle this gracefully - pass - finally: - databases.disconnect(conn) - - # test_execute_sql_invalid_syntax: See test/unit/server/api/utils/test_utils_databases.py::TestExecuteSql::test_execute_sql_raises_on_other_database_error - - def test_drop_vs_function_exists(self): - """Test that drop_vs function exists and is callable""" - assert hasattr(databases, "drop_vs") - assert callable(databases.drop_vs) - - # test_drop_vs_calls_langchain: See test/unit/server/api/utils/test_utils_databases.py::TestDropVs::test_drop_vs_calls_langchain - - -class TestDatabaseUtilsQueryFunctions: - """Test public utility functions - get and client database functions""" - - sample_database: Database - - def setup_method(self): - """Setup test data""" - self.sample_database = Database( - name="test_db", - user=TEST_CONFIG["db_username"], - password=TEST_CONFIG["db_password"], - dsn=TEST_CONFIG["db_dsn"], - ) - - def test_get_without_validation(self, db_container, db_objects_manager): - """Test get without validation""" - assert db_container is not None - assert db_objects_manager is not None - databases.DATABASE_OBJECTS.clear() - databases.DATABASE_OBJECTS.append(self.sample_database) - - # Test getting all databases - result = databases.get() - assert isinstance(result, list) - assert len(result) == 1 - assert result[0].name == "test_db" - assert result[0].connected is False # No validation, so not connected - - def test_get_with_validation(self, db_container, db_objects_manager): - """Test get with validation using real database""" - assert db_container is not None - assert db_objects_manager is not None - databases.DATABASE_OBJECTS.clear() - databases.DATABASE_OBJECTS.append(self.sample_database) - - # Test getting all databases with validation - result = databases.get_databases(validate=True) - assert isinstance(result, list) - assert len(result) == 1 - assert result[0].name == "test_db" - assert result[0].connected is True # Validation should connect - assert result[0].connection is not None - - # Clean up connections - for db in databases.DATABASE_OBJECTS: - if db.connection: - databases.disconnect(db.connection) - - def test_get_by_name(self, db_container, db_objects_manager): - """Test get by specific name""" - assert db_container is not None - assert db_objects_manager is not None - databases.DATABASE_OBJECTS.clear() - db1 = Database(name="db1", user="user1", password="pass1", dsn="dsn1") - db2 = Database( - name="db2", user=TEST_CONFIG["db_username"], password=TEST_CONFIG["db_password"], dsn=TEST_CONFIG["db_dsn"] - ) - databases.DATABASE_OBJECTS.extend([db1, db2]) - - # Test getting specific database - result = databases.get_databases(db_name="db2") - assert isinstance(result, Database) # Single database, not list - assert result.name == "db2" - - def test_get_validation_failure(self, db_container, db_objects_manager): - """Test get with validation when connection fails""" - assert db_container is not None - assert db_objects_manager is not None - databases.DATABASE_OBJECTS.clear() - # Add database with invalid credentials - invalid_db = Database(name="invalid", user="invalid", password="invalid", dsn="invalid") - databases.DATABASE_OBJECTS.append(invalid_db) - - # Test validation with invalid database (should continue without error) - result = databases.get_databases(validate=True) - assert isinstance(result, list) - assert len(result) == 1 - assert result[0].connected is False # Should remain False due to connection failure - - @patch("server.api.utils.settings.get_client") - def test_get_client_database_default(self, mock_get_settings, db_container, db_objects_manager): - """Test get_client_database with default settings""" - assert db_container is not None - assert db_objects_manager is not None - # Mock client settings without vector_search - mock_settings = MagicMock() - mock_settings.vector_search = None - mock_get_settings.return_value = mock_settings - - databases.DATABASE_OBJECTS.clear() - default_db = Database( - name="DEFAULT", - user=TEST_CONFIG["db_username"], - password=TEST_CONFIG["db_password"], - dsn=TEST_CONFIG["db_dsn"], - ) - databases.DATABASE_OBJECTS.append(default_db) - - result = databases.get_client_database("test_client") - assert isinstance(result, Database) - assert result.name == "DEFAULT" - - @patch("server.api.utils.settings.get_client") - def test_get_client_database_with_vector_search(self, mock_get_settings, db_container, db_objects_manager): - """Test get_client_database with vector_search settings""" - assert db_container is not None - assert db_objects_manager is not None - # Mock client settings with vector_search - mock_vector_search = MagicMock() - mock_vector_search.database = "VECTOR_DB" - mock_settings = MagicMock() - mock_settings.vector_search = mock_vector_search - mock_get_settings.return_value = mock_settings - - databases.DATABASE_OBJECTS.clear() - vector_db = Database( - name="VECTOR_DB", - user=TEST_CONFIG["db_username"], - password=TEST_CONFIG["db_password"], - dsn=TEST_CONFIG["db_dsn"], - ) - databases.DATABASE_OBJECTS.append(vector_db) - - result = databases.get_client_database("test_client") - assert isinstance(result, Database) - assert result.name == "VECTOR_DB" - - @patch("server.api.utils.settings.get_client") - def test_get_client_database_with_validation(self, mock_get_settings, db_container, db_objects_manager): - """Test get_client_database with validation enabled""" - assert db_container is not None - assert db_objects_manager is not None - # Mock client settings - mock_settings = MagicMock() - mock_settings.vector_search = None - mock_get_settings.return_value = mock_settings - - databases.DATABASE_OBJECTS.clear() - default_db = Database( - name="DEFAULT", - user=TEST_CONFIG["db_username"], - password=TEST_CONFIG["db_password"], - dsn=TEST_CONFIG["db_dsn"], - ) - databases.DATABASE_OBJECTS.append(default_db) - - result = databases.get_client_database("test_client", validate=True) - assert isinstance(result, Database) - assert result.name == "DEFAULT" - assert result.connected is True - assert result.connection is not None - - # Clean up connections - for db in databases.DATABASE_OBJECTS: - if db.connection: - databases.disconnect(db.connection) - - # test_logger_exists: See test/unit/server/api/utils/test_utils_databases.py::TestLoggerConfiguration::test_logger_exists diff --git a/tests/server/unit/api/utils/test_utils_embed.py b/tests/server/unit/api/utils/test_utils_embed.py deleted file mode 100644 index 13a63538..00000000 --- a/tests/server/unit/api/utils/test_utils_embed.py +++ /dev/null @@ -1,283 +0,0 @@ -""" -Copyright (c) 2024, 2025, Oracle and/or its affiliates. -Licensed under the Universal Permissive License v1.0 as shown at http://oss.oracle.com/licenses/upl. -""" -# spell-checker: disable -# pylint: disable=protected-access import-error import-outside-toplevel - -from decimal import Decimal -from pathlib import Path -from unittest.mock import patch, mock_open, MagicMock - -import pytest -from langchain.docstore.document import Document as LangchainDocument - -from server.api.utils import embed -from common.schema import Database - - -class TestEmbedUtils: - """Test embed utility functions""" - - @pytest.fixture - def sample_document(self): - """Sample document fixture""" - return LangchainDocument( - page_content="This is a test document content.", metadata={"source": "/path/to/test_file.txt", "page": 1} - ) - - @pytest.fixture - def sample_split_doc(self): - """Sample split document fixture""" - return LangchainDocument( - page_content="This is a chunk of content.", metadata={"source": "/path/to/test_file.txt", "start_index": 0} - ) - - @patch("pathlib.Path.exists") - @patch("pathlib.Path.is_dir") - @patch("pathlib.Path.mkdir") - def test_get_temp_directory_app_tmp(self, mock_mkdir, mock_is_dir, mock_exists): - """Test temp directory creation in /app/tmp""" - mock_exists.return_value = True - mock_is_dir.return_value = True - - result = embed.get_temp_directory("test_client", "embed") - - assert result == Path("/app/tmp") / "test_client" / "embed" - mock_mkdir.assert_called_once_with(parents=True, exist_ok=True) - - @patch("pathlib.Path.exists") - @patch("pathlib.Path.mkdir") - def test_get_temp_directory_tmp_fallback(self, mock_mkdir, mock_exists): - """Test temp directory creation fallback to /tmp""" - mock_exists.return_value = False - - result = embed.get_temp_directory("test_client", "embed") - - assert result == Path("/tmp") / "test_client" / "embed" - mock_mkdir.assert_called_once_with(parents=True, exist_ok=True) - - @patch("builtins.open", new_callable=mock_open) - @patch("os.path.getsize") - @patch("json.dumps") - def test_doc_to_json_default_output(self, mock_json_dumps, mock_getsize, mock_file, sample_document): - """Test document to JSON conversion with default output directory""" - mock_json_dumps.return_value = '{"test": "data"}' - mock_getsize.return_value = 100 - - result = embed.doc_to_json([sample_document], "/path/to/test_file.txt", "/tmp") - - mock_file.assert_called_once() - mock_json_dumps.assert_called_once() - mock_getsize.assert_called_once() - assert result.endswith("_test_file.json") - - @patch("builtins.open", new_callable=mock_open) - @patch("os.path.getsize") - @patch("json.dumps") - def test_doc_to_json_custom_output(self, mock_json_dumps, mock_getsize, mock_file, sample_document): - """Test document to JSON conversion with custom output directory""" - mock_json_dumps.return_value = '{"test": "data"}' - mock_getsize.return_value = 100 - - result = embed.doc_to_json([sample_document], "/path/to/test_file.txt", "/custom/output") - - mock_file.assert_called_once() - mock_json_dumps.assert_called_once() - mock_getsize.assert_called_once() - assert result == "/custom/output/_test_file.json" - - def test_logger_exists(self): - """Test that logger is properly configured""" - assert hasattr(embed, "logger") - assert embed.logger.name == "api.utils.embed" - - -class TestGetVectorStoreFiles: - """Test get_vector_store_files() function""" - - @pytest.fixture - def sample_db(self): - """Sample database fixture""" - return Database( - name="TEST_DB", - user="test_user", - password="", - dsn="localhost:1521/FREEPDB1" - ) - - @patch("server.api.utils.databases.connect") - @patch("server.api.utils.databases.disconnect") - def test_get_vector_store_files_with_metadata(self, mock_disconnect, mock_connect, sample_db): - """Test retrieving file list with complete metadata""" - # Mock database connection and cursor - mock_conn = MagicMock() - mock_cursor = MagicMock() - mock_conn.cursor.return_value = mock_cursor - mock_connect.return_value = mock_conn - - # Mock query results with metadata - mock_cursor.fetchall.return_value = [ - ({ - "filename": "doc1.pdf", - "size": 1024000, - "time_modified": "2025-11-01T10:00:00", - "etag": "etag-123" - },), - ({ - "filename": "doc1.pdf", - "size": 1024000, - "time_modified": "2025-11-01T10:00:00", - "etag": "etag-123" - },), - ({ - "filename": "doc2.txt", - "size": 2048, - "time_modified": "2025-11-02T10:00:00", - "etag": "etag-456" - },), - ] - - # Execute - result = embed.get_vector_store_files(sample_db, "TEST_VS") - - # Verify - assert result["vector_store"] == "TEST_VS" - assert result["total_files"] == 2 - assert result["total_chunks"] == 3 - assert result["orphaned_chunks"] == 0 - - # Verify files - assert len(result["files"]) == 2 - assert result["files"][0]["filename"] == "doc1.pdf" - assert result["files"][0]["chunk_count"] == 2 - assert result["files"][0]["size"] == 1024000 - assert result["files"][1]["filename"] == "doc2.txt" - assert result["files"][1]["chunk_count"] == 1 - - mock_disconnect.assert_called_once() - - @patch("server.api.utils.databases.connect") - @patch("server.api.utils.databases.disconnect") - def test_get_vector_store_files_with_decimal_size(self, _mock_disconnect, mock_connect, sample_db): - """Test handling of Decimal size from Oracle NUMBER type""" - # Mock database connection - mock_conn = MagicMock() - mock_cursor = MagicMock() - mock_conn.cursor.return_value = mock_cursor - mock_connect.return_value = mock_conn - - # Mock query results with Decimal size (from Oracle) - mock_cursor.fetchall.return_value = [ - ({ - "filename": "doc.pdf", - "size": Decimal("1024000"), # Oracle returns Decimal - "time_modified": "2025-11-01T10:00:00", - "etag": "etag-123" - },), - ] - - # Execute - result = embed.get_vector_store_files(sample_db, "TEST_VS") - - # Verify Decimal was converted to int - assert result["files"][0]["size"] == 1024000 - assert isinstance(result["files"][0]["size"], int) - - @patch("server.api.utils.databases.connect") - @patch("server.api.utils.databases.disconnect") - def test_get_vector_store_files_old_format(self, _mock_disconnect, mock_connect, sample_db): - """Test retrieving files with old metadata format (source field)""" - # Mock database connection - mock_conn = MagicMock() - mock_cursor = MagicMock() - mock_conn.cursor.return_value = mock_cursor - mock_connect.return_value = mock_conn - - # Mock query results with old format (source instead of filename) - mock_cursor.fetchall.return_value = [ - ({"source": "/path/to/doc1.pdf"},), - ({"source": "/path/to/doc1.pdf"},), - ] - - # Execute - result = embed.get_vector_store_files(sample_db, "TEST_VS") - - # Verify fallback to source field worked - assert result["total_files"] == 1 - assert result["files"][0]["filename"] == "doc1.pdf" - assert result["files"][0]["chunk_count"] == 2 - - @patch("server.api.utils.databases.connect") - @patch("server.api.utils.databases.disconnect") - def test_get_vector_store_files_with_orphaned_chunks(self, _mock_disconnect, mock_connect, sample_db): - """Test detection of orphaned chunks without valid filename""" - # Mock database connection - mock_conn = MagicMock() - mock_cursor = MagicMock() - mock_conn.cursor.return_value = mock_cursor - mock_connect.return_value = mock_conn - - # Mock query results with some orphaned chunks - mock_cursor.fetchall.return_value = [ - ({"filename": "doc1.pdf", "size": 1024},), - ({"filename": "doc1.pdf", "size": 1024},), - ({"other_field": "no_filename"},), # Orphaned chunk - ({"other_field": "no_source"},), # Orphaned chunk - ] - - # Execute - result = embed.get_vector_store_files(sample_db, "TEST_VS") - - # Verify - assert result["total_files"] == 1 - assert result["total_chunks"] == 2 - assert result["orphaned_chunks"] == 2 - assert result["files"][0]["chunk_count"] == 2 - - @patch("server.api.utils.databases.connect") - @patch("server.api.utils.databases.disconnect") - def test_get_vector_store_files_empty_store(self, _mock_disconnect, mock_connect, sample_db): - """Test retrieving from empty vector store""" - # Mock database connection - mock_conn = MagicMock() - mock_cursor = MagicMock() - mock_conn.cursor.return_value = mock_cursor - mock_connect.return_value = mock_conn - - # Mock empty results - mock_cursor.fetchall.return_value = [] - - # Execute - result = embed.get_vector_store_files(sample_db, "EMPTY_VS") - - # Verify - assert result["vector_store"] == "EMPTY_VS" - assert result["total_files"] == 0 - assert result["total_chunks"] == 0 - assert result["orphaned_chunks"] == 0 - assert len(result["files"]) == 0 - - @patch("server.api.utils.databases.connect") - @patch("server.api.utils.databases.disconnect") - def test_get_vector_store_files_sorts_by_filename(self, _mock_disconnect, mock_connect, sample_db): - """Test that files are sorted alphabetically by filename""" - # Mock database connection - mock_conn = MagicMock() - mock_cursor = MagicMock() - mock_conn.cursor.return_value = mock_cursor - mock_connect.return_value = mock_conn - - # Mock query results in random order - mock_cursor.fetchall.return_value = [ - ({"filename": "zebra.pdf"},), - ({"filename": "apple.txt"},), - ({"filename": "monkey.md"},), - ] - - # Execute - result = embed.get_vector_store_files(sample_db, "TEST_VS") - - # Verify sorted order - filenames = [f["filename"] for f in result["files"]] - assert filenames == ["apple.txt", "monkey.md", "zebra.pdf"] diff --git a/tests/server/unit/api/utils/test_utils_models.py b/tests/server/unit/api/utils/test_utils_models.py deleted file mode 100644 index d7451dde..00000000 --- a/tests/server/unit/api/utils/test_utils_models.py +++ /dev/null @@ -1,241 +0,0 @@ -""" -Copyright (c) 2024, 2025, Oracle and/or its affiliates. -Licensed under the Universal Permissive License v1.0 as shown at http://oss.oracle.com/licenses/upl. -""" -# spell-checker: disable -# pylint: disable=protected-access import-error import-outside-toplevel - -from unittest.mock import patch, MagicMock - -import pytest - -from conftest import get_sample_oci_config -from server.api.utils import models -from server.api.utils.models import URLUnreachableError, InvalidModelError, ExistsModelError, UnknownModelError -from common.schema import Model - - -##################################################### -# Exceptions -##################################################### -class TestModelsExceptions: - """Test custom exception classes""" - - # test_url_unreachable_error: See test/unit/server/api/utils/test_utils_models.py::TestExceptions::test_url_unreachable_error_is_value_error - # test_invalid_model_error: See test/unit/server/api/utils/test_utils_models.py::TestExceptions::test_invalid_model_error_is_value_error - # test_exists_model_error: See test/unit/server/api/utils/test_utils_models.py::TestExceptions::test_exists_model_error_is_value_error - # test_unknown_model_error: See test/unit/server/api/utils/test_utils_models.py::TestExceptions::test_unknown_model_error_is_value_error - pass - - -##################################################### -# CRUD Functions -##################################################### -class TestModelsCRUD: - """Test models module functionality""" - - @pytest.fixture - def sample_model(self): - """Sample model fixture""" - return Model( - id="test-model", provider="openai", type="ll", enabled=True, api_base="https://api.openai.com" - ) - - @pytest.fixture - def disabled_model(self): - """Disabled model fixture""" - return Model(id="disabled-model", provider="anthropic", type="ll", enabled=False) - - # test_get_model_all_models: See test/unit/server/api/utils/test_utils_models.py::TestGet::test_get_all_models - - @patch("server.api.utils.models.MODEL_OBJECTS") - def test_get_model_by_id_found(self, mock_model_objects, sample_model): - """Test getting model by ID when it exists""" - mock_model_objects.__iter__ = MagicMock(return_value=iter([sample_model])) - mock_model_objects.__len__ = MagicMock(return_value=1) - - (result,) = models.get(model_id="test-model") - - assert result == sample_model - - @patch("server.api.utils.models.MODEL_OBJECTS") - def test_get_model_by_id_not_found(self, mock_model_objects, sample_model): - """Test getting model by ID when it doesn't exist""" - mock_model_objects.__iter__ = MagicMock(return_value=iter([sample_model])) - mock_model_objects.__len__ = MagicMock(return_value=1) - - with pytest.raises(UnknownModelError, match="nonexistent not found"): - models.get(model_id="nonexistent") - - # test_get_model_by_provider: See test/unit/server/api/utils/test_utils_models.py::TestGet::test_get_by_provider - # test_get_model_by_type: See test/unit/server/api/utils/test_utils_models.py::TestGet::test_get_by_type - # test_get_model_exclude_disabled: See test/unit/server/api/utils/test_utils_models.py::TestGet::test_get_exclude_disabled - - # test_create_model_success: See test/unit/server/api/utils/test_utils_models.py::TestCreate::test_create_success - # test_create_model_already_exists: See test/unit/server/api/utils/test_utils_models.py::TestCreate::test_create_raises_exists_error - - @patch("server.api.utils.models.MODEL_OBJECTS", []) - @patch("server.api.utils.models.is_url_accessible") - def test_create_model_unreachable_url(self, mock_url_check): - """Test creating model with unreachable URL""" - # Create a model that starts as enabled - test_model = Model( - id="test-model", - provider="openai", - type="ll", - enabled=True, # Start as enabled - api_base="https://api.openai.com", - ) - - mock_url_check.return_value = (False, "Connection failed") - - result = models.create(test_model) - - assert result.enabled is False - - @patch("server.api.utils.models.MODEL_OBJECTS", []) - def test_create_model_skip_url_check(self, sample_model): - """Test creating model without URL check""" - result = models.create(sample_model, check_url=False) - - assert result == sample_model - assert result in models.MODEL_OBJECTS - - # test_delete_model: See test/unit/server/api/utils/test_utils_models.py::TestDelete::test_delete_removes_model - # test_logger_exists: See test/unit/server/api/utils/test_utils_models.py::TestLoggerConfiguration::test_logger_exists - pass - - -##################################################### -# Utility Functions -##################################################### -class TestModelsUtils: - """Test models utility functions""" - - @pytest.fixture - def sample_model(self): - """Sample model fixture""" - return Model( - id="test-model", provider="openai", type="ll", enabled=True, api_base="https://api.openai.com" - ) - - @pytest.fixture - def sample_oci_config(self): - """Sample OCI config fixture""" - return get_sample_oci_config() - - # test_update_success: See test/unit/server/api/utils/test_utils_models.py::TestUpdate::test_update_success - - @patch("server.api.utils.models.MODEL_OBJECTS", []) - @patch("server.api.utils.models.is_url_accessible") - def test_update_embedding_model_max_chunk_size(self, mock_url_check): - """Test updating max_chunk_size for embedding model (regression test for bug)""" - # Create an embedding model with default max_chunk_size - embed_model = Model( - id="test-embed-model", - provider="ollama", - type="embed", - enabled=True, - api_base="http://127.0.0.1:11434", - max_chunk_size=8192, - ) - models.MODEL_OBJECTS.append(embed_model) - mock_url_check.return_value = (True, None) - - # Update the max_chunk_size to 512 - update_payload = Model( - id="test-embed-model", - provider="ollama", - type="embed", - enabled=True, - api_base="http://127.0.0.1:11434", - max_chunk_size=512, - ) - - result = models.update(update_payload) - - # Verify the update was successful - assert result.max_chunk_size == 512 - assert result.id == "test-embed-model" - assert result.provider == "ollama" - - # Verify the model in MODEL_OBJECTS was updated - (updated_model,) = models.get(model_provider="ollama", model_id="test-embed-model") - assert updated_model.max_chunk_size == 512 - - @patch("server.api.utils.models.MODEL_OBJECTS", []) - @patch("server.api.utils.models.is_url_accessible") - def test_update_multiple_fields(self, mock_url_check, sample_model): - """Test updating multiple fields at once""" - # Create a model - models.MODEL_OBJECTS.append(sample_model) - mock_url_check.return_value = (True, None) - - # Update multiple fields - update_payload = Model( - id="test-model", - provider="openai", - type="ll", - enabled=False, # Changed from True - api_base="https://api.openai.com/v2", # Changed - temperature=0.5, # Changed - max_tokens=2048, # Changed - ) - - result = models.update(update_payload) - - assert result.enabled is False - assert result.api_base == "https://api.openai.com/v2" - assert result.temperature == 0.5 - assert result.max_tokens == 2048 - - # test_get_full_config_success: See test/unit/server/api/utils/test_utils_models.py::TestGetFullConfig::test_get_full_config_success - # test_get_full_config_unknown_model: See test/unit/server/api/utils/test_utils_models.py::TestGetFullConfig::test_get_full_config_raises_unknown_model - # test_get_litellm_config_basic: See test/unit/server/api/utils/test_utils_models.py::TestGetLitellmConfig::test_get_litellm_config_basic - - @patch("server.api.utils.models._get_full_config") - @patch("litellm.get_supported_openai_params") - def test_get_litellm_config_cohere(self, mock_get_params, mock_get_full_config, sample_oci_config): - """Test LiteLLM config generation for Cohere""" - mock_get_full_config.return_value = ({"api_base": "https://custom.cohere.com/v1"}, "cohere") - mock_get_params.return_value = [] - model_config = {"model": "cohere/command"} - - result = models.get_litellm_config(model_config, sample_oci_config) - - assert result["api_base"] == "https://api.cohere.ai/compatibility/v1" - assert result["model"] == "cohere/command" - - @patch("server.api.utils.models._get_full_config") - @patch("litellm.get_supported_openai_params") - def test_get_litellm_config_xai(self, mock_get_params, mock_get_full_config, sample_oci_config): - """Test LiteLLM config generation for xAI""" - mock_get_full_config.return_value = ( - {"temperature": 0.7, "presence_penalty": 0.1, "frequency_penalty": 0.1}, - "xai", - ) - mock_get_params.return_value = ["temperature", "presence_penalty", "frequency_penalty"] - model_config = {"model": "xai/grok"} - - result = models.get_litellm_config(model_config, sample_oci_config) - - assert result["temperature"] == 0.7 - assert "presence_penalty" not in result - assert "frequency_penalty" not in result - - # test_get_litellm_config_oci: See test/unit/server/api/utils/test_utils_models.py::TestGetLitellmConfig::test_get_litellm_config_oci_provider - - @patch("server.api.utils.models._get_full_config") - @patch("litellm.get_supported_openai_params") - def test_get_litellm_config_giskard(self, mock_get_params, mock_get_full_config, sample_oci_config): - """Test LiteLLM config generation for Giskard""" - mock_get_full_config.return_value = ({"temperature": 0.7, "model": "test-model"}, "openai") - mock_get_params.return_value = ["temperature", "model"] - model_config = {"model": "openai/gpt-4"} - - result = models.get_litellm_config(model_config, sample_oci_config, giskard=True) - - assert "model" not in result - assert "temperature" not in result - - # test_logger_exists: See test/unit/server/api/utils/test_utils_models.py::TestLoggerConfiguration::test_logger_exists diff --git a/tests/server/unit/api/utils/test_utils_oci.py b/tests/server/unit/api/utils/test_utils_oci.py deleted file mode 100644 index 39c0a4f1..00000000 --- a/tests/server/unit/api/utils/test_utils_oci.py +++ /dev/null @@ -1,337 +0,0 @@ -""" -Copyright (c) 2024, 2025, Oracle and/or its affiliates. -Licensed under the Universal Permissive License v1.0 as shown at http://oss.oracle.com/licenses/upl. -""" -# spell-checker: disable -# pylint: disable=protected-access import-error import-outside-toplevel - -from unittest.mock import patch, MagicMock - -import pytest -import oci - -from conftest import get_sample_oci_config -from server.api.utils import oci as oci_utils -from server.api.utils.oci import OciException -from common.schema import OracleCloudSettings, Settings, OciSettings - - -class TestOciException: - """Test custom OCI exception class""" - - # test_oci_exception_initialization: See test/unit/server/api/utils/test_utils_oci.py::TestOciException::test_oci_exception_init - - -class TestOciGet: - """Test OCI get() function""" - - @pytest.fixture - def sample_oci_default(self): - """Sample OCI config with DEFAULT profile""" - return OracleCloudSettings( - auth_profile="DEFAULT", compartment_id="ocid1.compartment.oc1..default" - ) - - @pytest.fixture - def sample_oci_custom(self): - """Sample OCI config with CUSTOM profile""" - return OracleCloudSettings( - auth_profile="CUSTOM", compartment_id="ocid1.compartment.oc1..custom" - ) - - @pytest.fixture - def sample_client_settings(self): - """Sample client settings fixture""" - return Settings(client="test_client", oci=OciSettings(auth_profile="CUSTOM")) - - # test_get_no_objects_configured: See test/unit/server/api/utils/test_utils_oci.py::TestGet::test_get_raises_value_error_when_not_configured - # test_get_all: See test/unit/server/api/utils/test_utils_oci.py::TestGet::test_get_returns_all_oci_objects - # test_get_by_auth_profile_found: See test/unit/server/api/utils/test_utils_oci.py::TestGet::test_get_by_auth_profile - # test_get_by_auth_profile_not_found: See test/unit/server/api/utils/test_utils_oci.py::TestGet::test_get_raises_value_error_profile_not_found - - def test_get_by_client_with_oci_settings(self, sample_client_settings, sample_oci_default, sample_oci_custom): - """Test getting OCI settings by client when client has OCI settings""" - from server.bootstrap import bootstrap - - # Save originals - orig_settings = bootstrap.SETTINGS_OBJECTS - orig_oci = bootstrap.OCI_OBJECTS - - try: - # Replace with test data - bootstrap.SETTINGS_OBJECTS = [sample_client_settings] - bootstrap.OCI_OBJECTS = [sample_oci_default, sample_oci_custom] - - result = oci_utils.get(client="test_client") - - assert result == sample_oci_custom - finally: - # Restore originals - bootstrap.SETTINGS_OBJECTS = orig_settings - bootstrap.OCI_OBJECTS = orig_oci - - def test_get_by_client_without_oci_settings(self, sample_oci_default): - """Test getting OCI settings by client when client has no OCI settings""" - from server.bootstrap import bootstrap - - client_settings_no_oci = Settings(client="test_client", oci=None) - - # Save originals - orig_settings = bootstrap.SETTINGS_OBJECTS - orig_oci = bootstrap.OCI_OBJECTS - - try: - # Replace with test data - bootstrap.SETTINGS_OBJECTS = [client_settings_no_oci] - bootstrap.OCI_OBJECTS = [sample_oci_default] - - result = oci_utils.get(client="test_client") - - assert result == sample_oci_default - finally: - # Restore originals - bootstrap.SETTINGS_OBJECTS = orig_settings - bootstrap.OCI_OBJECTS = orig_oci - - # test_get_by_client_not_found: See test/unit/server/api/utils/test_utils_oci.py::TestGet::test_get_raises_value_error_client_not_found - - def test_get_by_client_no_matching_profile(self, sample_client_settings, sample_oci_default): - """Test getting OCI settings by client when no matching profile exists""" - from server.bootstrap import bootstrap - - # Save originals - orig_settings = bootstrap.SETTINGS_OBJECTS - orig_oci = bootstrap.OCI_OBJECTS - - try: - # Replace with test data - bootstrap.SETTINGS_OBJECTS = [sample_client_settings] - bootstrap.OCI_OBJECTS = [sample_oci_default] # Only DEFAULT profile - - expected_error = "No settings found for client 'test_client' with auth_profile 'CUSTOM'" - with pytest.raises(ValueError, match=expected_error): - oci_utils.get(client="test_client") - finally: - # Restore originals - bootstrap.SETTINGS_OBJECTS = orig_settings - bootstrap.OCI_OBJECTS = orig_oci - - # test_get_both_client_and_auth_profile: See test/unit/server/api/utils/test_utils_oci.py::TestGet::test_get_raises_value_error_both_params - - -class TestGetSigner: - """Test get_signer() function""" - - # test_get_signer_instance_principal: See test/unit/server/api/utils/test_utils_oci.py::TestGetSigner::test_get_signer_instance_principal - # test_get_signer_oke_workload_identity: See test/unit/server/api/utils/test_utils_oci.py::TestGetSigner::test_get_signer_oke_workload_identity - # test_get_signer_api_key: See test/unit/server/api/utils/test_utils_oci.py::TestGetSigner::test_get_signer_api_key_returns_none - - def test_get_signer_security_token(self): - """Test get_signer with security_token authentication (returns None)""" - config = OracleCloudSettings(auth_profile="DEFAULT", authentication="security_token") - - result = oci_utils.get_signer(config) - - assert result is None - - -class TestInitClient: - """Test init_client() function""" - - @pytest.fixture - def api_key_config(self): - """API key configuration fixture""" - return OracleCloudSettings( - auth_profile="DEFAULT", - authentication="api_key", - region="us-ashburn-1", - user="ocid1.user.oc1..testuser", - fingerprint="test-fingerprint", - tenancy="ocid1.tenancy.oc1..testtenant", - key_file="/path/to/key.pem", - ) - - # test_init_client_api_key: See test/unit/server/api/utils/test_utils_oci.py::TestInitClient::test_init_client_standard_auth - - @patch("oci.generative_ai_inference.GenerativeAiInferenceClient") - @patch.object(oci_utils, "get_signer", return_value=None) - def test_init_client_genai_with_endpoint(self, _mock_get_signer, mock_client_class, api_key_config): - """Test init_client for GenAI sets correct service endpoint""" - genai_config = api_key_config.model_copy() - genai_config.genai_compartment_id = "ocid1.compartment.oc1..test" - genai_config.genai_region = "us-chicago-1" - - mock_client = MagicMock() - mock_client_class.return_value = mock_client - - result = oci_utils.init_client(oci.generative_ai_inference.GenerativeAiInferenceClient, genai_config) - - assert result == mock_client - # Verify service_endpoint was set in kwargs - call_kwargs = mock_client_class.call_args[1] - assert "service_endpoint" in call_kwargs - assert "us-chicago-1" in call_kwargs["service_endpoint"] - - @patch("oci.identity.IdentityClient") - @patch.object(oci_utils, "get_signer") - def test_init_client_with_instance_principal_signer(self, mock_get_signer, mock_client_class): - """Test init_client with instance principal signer""" - instance_config = OracleCloudSettings( - auth_profile="DEFAULT", - authentication="instance_principal", - region="us-ashburn-1", - tenancy=None, # Will be set from signer - ) - - mock_signer = MagicMock() - mock_signer.tenancy_id = "ocid1.tenancy.oc1..test" - mock_get_signer.return_value = mock_signer - - mock_client = MagicMock() - mock_client_class.return_value = mock_client - - result = oci_utils.init_client(oci.identity.IdentityClient, instance_config) - - assert result == mock_client - # Verify signer was used - call_kwargs = mock_client_class.call_args[1] - assert call_kwargs["signer"] == mock_signer - # Verify tenancy was set from signer - assert instance_config.tenancy == "ocid1.tenancy.oc1..test" - - @patch("oci.identity.IdentityClient") - @patch.object(oci_utils, "get_signer") - def test_init_client_with_workload_identity_signer(self, mock_get_signer, mock_client_class): - """Test init_client with OKE workload identity signer""" - workload_config = OracleCloudSettings( - auth_profile="DEFAULT", - authentication="oke_workload_identity", - region="us-ashburn-1", - tenancy=None, # Will be extracted from token - ) - - # Mock JWT token with tenant claim - import base64 - import json - - payload = {"tenant": "ocid1.tenancy.oc1..workload"} - payload_json = json.dumps(payload) - payload_b64 = base64.urlsafe_b64encode(payload_json.encode()).decode().rstrip("=") - mock_token = f"header.{payload_b64}.signature" - - mock_signer = MagicMock() - mock_signer.get_security_token.return_value = mock_token - mock_get_signer.return_value = mock_signer - - mock_client = MagicMock() - mock_client_class.return_value = mock_client - - result = oci_utils.init_client(oci.identity.IdentityClient, workload_config) - - assert result == mock_client - # Verify tenancy was extracted from token - assert workload_config.tenancy == "ocid1.tenancy.oc1..workload" - - @patch("oci.identity.IdentityClient") - @patch.object(oci_utils, "get_signer", return_value=None) - @patch("builtins.open", new_callable=MagicMock) - @patch("oci.signer.load_private_key_from_file") - @patch("oci.auth.signers.SecurityTokenSigner") - def test_init_client_with_security_token( - self, mock_sec_token_signer, mock_load_key, mock_open, _mock_get_signer, mock_client_class - ): - """Test init_client with security token authentication""" - token_config = OracleCloudSettings( - auth_profile="DEFAULT", - authentication="security_token", - region="us-ashburn-1", - security_token_file="/path/to/token", - key_file="/path/to/key.pem", - ) - - # Mock file reading - mock_open.return_value.__enter__.return_value.read.return_value = "mock_token_content" - mock_private_key = MagicMock() - mock_load_key.return_value = mock_private_key - mock_signer_instance = MagicMock() - mock_sec_token_signer.return_value = mock_signer_instance - - mock_client = MagicMock() - mock_client_class.return_value = mock_client - - result = oci_utils.init_client(oci.identity.IdentityClient, token_config) - - assert result == mock_client - mock_load_key.assert_called_once_with("/path/to/key.pem") - mock_sec_token_signer.assert_called_once_with("mock_token_content", mock_private_key) - - # test_init_client_invalid_config: See test/unit/server/api/utils/test_utils_oci.py::TestInitClient::test_init_client_raises_oci_exception_on_invalid_config - - -class TestOciUtils: - """Test OCI utility functions""" - - @pytest.fixture - def sample_oci_config(self): - """Sample OCI config fixture""" - return get_sample_oci_config() - - # test_init_genai_client: See test/unit/server/api/utils/test_utils_oci.py::TestInitGenaiClient::test_init_genai_client_calls_init_client - # test_get_namespace_success: See test/unit/server/api/utils/test_utils_oci.py::TestGetNamespace::test_get_namespace_success - - @patch.object(oci_utils, "init_client") - def test_get_namespace_invalid_config(self, mock_init_client, sample_oci_config): - """Test namespace retrieval with invalid config""" - mock_client = MagicMock() - mock_client.get_namespace.side_effect = oci.exceptions.InvalidConfig("Invalid config") - mock_init_client.return_value = mock_client - - with pytest.raises(OciException) as exc_info: - oci_utils.get_namespace(sample_oci_config) - - assert exc_info.value.status_code == 400 - assert "Invalid Config" in str(exc_info.value) - - # test_get_namespace_file_not_found: See test/unit/server/api/utils/test_utils_oci.py::TestGetNamespace::test_get_namespace_raises_on_file_not_found - # test_get_namespace_service_error: See test/unit/server/api/utils/test_utils_oci.py::TestGetNamespace::test_get_namespace_raises_on_service_error - - @patch.object(oci_utils, "init_client") - def test_get_namespace_unbound_local_error(self, mock_init_client, sample_oci_config): - """Test namespace retrieval with unbound local error""" - mock_client = MagicMock() - mock_client.get_namespace.side_effect = UnboundLocalError("local variable referenced before assignment") - mock_init_client.return_value = mock_client - - with pytest.raises(OciException) as exc_info: - oci_utils.get_namespace(sample_oci_config) - - assert exc_info.value.status_code == 500 - assert "No Configuration" in str(exc_info.value) - - @patch.object(oci_utils, "init_client") - def test_get_namespace_request_exception(self, mock_init_client, sample_oci_config): - """Test namespace retrieval with request exception""" - mock_client = MagicMock() - mock_client.get_namespace.side_effect = oci.exceptions.RequestException("Connection timeout") - mock_init_client.return_value = mock_client - - with pytest.raises(OciException) as exc_info: - oci_utils.get_namespace(sample_oci_config) - - assert exc_info.value.status_code == 503 - - @patch.object(oci_utils, "init_client") - def test_get_namespace_generic_exception(self, mock_init_client, sample_oci_config): - """Test namespace retrieval with generic exception""" - mock_client = MagicMock() - mock_client.get_namespace.side_effect = Exception("Unexpected error") - mock_init_client.return_value = mock_client - - with pytest.raises(OciException) as exc_info: - oci_utils.get_namespace(sample_oci_config) - - assert exc_info.value.status_code == 500 - assert "Unexpected error" in str(exc_info.value) - - # test_get_regions_success: See test/unit/server/api/utils/test_utils_oci.py::TestGetRegions::test_get_regions_returns_list - # test_logger_exists: See test/unit/server/api/utils/test_utils_oci.py::TestLoggerConfiguration::test_logger_exists diff --git a/tests/server/unit/api/utils/test_utils_oci_refresh.py b/tests/server/unit/api/utils/test_utils_oci_refresh.py deleted file mode 100644 index 72b81920..00000000 --- a/tests/server/unit/api/utils/test_utils_oci_refresh.py +++ /dev/null @@ -1,248 +0,0 @@ -""" -Copyright (c) 2024, 2025, Oracle and/or its affiliates. -Licensed under the Universal Permissive License v1.0 as shown at http://oss.oracle.com/licenses/upl. -""" -# spell-checker: disable -# pylint: disable=protected-access import-error import-outside-toplevel - -from datetime import datetime -from unittest.mock import patch, MagicMock - -import pytest - -from server.api.utils import oci as oci_utils -from common.schema import OracleCloudSettings - - -class TestGetBucketObjectsWithMetadata: - """Test get_bucket_objects_with_metadata() function""" - - @pytest.fixture - def sample_oci_config(self): - """Sample OCI config fixture""" - return OracleCloudSettings( - auth_profile="DEFAULT", - namespace="test-namespace", - compartment_id="ocid1.compartment.oc1..test", - region="us-ashburn-1", - ) - - def create_mock_object(self, name, size, etag, time_modified, md5): - """Create a mock OCI object""" - mock_obj = MagicMock() - mock_obj.name = name - mock_obj.size = size - mock_obj.etag = etag - mock_obj.time_modified = time_modified - mock_obj.md5 = md5 - return mock_obj - - @patch.object(oci_utils, "init_client") - def test_get_bucket_objects_with_metadata_success(self, mock_init_client, sample_oci_config): - """Test successful retrieval of bucket objects with metadata""" - # Create mock objects - time1 = datetime(2025, 11, 1, 10, 0, 0) - time2 = datetime(2025, 11, 2, 10, 0, 0) - - mock_obj1 = self.create_mock_object( - name="document1.pdf", size=1024000, etag="etag-123", time_modified=time1, md5="md5-hash-1" - ) - mock_obj2 = self.create_mock_object( - name="document2.txt", size=2048, etag="etag-456", time_modified=time2, md5="md5-hash-2" - ) - - # Mock client - mock_client = MagicMock() - mock_response = MagicMock() - mock_response.data.objects = [mock_obj1, mock_obj2] - mock_client.list_objects.return_value = mock_response - mock_init_client.return_value = mock_client - - # Execute - result = oci_utils.get_bucket_objects_with_metadata("test-bucket", sample_oci_config) - - # Verify - assert len(result) == 2 - assert result[0]["name"] == "document1.pdf" - assert result[0]["size"] == 1024000 - assert result[0]["etag"] == "etag-123" - assert result[0]["time_modified"] == time1.isoformat() - assert result[0]["md5"] == "md5-hash-1" - assert result[0]["extension"] == "pdf" - - assert result[1]["name"] == "document2.txt" - assert result[1]["size"] == 2048 - - # Verify fields parameter was passed - call_kwargs = mock_client.list_objects.call_args[1] - assert "fields" in call_kwargs - assert "name" in call_kwargs["fields"] - assert "size" in call_kwargs["fields"] - assert "etag" in call_kwargs["fields"] - - @patch.object(oci_utils, "init_client") - def test_get_bucket_objects_filters_unsupported_types(self, mock_init_client, sample_oci_config): - """Test that unsupported file types are filtered out""" - # Create mock objects with various file types - mock_pdf = self.create_mock_object("doc.pdf", 1000, "etag1", datetime.now(), "md5-1") - mock_exe = self.create_mock_object("app.exe", 2000, "etag2", datetime.now(), "md5-2") - mock_txt = self.create_mock_object("file.txt", 3000, "etag3", datetime.now(), "md5-3") - mock_zip = self.create_mock_object("archive.zip", 4000, "etag4", datetime.now(), "md5-4") - - # Mock client - mock_client = MagicMock() - mock_response = MagicMock() - mock_response.data.objects = [mock_pdf, mock_exe, mock_txt, mock_zip] - mock_client.list_objects.return_value = mock_response - mock_init_client.return_value = mock_client - - # Execute - result = oci_utils.get_bucket_objects_with_metadata("test-bucket", sample_oci_config) - - # Verify only supported types are included - assert len(result) == 2 - names = [obj["name"] for obj in result] - assert "doc.pdf" in names - assert "file.txt" in names - assert "app.exe" not in names - assert "archive.zip" not in names - - @patch.object(oci_utils, "init_client") - def test_get_bucket_objects_empty_bucket(self, mock_init_client, sample_oci_config): - """Test handling of empty bucket""" - # Mock empty bucket - mock_client = MagicMock() - mock_response = MagicMock() - mock_response.data.objects = [] - mock_client.list_objects.return_value = mock_response - mock_init_client.return_value = mock_client - - # Execute - result = oci_utils.get_bucket_objects_with_metadata("empty-bucket", sample_oci_config) - - # Verify - assert len(result) == 0 - - @patch.object(oci_utils, "init_client") - def test_get_bucket_objects_none_time_modified(self, mock_init_client, sample_oci_config): - """Test handling of objects with None time_modified""" - # Create mock object with None time_modified - mock_obj = self.create_mock_object( - name="document.pdf", size=1024, etag="etag-123", time_modified=None, md5="md5-hash" - ) - - # Mock client - mock_client = MagicMock() - mock_response = MagicMock() - mock_response.data.objects = [mock_obj] - mock_client.list_objects.return_value = mock_response - mock_init_client.return_value = mock_client - - # Execute - result = oci_utils.get_bucket_objects_with_metadata("test-bucket", sample_oci_config) - - # Verify time_modified is None - assert len(result) == 1 - assert result[0]["time_modified"] is None - - -class TestDetectChangedObjects: - """Test detect_changed_objects() function""" - - def test_detect_all_new_objects(self): - """Test detection when all objects are new""" - current_objects = [ - {"name": "file1.pdf", "etag": "etag1", "time_modified": "2025-11-01T10:00:00"}, - {"name": "file2.pdf", "etag": "etag2", "time_modified": "2025-11-02T10:00:00"}, - ] - processed_objects = {} - - new_objects, modified_objects = oci_utils.detect_changed_objects(current_objects, processed_objects) - - assert len(new_objects) == 2 - assert len(modified_objects) == 0 - assert new_objects[0]["name"] == "file1.pdf" - assert new_objects[1]["name"] == "file2.pdf" - - def test_detect_modified_objects_by_etag(self): - """Test detection of modified objects by ETag change""" - current_objects = [ - {"name": "file1.pdf", "etag": "etag1-new", "time_modified": "2025-11-01T10:00:00"}, - {"name": "file2.pdf", "etag": "etag2", "time_modified": "2025-11-02T10:00:00"}, - ] - processed_objects = { - "file1.pdf": {"etag": "etag1-old", "time_modified": "2025-11-01T10:00:00"}, - "file2.pdf": {"etag": "etag2", "time_modified": "2025-11-02T10:00:00"}, - } - - new_objects, modified_objects = oci_utils.detect_changed_objects(current_objects, processed_objects) - - assert len(new_objects) == 0 - assert len(modified_objects) == 1 - assert modified_objects[0]["name"] == "file1.pdf" - assert modified_objects[0]["etag"] == "etag1-new" - - def test_detect_modified_objects_by_time(self): - """Test detection of modified objects by modification time change""" - current_objects = [ - {"name": "file1.pdf", "etag": "etag1", "time_modified": "2025-11-01T12:00:00"}, - ] - processed_objects = { - "file1.pdf": {"etag": "etag1", "time_modified": "2025-11-01T10:00:00"}, - } - - new_objects, modified_objects = oci_utils.detect_changed_objects(current_objects, processed_objects) - - assert len(new_objects) == 0 - assert len(modified_objects) == 1 - assert modified_objects[0]["name"] == "file1.pdf" - - def test_detect_no_changes(self): - """Test detection when no changes exist""" - current_objects = [ - {"name": "file1.pdf", "etag": "etag1", "time_modified": "2025-11-01T10:00:00"}, - {"name": "file2.pdf", "etag": "etag2", "time_modified": "2025-11-02T10:00:00"}, - ] - processed_objects = { - "file1.pdf": {"etag": "etag1", "time_modified": "2025-11-01T10:00:00"}, - "file2.pdf": {"etag": "etag2", "time_modified": "2025-11-02T10:00:00"}, - } - - new_objects, modified_objects = oci_utils.detect_changed_objects(current_objects, processed_objects) - - assert len(new_objects) == 0 - assert len(modified_objects) == 0 - - def test_detect_mixed_changes(self): - """Test detection with mix of new, modified, and unchanged objects""" - current_objects = [ - {"name": "file1.pdf", "etag": "etag1", "time_modified": "2025-11-01T10:00:00"}, # unchanged - {"name": "file2.pdf", "etag": "etag2-new", "time_modified": "2025-11-02T10:00:00"}, # modified - {"name": "file3.pdf", "etag": "etag3", "time_modified": "2025-11-03T10:00:00"}, # new - ] - processed_objects = { - "file1.pdf": {"etag": "etag1", "time_modified": "2025-11-01T10:00:00"}, - "file2.pdf": {"etag": "etag2-old", "time_modified": "2025-11-02T10:00:00"}, - } - - new_objects, modified_objects = oci_utils.detect_changed_objects(current_objects, processed_objects) - - assert len(new_objects) == 1 - assert len(modified_objects) == 1 - assert new_objects[0]["name"] == "file3.pdf" - assert modified_objects[0]["name"] == "file2.pdf" - - def test_skip_old_format_objects(self): - """Test that objects with old format (no etag/time_modified) are skipped""" - current_objects = [ - {"name": "file1.pdf", "etag": "etag1", "time_modified": "2025-11-01T10:00:00"}, - ] - processed_objects = { - "file1.pdf": {"etag": None, "time_modified": None}, # Old format - } - - new_objects, modified_objects = oci_utils.detect_changed_objects(current_objects, processed_objects) - - # Should skip the old format object to avoid duplicates - assert len(new_objects) == 0 - assert len(modified_objects) == 0 diff --git a/tests/server/unit/api/utils/test_utils_settings.py b/tests/server/unit/api/utils/test_utils_settings.py deleted file mode 100644 index 3027874f..00000000 --- a/tests/server/unit/api/utils/test_utils_settings.py +++ /dev/null @@ -1,133 +0,0 @@ -""" -Copyright (c) 2024, 2025, Oracle and/or its affiliates. -Licensed under the Universal Permissive License v1.0 as shown at http://oss.oracle.com/licenses/upl. -""" -# spell-checker: disable -# pylint: disable=protected-access import-error import-outside-toplevel - -from unittest.mock import patch, MagicMock, mock_open -import os - -import pytest - -from server.api.utils import settings -from common.schema import Settings, Configuration, Database, Model, OracleCloudSettings - - -##################################################### -# Helper functions for test data -##################################################### -def make_default_settings(): - """Create default settings for tests""" - return Settings(client="default") - - -def make_test_client_settings(): - """Create test client settings for tests""" - return Settings(client="test_client") - - -def make_sample_config_data(): - """Create sample configuration data for tests""" - return { - "database_configs": [{"name": "test_db", "user": "user", "password": "pass", "dsn": "dsn"}], - "model_configs": [{"id": "test-model", "provider": "openai", "type": "ll"}], - "oci_configs": [{"auth_profile": "DEFAULT", "compartment_id": "ocid1.compartment.oc1..test"}], - "prompt_overrides": {"optimizer_basic-default": "You are helpful"}, - "client_settings": {"client": "default", "max_tokens": 1000, "temperature": 0.7}, - } - - -##################################################### -# Client Settings Tests -##################################################### -class TestClientSettings: - """Test client settings CRUD operations""" - - # test_create_client_success: See test/unit/server/api/utils/test_utils_settings.py::TestCreateClient::test_create_client_success - # test_create_client_already_exists: See test/unit/server/api/utils/test_utils_settings.py::TestCreateClient::test_create_client_raises_on_existing - # test_get_client_found: See test/unit/server/api/utils/test_utils_settings.py::TestGetClient::test_get_client_success - # test_get_client_not_found: See test/unit/server/api/utils/test_utils_settings.py::TestGetClient::test_get_client_raises_on_not_found - # test_update_client: See test/unit/server/api/utils/test_utils_settings.py::TestUpdateClient::test_update_client_success - pass - - -##################################################### -# Server Configuration Tests -##################################################### -class TestServerConfiguration: - """Test server configuration operations""" - - # test_get_server: See test/unit/server/api/utils/test_utils_settings.py::TestGetServer::test_get_server_returns_config - # test_update_server: See test/unit/server/api/utils/test_utils_settings.py::TestUpdateServer::test_update_server_updates_databases - - @patch("server.api.utils.settings.bootstrap") - def test_update_server_mutates_lists_not_replaces(self, mock_bootstrap): - """Test that update_server mutates existing lists rather than replacing them. - - This is critical because other modules import these lists directly - (e.g., `from server.bootstrap.bootstrap import DATABASE_OBJECTS`). - If we replace the list, those modules would hold stale references. - """ - original_db_list = [] - original_model_list = [] - original_oci_list = [] - - mock_bootstrap.DATABASE_OBJECTS = original_db_list - mock_bootstrap.MODEL_OBJECTS = original_model_list - mock_bootstrap.OCI_OBJECTS = original_oci_list - - settings.update_server(make_sample_config_data()) - - # Verify the lists are the SAME objects (mutated, not replaced) - assert mock_bootstrap.DATABASE_OBJECTS is original_db_list, "DATABASE_OBJECTS was replaced instead of mutated" - assert mock_bootstrap.MODEL_OBJECTS is original_model_list, "MODEL_OBJECTS was replaced instead of mutated" - assert mock_bootstrap.OCI_OBJECTS is original_oci_list, "OCI_OBJECTS was replaced instead of mutated" - - # Verify the lists now contain the new data - assert len(original_db_list) == 1 - assert original_db_list[0].name == "test_db" - assert len(original_model_list) == 1 - assert original_model_list[0].id == "test-model" - assert len(original_oci_list) == 1 - assert original_oci_list[0].auth_profile == "DEFAULT" - - -##################################################### -# Config Loading Tests -##################################################### -class TestConfigLoading: - """Test configuration loading operations""" - - # test_load_config_from_json_data_with_client: See test/unit/server/api/utils/test_utils_settings.py::TestLoadConfigFromJsonData::test_load_config_from_json_data_with_client - # test_load_config_from_json_data_without_client: See test/unit/server/api/utils/test_utils_settings.py::TestLoadConfigFromJsonData::test_load_config_from_json_data_without_client - # test_load_config_from_json_data_missing_client_settings: See test/unit/server/api/utils/test_utils_settings.py::TestLoadConfigFromJsonData::test_load_config_from_json_data_raises_missing_settings - # test_read_config_from_json_file_success: See test/unit/server/api/utils/test_utils_settings.py::TestReadConfigFromJsonFile::test_read_config_from_json_file_success - # test_read_config_from_json_file_not_exists: Empty test stub - not implemented - # test_read_config_from_json_file_wrong_extension: Empty test stub - not implemented - # test_logger_exists: See test/unit/server/api/utils/test_utils_settings.py::TestLoggerConfiguration::test_logger_exists - pass - - -##################################################### -# Prompt Override Tests -##################################################### -class TestPromptOverrides: - """Test prompt override operations""" - - # test_load_prompt_override_with_text: See test/unit/server/api/utils/test_utils_settings.py::TestLoadPromptOverride::test_load_prompt_override_with_text - # test_load_prompt_override_without_text: See test/unit/server/api/utils/test_utils_settings.py::TestLoadPromptOverride::test_load_prompt_override_without_text - - @patch("server.api.utils.settings.cache") - def test_load_prompt_override_empty_text(self, mock_cache): - """Test loading prompt override when text is empty string""" - prompt = {"name": "optimizer_test-prompt", "text": ""} - - result = settings._load_prompt_override(prompt) - - assert result is False - mock_cache.set_override.assert_not_called() - - # test_load_prompt_configs_success: See test/unit/server/api/utils/test_utils_settings.py::TestLoadPromptConfigs::test_load_prompt_configs_with_prompts - # test_load_prompt_configs_no_prompts_key: See test/unit/server/api/utils/test_utils_settings.py::TestLoadPromptConfigs::test_load_prompt_configs_without_key - # test_load_prompt_configs_empty_list: See test/unit/server/api/utils/test_utils_settings.py::TestLoadPromptConfigs::test_load_prompt_configs_empty_list diff --git a/tests/server/unit/api/utils/test_utils_testbed.py b/tests/server/unit/api/utils/test_utils_testbed.py deleted file mode 100644 index 7137d4a3..00000000 --- a/tests/server/unit/api/utils/test_utils_testbed.py +++ /dev/null @@ -1,99 +0,0 @@ -""" -Copyright (c) 2024, 2025, Oracle and/or its affiliates. -Licensed under the Universal Permissive License v1.0 as shown at http://oss.oracle.com/licenses/upl. -""" -# spell-checker: disable -# pylint: disable=protected-access import-error import-outside-toplevel - -from unittest.mock import patch, MagicMock -import json - -import pytest -from oracledb import Connection - -from server.api.utils import testbed - - -class TestTestbedUtils: - """Test testbed utility functions""" - - @pytest.fixture - def mock_connection(self): - """Mock database connection fixture""" - return MagicMock(spec=Connection) - - @pytest.fixture - def sample_qa_data(self): - """Sample QA data fixture""" - return { - "question": "What is the capital of France?", - "answer": "Paris", - "context": "France is a country in Europe.", - } - - def test_jsonl_to_json_content_single_json(self): - """Test converting single JSON object to JSON content""" - content = '{"key": "value"}' - result = testbed.jsonl_to_json_content(content) - expected = json.dumps({"key": "value"}) - assert result == expected - - def test_jsonl_to_json_content_jsonl_multiple_lines(self): - """Test converting JSONL with multiple lines to JSON content""" - content = '{"line": 1}\n{"line": 2}\n{"line": 3}' - result = testbed.jsonl_to_json_content(content) - expected = json.dumps([{"line": 1}, {"line": 2}, {"line": 3}]) - assert result == expected - - def test_jsonl_to_json_content_jsonl_single_line(self): - """Test converting JSONL with single line to JSON content""" - content = '{"single": "line"}' - result = testbed.jsonl_to_json_content(content) - expected = json.dumps({"single": "line"}) - assert result == expected - - def test_jsonl_to_json_content_bytes_input(self): - """Test converting bytes JSONL content to JSON""" - content = b'{"bytes": "content"}' - result = testbed.jsonl_to_json_content(content) - expected = json.dumps({"bytes": "content"}) - assert result == expected - - def test_jsonl_to_json_content_invalid_json(self): - """Test handling invalid JSON content""" - content = '{"invalid": json}' - with pytest.raises(ValueError, match="Invalid JSONL content"): - testbed.jsonl_to_json_content(content) - - def test_jsonl_to_json_content_empty_content(self): - """Test handling empty content""" - content = "" - with pytest.raises(ValueError, match="Invalid JSONL content"): - testbed.jsonl_to_json_content(content) - - def test_jsonl_to_json_content_whitespace_content(self): - """Test handling whitespace-only content""" - content = " \n \n " - with pytest.raises(ValueError, match="Invalid JSONL content"): - testbed.jsonl_to_json_content(content) - - @patch("server.api.utils.databases.execute_sql") - def test_create_testset_objects(self, mock_execute_sql, mock_connection): - """Test creating testset database objects""" - mock_execute_sql.return_value = [] - - testbed.create_testset_objects(mock_connection) - - # Should execute 3 SQL statements (testsets, testset_qa, evaluations tables) - assert mock_execute_sql.call_count == 3 - - # Verify table creation statements - call_args_list = mock_execute_sql.call_args_list - assert "oai_testsets" in call_args_list[0][0][1] - assert "oai_testset_qa" in call_args_list[1][0][1] - assert "oai_evaluations" in call_args_list[2][0][1] - - def test_logger_exists(self): - """Test that logger is properly configured""" - assert hasattr(testbed, "logger") - assert testbed.logger.name == "api.utils.testbed" From 2ad4f41d7b154ed7349a7095ca622840cb6b0f6b Mon Sep 17 00:00:00 2001 From: gotsysdba Date: Sat, 29 Nov 2025 01:24:47 +0000 Subject: [PATCH 06/20] pylint --- test/unit/server/api/utils/test_utils_mcp.py | 6 +- test/unit/server/api/utils/test_utils_oci.py | 20 ++- test/unit/server/api/v1/test_v1_embed.py | 125 ++++++++----------- test/unit/server/api/v1/test_v1_mcp.py | 6 +- 4 files changed, 70 insertions(+), 87 deletions(-) diff --git a/test/unit/server/api/utils/test_utils_mcp.py b/test/unit/server/api/utils/test_utils_mcp.py index f2c69a5a..4fb234c1 100644 --- a/test/unit/server/api/utils/test_utils_mcp.py +++ b/test/unit/server/api/utils/test_utils_mcp.py @@ -6,11 +6,11 @@ Tests for MCP utility functions. """ -from unittest.mock import patch, MagicMock, AsyncMock import os -import pytest - from test.shared_fixtures import TEST_API_KEY, TEST_API_KEY_ALT +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest from server.api.utils import mcp diff --git a/test/unit/server/api/utils/test_utils_oci.py b/test/unit/server/api/utils/test_utils_oci.py index d11acbad..f59a855f 100644 --- a/test/unit/server/api/utils/test_utils_oci.py +++ b/test/unit/server/api/utils/test_utils_oci.py @@ -8,11 +8,14 @@ # pylint: disable=too-few-public-methods +import base64 +import json from datetime import datetime -from unittest.mock import patch, MagicMock +from unittest.mock import MagicMock, patch -import pytest import oci +import pytest +from urllib3.exceptions import MaxRetryError from server.api.utils import oci as utils_oci from server.api.utils.oci import OciException @@ -624,9 +627,6 @@ class TestInitClientOkeWorkloadIdentityTenancy: @patch("server.api.utils.oci.oci.object_storage.ObjectStorageClient") def test_init_client_oke_workload_extracts_tenancy(self, mock_client_class, mock_get_signer, make_oci_config): """init_client should extract tenancy from OKE workload identity token.""" - import base64 - import json - # Create a mock JWT token with tenant claim payload = {"tenant": "ocid1.tenancy.oc1..test"} payload_json = json.dumps(payload) @@ -710,15 +710,13 @@ def test_get_genai_models_handles_service_error(self, mock_init_client, make_oci result = utils_oci.get_genai_models(config, regional=True) # Should return empty list instead of raising - assert result == [] + assert not result @patch("server.api.utils.oci.init_client") def test_get_genai_models_handles_request_exception(self, mock_init_client, make_oci_config): """get_genai_models should handle RequestException gracefully.""" - import urllib3.exceptions - mock_client = MagicMock() - mock_client.list_models.side_effect = urllib3.exceptions.MaxRetryError(None, "url") + mock_client.list_models.side_effect = MaxRetryError(None, "url") mock_init_client.return_value = mock_client config = make_oci_config(genai_region="us-chicago-1") @@ -727,7 +725,7 @@ def test_get_genai_models_handles_request_exception(self, mock_init_client, make result = utils_oci.get_genai_models(config, regional=True) # Should return empty list instead of raising - assert result == [] + assert not result @patch("server.api.utils.oci.init_client") def test_get_genai_models_excludes_deprecated(self, mock_init_client, make_oci_config): @@ -783,7 +781,7 @@ def test_get_bucket_objects_with_metadata_returns_empty_on_service_error(self, m result = utils_oci.get_bucket_objects_with_metadata("nonexistent-bucket", config) - assert result == [] + assert not result class TestGetClientDerivedAuthProfileNoMatch: diff --git a/test/unit/server/api/v1/test_v1_embed.py b/test/unit/server/api/v1/test_v1_embed.py index 094d01ae..2a2dcdc8 100644 --- a/test/unit/server/api/v1/test_v1_embed.py +++ b/test/unit/server/api/v1/test_v1_embed.py @@ -47,6 +47,31 @@ def split_embed_mocks(): } +@pytest.fixture +def refresh_vector_store_mocks(): + """Fixture providing bundled mocks for refresh_vector_store tests.""" + with patch("server.api.v1.embed.utils_oci.get") as mock_oci_get, \ + patch("server.api.v1.embed.utils_databases.get_client_database") as mock_get_db, \ + patch("server.api.v1.embed.utils_embed.get_vector_store_by_alias") as mock_get_vs, \ + patch("server.api.v1.embed.utils_oci.get_bucket_objects_with_metadata") as mock_get_objects, \ + patch("server.api.v1.embed.utils_embed.get_processed_objects_metadata") as mock_get_processed, \ + patch("server.api.v1.embed.utils_oci.detect_changed_objects") as mock_detect_changed, \ + patch("server.api.v1.embed.utils_embed.get_total_chunks_count") as mock_get_chunks, \ + patch("server.api.v1.embed.utils_models.get_client_embed") as mock_get_embed, \ + patch("server.api.v1.embed.utils_embed.refresh_vector_store_from_bucket") as mock_refresh: + yield { + "oci_get": mock_oci_get, + "get_db": mock_get_db, + "get_vs": mock_get_vs, + "get_objects": mock_get_objects, + "get_processed": mock_get_processed, + "detect_changed": mock_detect_changed, + "get_chunks": mock_get_chunks, + "get_embed": mock_get_embed, + "refresh": mock_refresh, + } + + class TestExtractProviderErrorMessage: """Tests for the _extract_provider_error_message helper function.""" @@ -622,34 +647,22 @@ async def test_refresh_vector_store_raises_500_on_db_exception( assert exc_info.value.status_code == 500 @pytest.mark.asyncio - @patch("server.api.v1.embed.utils_oci.get") - @patch("server.api.v1.embed.utils_databases.get_client_database") - @patch("server.api.v1.embed.utils_embed.get_vector_store_by_alias") - @patch("server.api.v1.embed.utils_oci.get_bucket_objects_with_metadata") - @patch("server.api.v1.embed.utils_embed.get_processed_objects_metadata") - @patch("server.api.v1.embed.utils_oci.detect_changed_objects") - @patch("server.api.v1.embed.utils_embed.get_total_chunks_count") async def test_refresh_vector_store_no_changes( self, - mock_get_chunks, - mock_detect_changed, - mock_get_processed, - mock_get_objects, - mock_get_vs, - mock_get_db, - mock_oci_get, + refresh_vector_store_mocks, make_oci_config, make_database, make_vector_store, ): """refresh_vector_store should return success when no changes detected.""" - mock_oci_get.return_value = make_oci_config() - mock_get_db.return_value = make_database() - mock_get_vs.return_value = make_vector_store() - mock_get_objects.return_value = [{"name": "file.pdf", "etag": "abc123"}] - mock_get_processed.return_value = {"file.pdf": {"etag": "abc123"}} - mock_detect_changed.return_value = ([], []) # No new, no modified - mock_get_chunks.return_value = 100 + mocks = refresh_vector_store_mocks + mocks["oci_get"].return_value = make_oci_config() + mocks["get_db"].return_value = make_database() + mocks["get_vs"].return_value = make_vector_store() + mocks["get_objects"].return_value = [{"name": "file.pdf", "etag": "abc123"}] + mocks["get_processed"].return_value = {"file.pdf": {"etag": "abc123"}} + mocks["detect_changed"].return_value = ([], []) # No new, no modified + mocks["get_chunks"].return_value = 100 request = VectorStoreRefreshRequest(vector_store_alias="test_alias", bucket_name="test-bucket") @@ -661,46 +674,30 @@ async def test_refresh_vector_store_no_changes( assert body["total_chunks_in_store"] == 100 @pytest.mark.asyncio - @patch("server.api.v1.embed.utils_oci.get") - @patch("server.api.v1.embed.utils_databases.get_client_database") - @patch("server.api.v1.embed.utils_embed.get_vector_store_by_alias") - @patch("server.api.v1.embed.utils_oci.get_bucket_objects_with_metadata") - @patch("server.api.v1.embed.utils_embed.get_processed_objects_metadata") - @patch("server.api.v1.embed.utils_oci.detect_changed_objects") - @patch("server.api.v1.embed.utils_models.get_client_embed") - @patch("server.api.v1.embed.utils_embed.refresh_vector_store_from_bucket") - @patch("server.api.v1.embed.utils_embed.get_total_chunks_count") async def test_refresh_vector_store_with_changes( self, - mock_get_chunks, - mock_refresh, - mock_get_embed, - mock_detect_changed, - mock_get_processed, - mock_get_objects, - mock_get_vs, - mock_get_db, - mock_oci_get, + refresh_vector_store_mocks, make_oci_config, make_database, make_vector_store, ): """refresh_vector_store should process changed files.""" - mock_oci_get.return_value = make_oci_config() - mock_get_db.return_value = make_database() - mock_get_vs.return_value = make_vector_store(model="text-embedding-3-small") - mock_get_objects.return_value = [ + mocks = refresh_vector_store_mocks + mocks["oci_get"].return_value = make_oci_config() + mocks["get_db"].return_value = make_database() + mocks["get_vs"].return_value = make_vector_store(model="text-embedding-3-small") + mocks["get_objects"].return_value = [ {"name": "new_file.pdf", "etag": "new123"}, {"name": "modified.pdf", "etag": "mod456"}, ] - mock_get_processed.return_value = {"modified.pdf": {"etag": "old_etag"}} - mock_detect_changed.return_value = ( + mocks["get_processed"].return_value = {"modified.pdf": {"etag": "old_etag"}} + mocks["detect_changed"].return_value = ( [{"name": "new_file.pdf", "etag": "new123"}], # new [{"name": "modified.pdf", "etag": "mod456"}], # modified ) - mock_get_embed.return_value = MagicMock() - mock_refresh.return_value = {"message": "Processed 2 files", "processed_files": 2, "total_chunks": 50} - mock_get_chunks.return_value = 150 + mocks["get_embed"].return_value = MagicMock() + mocks["refresh"].return_value = {"message": "Processed 2 files", "processed_files": 2, "total_chunks": 50} + mocks["get_chunks"].return_value = 150 request = VectorStoreRefreshRequest(vector_store_alias="test_alias", bucket_name="test-bucket") @@ -712,37 +709,25 @@ async def test_refresh_vector_store_with_changes( assert body["new_files"] == 1 assert body["updated_files"] == 1 assert body["total_chunks_in_store"] == 150 - mock_refresh.assert_called_once() + mocks["refresh"].assert_called_once() @pytest.mark.asyncio - @patch("server.api.v1.embed.utils_oci.get") - @patch("server.api.v1.embed.utils_databases.get_client_database") - @patch("server.api.v1.embed.utils_embed.get_vector_store_by_alias") - @patch("server.api.v1.embed.utils_oci.get_bucket_objects_with_metadata") - @patch("server.api.v1.embed.utils_embed.get_processed_objects_metadata") - @patch("server.api.v1.embed.utils_oci.detect_changed_objects") - @patch("server.api.v1.embed.utils_models.get_client_embed") async def test_refresh_vector_store_raises_500_on_generic_exception( self, - mock_get_embed, - mock_detect_changed, - mock_get_processed, - mock_get_objects, - mock_get_vs, - mock_get_db, - mock_oci_get, + refresh_vector_store_mocks, make_oci_config, make_database, make_vector_store, ): """refresh_vector_store should raise 500 on generic Exception.""" - mock_oci_get.return_value = make_oci_config() - mock_get_db.return_value = make_database() - mock_get_vs.return_value = make_vector_store() - mock_get_objects.return_value = [{"name": "file.pdf", "etag": "abc123"}] - mock_get_processed.return_value = {} - mock_detect_changed.return_value = ([{"name": "file.pdf"}], []) - mock_get_embed.side_effect = RuntimeError("Embedding service unavailable") + mocks = refresh_vector_store_mocks + mocks["oci_get"].return_value = make_oci_config() + mocks["get_db"].return_value = make_database() + mocks["get_vs"].return_value = make_vector_store() + mocks["get_objects"].return_value = [{"name": "file.pdf", "etag": "abc123"}] + mocks["get_processed"].return_value = {} + mocks["detect_changed"].return_value = ([{"name": "file.pdf"}], []) + mocks["get_embed"].side_effect = RuntimeError("Embedding service unavailable") request = VectorStoreRefreshRequest(vector_store_alias="test_alias", bucket_name="test-bucket") diff --git a/test/unit/server/api/v1/test_v1_mcp.py b/test/unit/server/api/v1/test_v1_mcp.py index 8cd6ba50..516d76cc 100644 --- a/test/unit/server/api/v1/test_v1_mcp.py +++ b/test/unit/server/api/v1/test_v1_mcp.py @@ -8,10 +8,10 @@ # pylint: disable=too-few-public-methods -from unittest.mock import patch, MagicMock, AsyncMock -import pytest - from test.shared_fixtures import TEST_API_KEY +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest from server.api.v1 import mcp From cc3e2478c976db2b0ffd9111106a16f872a5cdb4 Mon Sep 17 00:00:00 2001 From: gotsysdba Date: Sat, 29 Nov 2025 08:56:31 +0000 Subject: [PATCH 07/20] Updated Tests --- pyproject.toml | 1 + src/server/api/utils/databases.py | 10 +- src/server/api/utils/oci.py | 3 + src/server/api/v1/databases.py | 23 +- test/integration/server/api/conftest.py | 52 ++ test/integration/server/api/utils/__init__.py | 6 - .../integration/server/api/v1/test_chat.py | 93 +-- .../server/api/v1/test_databases.py | 143 +++++ test/integration/server/api/v1/test_embed.py | 375 ++++++++++++ .../server/api/v1/test_mcp_prompts.py | 5 + test/integration/server/api/v1/test_models.py | 135 +++++ test/integration/server/api/v1/test_oci.py | 168 ++++++ .../server/api/v1/test_settings.py | 107 ++++ .../integration/server/api/v1/test_testbed.py | 264 +++++---- .../server/api/utils/test_utils_databases.py | 9 +- test/unit/server/api/v1/test_v1_databases.py | 12 +- .../integration/test_endpoints_databases.py | 207 ------- .../integration/test_endpoints_embed.py | 532 ------------------ .../integration/test_endpoints_health.py | 30 - .../integration/test_endpoints_models.py | 456 --------------- .../server/integration/test_endpoints_oci.py | 259 --------- .../integration/test_endpoints_settings.py | 295 ---------- tests/server/unit/api/v1/test_v1_embed.py | 64 --- tests/server/unit/bootstrap/test_bootstrap.py | 22 - 24 files changed, 1212 insertions(+), 2059 deletions(-) delete mode 100644 test/integration/server/api/utils/__init__.py rename tests/server/integration/test_endpoints_chat.py => test/integration/server/api/v1/test_chat.py (74%) rename tests/server/integration/test_endpoints_mcp_prompts.py => test/integration/server/api/v1/test_mcp_prompts.py (96%) rename tests/server/integration/test_endpoints_testbed.py => test/integration/server/api/v1/test_testbed.py (66%) delete mode 100644 tests/server/integration/test_endpoints_databases.py delete mode 100644 tests/server/integration/test_endpoints_embed.py delete mode 100644 tests/server/integration/test_endpoints_health.py delete mode 100644 tests/server/integration/test_endpoints_models.py delete mode 100644 tests/server/integration/test_endpoints_oci.py delete mode 100644 tests/server/integration/test_endpoints_settings.py delete mode 100644 tests/server/unit/api/v1/test_v1_embed.py delete mode 100644 tests/server/unit/bootstrap/test_bootstrap.py diff --git a/pyproject.toml b/pyproject.toml index 4ccd07d8..922ffede 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -70,6 +70,7 @@ test = [ "pytest", "pytest-asyncio", "pytest-cov", + "types-jsonschema", "yamllint" ] diff --git a/src/server/api/utils/databases.py b/src/server/api/utils/databases.py index 8ef8735a..c46e0faf 100644 --- a/src/server/api/utils/databases.py +++ b/src/server/api/utils/databases.py @@ -100,12 +100,14 @@ def _test(config: Database) -> None: except oracledb.DatabaseError: logger.info("Refreshing %s database connection.", config.name) _ = connect(config) - except ValueError as ex: - raise DbException(status_code=400, detail=f"Database: {str(ex)}") from ex + except DbException: + raise except PermissionError as ex: raise DbException(status_code=401, detail=f"Database: {str(ex)}") from ex except ConnectionError as ex: raise DbException(status_code=503, detail=f"Database: {str(ex)}") from ex + except ValueError as ex: + raise DbException(status_code=400, detail=f"Database: {str(ex)}") from ex except Exception as ex: raise DbException(status_code=500, detail=str(ex)) from ex @@ -136,7 +138,7 @@ def connect(config: Database) -> oracledb.Connection: include_fields = set(DatabaseAuth.model_fields.keys()) db_authn = config.model_dump(include=include_fields) if any(not db_authn[key] for key in ("user", "password", "dsn")): - raise ValueError("missing connection details") + raise DbException(status_code=400, detail=f"Database: {config.name} missing connection details.") logger.info("Connecting to Database: %s", config.dsn) # If a wallet password is provided but no wallet location is set @@ -249,7 +251,7 @@ def get_databases( for db in databases: try: db_conn = connect(config=db) - except (ValueError, PermissionError, ConnectionError, LookupError): + except (DbException, PermissionError, ConnectionError, LookupError): continue db.vector_stores = _get_vs(db_conn) db.connected = True diff --git a/src/server/api/utils/oci.py b/src/server/api/utils/oci.py index ff0b010f..33904025 100644 --- a/src/server/api/utils/oci.py +++ b/src/server/api/utils/oci.py @@ -185,6 +185,9 @@ def get_namespace(config: OracleCloudSettings) -> str: client = init_client(client_type, config) config.namespace = client.get_namespace().data logger.info("OCI: Namespace = %s", config.namespace) + except OciException: + # Re-raise OciException from init_client without wrapping + raise except oci.exceptions.InvalidConfig as ex: raise OciException(status_code=400, detail="Invalid Config") from ex except oci.exceptions.ServiceError as ex: diff --git a/src/server/api/v1/databases.py b/src/server/api/v1/databases.py index acf33168..583e5f37 100644 --- a/src/server/api/v1/databases.py +++ b/src/server/api/v1/databases.py @@ -76,24 +76,25 @@ async def databases_update( db.connected = False try: - payload.config_dir = db.config_dir - payload.wallet_location = db.wallet_location - logger.debug("Testing Payload: %s", payload) - db_conn = utils_databases.connect(payload) - except (ValueError, PermissionError, ConnectionError, LookupError) as ex: + # Create a test config with payload values to test connection + # Only update the actual db object after successful connection + test_config = db.model_copy(update=payload.model_dump(exclude_unset=True)) + logger.debug("Testing Database: %s", test_config) + db_conn = utils_databases.connect(test_config) + except utils_databases.DbException as ex: + raise HTTPException(status_code=ex.status_code, detail=ex.detail) from ex + except (PermissionError, ConnectionError, LookupError) as ex: status_code = 500 - if isinstance(ex, ValueError): - status_code = 400 - elif isinstance(ex, PermissionError): + if isinstance(ex, PermissionError): status_code = 401 elif isinstance(ex, LookupError): status_code = 404 elif isinstance(ex, ConnectionError): status_code = 503 - else: - raise raise HTTPException(status_code=status_code, detail=f"Database: {db.name} {ex}.") from ex - for key, value in payload.model_dump().items(): + + # Connection successful - now update the actual db object + for key, value in payload.model_dump(exclude_unset=True).items(): setattr(db, key, value) # Manage Connections; Unset and disconnect other databases diff --git a/test/integration/server/api/conftest.py b/test/integration/server/api/conftest.py index 7b20baf6..a433fdc7 100644 --- a/test/integration/server/api/conftest.py +++ b/test/integration/server/api/conftest.py @@ -26,6 +26,8 @@ TEST_AUTH_TOKEN, ) +import numpy as np + import pytest from fastapi.testclient import TestClient @@ -66,6 +68,20 @@ def auth_headers(): } +@pytest.fixture +def test_client_auth_headers(test_client_settings): + """Auth headers using test_client for endpoints that require client settings. + + Use this fixture for endpoints that look up client settings via the client header. + It ensures the test_client exists in SETTINGS_OBJECTS before returning headers. + """ + return { + "no_auth": {}, + "invalid_auth": {"Authorization": "Bearer invalid-token", "client": test_client_settings}, + "valid_auth": {"Authorization": f"Bearer {TEST_CONFIG['auth_token']}", "client": test_client_settings}, + } + + ################################################# # FastAPI Test Client ################################################# @@ -129,6 +145,20 @@ def sample_settings_payload(): } +@pytest.fixture +def mock_embedding_model(): + """Provides a mock embedding model for testing. + + Returns a function that simulates embedding generation by returning random vectors. + """ + + def mock_embed_documents(texts: list[str]) -> list[list[float]]: + """Mock function that returns random embeddings for testing""" + return [np.random.rand(384).tolist() for _ in texts] # 384 is a common embedding dimension + + return mock_embed_documents + + ################################################# # State Management Helpers ################################################# @@ -145,6 +175,28 @@ def db_objects_manager(): DATABASE_OBJECTS.extend(original_db_objects) +@pytest.fixture +def test_client_settings(settings_objects_manager): + """Ensure test_client exists in SETTINGS_OBJECTS for integration tests. + + Many endpoints use the client header to look up client settings. + This fixture adds a test_client to SETTINGS_OBJECTS if not present. + """ + # Import here to avoid circular imports + from common.schema import Settings # pylint: disable=import-outside-toplevel + + # Check if test_client already exists + existing = next((s for s in settings_objects_manager if s.client == "test_client"), None) + if not existing: + # Create test_client settings based on default + default = next((s for s in settings_objects_manager if s.client == "default"), None) + if default: + test_settings = Settings(**default.model_dump()) + test_settings.client = "test_client" + settings_objects_manager.append(test_settings) + return "test_client" + + @pytest.fixture def model_objects_manager(): """Fixture to manage MODEL_OBJECTS save/restore operations.""" diff --git a/test/integration/server/api/utils/__init__.py b/test/integration/server/api/utils/__init__.py deleted file mode 100644 index 37340b95..00000000 --- a/test/integration/server/api/utils/__init__.py +++ /dev/null @@ -1,6 +0,0 @@ -""" -Copyright (c) 2024, 2025, Oracle and/or its affiliates. -Licensed under the Universal Permissive License v1.0 as shown at http://oss.oracle.com/licenses/upl. - -Server API utils integration tests package. -""" diff --git a/tests/server/integration/test_endpoints_chat.py b/test/integration/server/api/v1/test_chat.py similarity index 74% rename from tests/server/integration/test_endpoints_chat.py rename to test/integration/server/api/v1/test_chat.py index 43897a55..f29954da 100644 --- a/tests/server/integration/test_endpoints_chat.py +++ b/test/integration/server/api/v1/test_chat.py @@ -1,9 +1,14 @@ """ Copyright (c) 2024, 2025, Oracle and/or its affiliates. Licensed under the Universal Permissive License v1.0 as shown at http://oss.oracle.com/licenses/upl. + +Integration tests for server/api/v1/chat.py + +Tests the chat completion endpoints including authentication, completion requests, +streaming, and history management. """ # spell-checker: disable -# pylint: disable=protected-access import-error import-outside-toplevel +# pylint: disable=protected-access too-few-public-methods from unittest.mock import patch, MagicMock import warnings @@ -13,11 +18,8 @@ from common.schema import ChatRequest -############################################################################# -# Endpoints Test -############################################################################# -class TestEndpoints: - """Test Endpoints""" +class TestChatAuthenticationRequired: + """Test that chat endpoints require valid authentication.""" @pytest.mark.parametrize( "auth_type, status_code", @@ -35,15 +37,20 @@ class TestEndpoints: pytest.param("/v1/chat/history", "get", id="chat_history_return"), ], ) - def test_invalid_auth_endpoints(self, client, auth_headers, endpoint, api_method, auth_type, status_code): + def test_invalid_auth_endpoints( + self, client, test_client_auth_headers, endpoint, api_method, auth_type, status_code + ): """Test endpoints require valid authentication.""" - response = getattr(client, api_method)(endpoint, headers=auth_headers[auth_type]) + response = getattr(client, api_method)(endpoint, headers=test_client_auth_headers[auth_type]) assert response.status_code == status_code - def test_chat_completion_no_model(self, client, auth_headers): - """Test no model chat completion request""" + +class TestChatCompletions: + """Integration tests for chat completion endpoints.""" + + def test_chat_completion_no_model(self, client, test_client_auth_headers): + """Test chat completion request when no model is configured.""" with warnings.catch_warnings(): - # Enable the catch_warnings context warnings.simplefilter("ignore", category=UserWarning) request = ChatRequest( messages=[ChatMessage(content="Hello", role="user")], @@ -52,7 +59,7 @@ def test_chat_completion_no_model(self, client, auth_headers): max_tokens=256, ) response = client.post( - "/v1/chat/completions", headers=auth_headers["valid_auth"], json=request.model_dump() + "/v1/chat/completions", headers=test_client_auth_headers["valid_auth"], json=request.model_dump() ) assert response.status_code == 200 @@ -62,9 +69,8 @@ def test_chat_completion_no_model(self, client, auth_headers): == "I'm unable to initialise the Language Model. Please refresh the application." ) - def test_chat_completion_valid_mock(self, client, auth_headers): - """Test valid chat completion request""" - # Create the mock response + def test_chat_completion_valid_mock(self, client, test_client_auth_headers): + """Test valid chat completion request with mocked response.""" mock_response = { "id": "test-id", "choices": [ @@ -80,9 +86,7 @@ def test_chat_completion_valid_mock(self, client, auth_headers): "usage": {"prompt_tokens": 10, "completion_tokens": 20, "total_tokens": 30}, } - # Mock the requests.post call with patch.object(client, "post") as mock_post: - # Configure the mock response mock_response_obj = MagicMock() mock_response_obj.status_code = 200 mock_response_obj.json.return_value = mock_response @@ -96,20 +100,22 @@ def test_chat_completion_valid_mock(self, client, auth_headers): ) response = client.post( - "/v1/chat/completions", headers=auth_headers["valid_auth"], json=request.model_dump() + "/v1/chat/completions", headers=test_client_auth_headers["valid_auth"], json=request.model_dump() ) assert response.status_code == 200 assert "choices" in response.json() assert response.json()["choices"][0]["message"]["content"] == "Test response" - def test_chat_stream_valid_mock(self, client, auth_headers): - """Test valid chat stream request""" - # Create the mock streaming response + +class TestChatStreaming: + """Integration tests for chat streaming endpoint.""" + + def test_chat_stream_valid_mock(self, client, test_client_auth_headers): + """Test valid chat stream request with mocked response.""" mock_streaming_response = MagicMock() mock_streaming_response.status_code = 200 mock_streaming_response.iter_bytes.return_value = [b"Test streaming", b" response"] - # Mock the requests.post call with patch.object(client, "post") as mock_post: mock_post.return_value = mock_streaming_response @@ -121,55 +127,58 @@ def test_chat_stream_valid_mock(self, client, auth_headers): streaming=True, ) - response = client.post("/v1/chat/streams", headers=auth_headers["valid_auth"], json=request.model_dump()) + response = client.post( + "/v1/chat/streams", headers=test_client_auth_headers["valid_auth"], json=request.model_dump() + ) assert response.status_code == 200 content = b"".join(response.iter_bytes()) assert b"Test streaming response" in content - def test_chat_history_valid_mock(self, client, auth_headers): - """Test valid chat history request""" - # Create the mock history response + +class TestChatHistory: + """Integration tests for chat history management endpoints.""" + + def test_chat_history_valid_mock(self, client, test_client_auth_headers): + """Test retrieving chat history with mocked response.""" mock_history = [{"role": "user", "content": "Hello"}, {"role": "assistant", "content": "Hi there!"}] - # Mock the requests.get call with patch.object(client, "get") as mock_get: - # Configure the mock response mock_response_obj = MagicMock() mock_response_obj.status_code = 200 mock_response_obj.json.return_value = mock_history mock_get.return_value = mock_response_obj - response = client.get("/v1/chat/history", headers=auth_headers["valid_auth"]) + response = client.get("/v1/chat/history", headers=test_client_auth_headers["valid_auth"]) assert response.status_code == 200 history = response.json() assert len(history) == 2 assert history[0]["role"] == "user" assert history[0]["content"] == "Hello" - def test_chat_history_clean(self, client, auth_headers): - """Test chat history with no history""" + def test_chat_history_clean(self, client, test_client_auth_headers): + """Test clearing chat history when no prior history exists.""" with patch("server.agents.chatbot.chatbot_graph") as mock_graph: mock_graph.get_state.side_effect = KeyError() - response = client.patch("/v1/chat/history", headers=auth_headers["valid_auth"]) + response = client.patch("/v1/chat/history", headers=test_client_auth_headers["valid_auth"]) assert response.status_code == 200 history = response.json() assert len(history) == 1 assert history[0]["role"] == "system" assert "forgotten" in history[0]["content"].lower() - def test_chat_history_empty(self, client, auth_headers): - """Test chat history with no history""" + def test_chat_history_empty(self, client, test_client_auth_headers): + """Test retrieving chat history when no history exists.""" with patch("server.agents.chatbot.chatbot_graph") as mock_graph: mock_graph.get_state.side_effect = KeyError() - response = client.get("/v1/chat/history", headers=auth_headers["valid_auth"]) + response = client.get("/v1/chat/history", headers=test_client_auth_headers["valid_auth"]) assert response.status_code == 200 history = response.json() assert len(history) == 1 assert history[0]["role"] == "system" assert "no history" in history[0]["content"].lower() - def test_chat_history_clears_rag_context(self, client, auth_headers): - """Test that clearing chat history also clears RAG document context + def test_chat_history_clears_rag_context(self, client, test_client_auth_headers): + """Test that clearing chat history also clears RAG document context. This test ensures that when PATCH /v1/chat/history is called, all OptimizerState fields are cleared including: @@ -182,7 +191,6 @@ def test_chat_history_clears_rag_context(self, client, auth_headers): This prevents RAG documents from persisting across conversation resets. """ with patch("server.agents.chatbot.chatbot_graph") as mock_graph: - # Create a mock state snapshot that simulates a conversation with RAG documents mock_state = MagicMock() mock_state.values = { "messages": [ @@ -203,27 +211,22 @@ def test_chat_history_clears_rag_context(self, client, auth_headers): }, } - # Setup the mock to return our state mock_graph.get_state.return_value = mock_state mock_graph.update_state.return_value = None - # Call the endpoint to clear history - response = client.patch("/v1/chat/history", headers=auth_headers["valid_auth"]) + response = client.patch("/v1/chat/history", headers=test_client_auth_headers["valid_auth"]) - # Verify the response assert response.status_code == 200 history = response.json() assert len(history) == 1 assert history[0]["role"] == "system" assert "forgotten" in history[0]["content"].lower() - # Verify update_state was called with ALL state fields cleared mock_graph.update_state.assert_called_once() call_args = mock_graph.update_state.call_args - # Check that values dict includes all OptimizerState fields values = call_args.kwargs["values"] - assert "messages" in values # Should have RemoveMessage + assert "messages" in values assert "cleaned_messages" in values assert values["cleaned_messages"] == [] assert "context_input" in values diff --git a/test/integration/server/api/v1/test_databases.py b/test/integration/server/api/v1/test_databases.py index 847ef735..a089df39 100644 --- a/test/integration/server/api/v1/test_databases.py +++ b/test/integration/server/api/v1/test_databases.py @@ -8,6 +8,8 @@ These endpoints require authentication. """ +from test.db_fixtures import TEST_DB_CONFIG + class TestAuthentication: """Integration tests for authentication on database endpoints.""" @@ -51,6 +53,24 @@ def test_databases_list_contains_default(self, client, auth_headers): # If no config file, the list may be empty or contain DEFAULT assert isinstance(data, list) + def test_databases_list_initial_state(self, client, auth_headers, db_objects_manager, make_database): + """Test initial database listing shows disconnected state with no credentials.""" + # Ensure DEFAULT database exists + default_db = next((db for db in db_objects_manager if db.name == "DEFAULT"), None) + if not default_db: + db_objects_manager.append(make_database(name="DEFAULT")) + + response = client.get("/v1/databases", headers=auth_headers["valid_auth"]) + assert response.status_code == 200 + data = response.json() + assert isinstance(data, list) + assert len(data) > 0 + + default_db_data = next((db for db in data if db["name"] == "DEFAULT"), None) + assert default_db_data is not None + assert default_db_data["connected"] is False + assert default_db_data["vector_stores"] == [] + def test_databases_list_returns_database_schema(self, client, auth_headers, db_objects_manager, make_database): """GET /v1/databases should return databases with correct schema.""" # Ensure there's at least one database for testing @@ -151,3 +171,126 @@ def test_databases_update_connects_to_real_db( data = response.json() assert data["connected"] is True assert data["user"] == test_db_payload["user"] + + def test_databases_update_db_down(self, client, auth_headers, db_objects_manager, make_database): + """Test updating database when target database is unreachable.""" + # Add a test database + test_db = make_database(name="DOWN_DB_TEST") + db_objects_manager.append(test_db) + + payload = { + "user": "test_user", + "password": "test_pass", + "dsn": "//localhost:1521/DOWNDB_TP", # Non-existent database + } + response = client.patch("/v1/databases/DOWN_DB_TEST", headers=auth_headers["valid_auth"], json=payload) + assert response.status_code == 503 + assert "cannot connect to database" in response.json().get("detail", "") + + def test_databases_update_empty_payload(self, client, auth_headers, db_objects_manager, make_database): + """Test updating database with empty payload.""" + test_db = make_database(name="EMPTY_PAYLOAD_TEST") + db_objects_manager.append(test_db) + + response = client.patch("/v1/databases/EMPTY_PAYLOAD_TEST", headers=auth_headers["valid_auth"], json="") + assert response.status_code == 422 + assert "Input should be a valid dictionary" in str(response.json()) + + def test_databases_update_missing_credentials(self, client, auth_headers, db_objects_manager, make_database): + """Test updating database with missing connection credentials.""" + # Create database with no credentials + test_db = make_database(name="MISSING_CREDS_TEST", user=None, password=None, dsn=None) + db_objects_manager.append(test_db) + + response = client.patch("/v1/databases/MISSING_CREDS_TEST", headers=auth_headers["valid_auth"], json={}) + assert response.status_code == 400 + assert "missing connection details" in response.json().get("detail", "") + + def test_databases_update_wrong_password( + self, client, auth_headers, db_objects_manager, db_container, make_database + ): + """Test updating database with wrong password.""" + _ = db_container # Ensure container is running + test_db = make_database(name="WRONG_PASS_TEST") + db_objects_manager.append(test_db) + + payload = { + "user": TEST_DB_CONFIG["db_username"], + "password": "Wr0ng_P4sswOrd", + "dsn": TEST_DB_CONFIG["db_dsn"], + } + response = client.patch("/v1/databases/WRONG_PASS_TEST", headers=auth_headers["valid_auth"], json=payload) + assert response.status_code == 401 + assert "invalid credential or not authorized" in response.json().get("detail", "") + + def test_databases_update_successful( + self, client, auth_headers, db_objects_manager, db_container, test_db_payload, make_database + ): + """Test successful database update and verify state changes.""" + _ = db_container # Ensure container is running + test_db = make_database(name="SUCCESS_UPDATE_TEST") + db_objects_manager.append(test_db) + + response = client.patch( + "/v1/databases/SUCCESS_UPDATE_TEST", headers=auth_headers["valid_auth"], json=test_db_payload + ) + assert response.status_code == 200 + data = response.json() + data.pop("config_dir", None) # Remove environment-specific field + assert data["connected"] is True + assert data["user"] == test_db_payload["user"] + assert data["dsn"] == test_db_payload["dsn"] + + # Verify GET returns updated state + response = client.get("/v1/databases/SUCCESS_UPDATE_TEST", headers=auth_headers["valid_auth"]) + assert response.status_code == 200 + data = response.json() + assert data["connected"] is True + + # Verify LIST returns updated state + response = client.get("/v1/databases", headers=auth_headers["valid_auth"]) + assert response.status_code == 200 + data = response.json() + updated_db = next((db for db in data if db["name"] == "SUCCESS_UPDATE_TEST"), None) + assert updated_db is not None + assert updated_db["connected"] is True + + def test_databases_update_invalid_wallet( + self, client, auth_headers, db_objects_manager, db_container, test_db_payload, make_database + ): + """Test updating database with invalid wallet configuration still works if wallet not required.""" + _ = db_container # Ensure container is running + test_db = make_database(name="WALLET_TEST") + db_objects_manager.append(test_db) + + payload = { + **test_db_payload, + "wallet_location": "/nonexistent/path", + "wallet_password": "invalid", + } + response = client.patch("/v1/databases/WALLET_TEST", headers=auth_headers["valid_auth"], json=payload) + # Should still work if wallet is not required + assert response.status_code == 200 + + def test_databases_concurrent_connections( + self, client, auth_headers, db_objects_manager, db_container, test_db_payload, make_database + ): + """Test concurrent database connection attempts are handled properly.""" + _ = db_container # Ensure container is running + test_db = make_database(name="CONCURRENT_TEST") + db_objects_manager.append(test_db) + + # Make multiple concurrent connection attempts + responses = [] + for _ in range(5): + response = client.patch( + "/v1/databases/CONCURRENT_TEST", headers=auth_headers["valid_auth"], json=test_db_payload + ) + responses.append(response) + + # Verify all connections were handled properly + for response in responses: + assert response.status_code in [200, 503] # Either successful or proper error + if response.status_code == 200: + data = response.json() + assert data["connected"] is True diff --git a/test/integration/server/api/v1/test_embed.py b/test/integration/server/api/v1/test_embed.py index debdb019..fb0d7b56 100644 --- a/test/integration/server/api/v1/test_embed.py +++ b/test/integration/server/api/v1/test_embed.py @@ -7,6 +7,55 @@ Tests the embedding and vector store endpoints through the full API stack. These endpoints require authentication. """ +# pylint: disable=too-few-public-methods + +from io import BytesIO +from pathlib import Path +from unittest.mock import MagicMock, patch + +from langchain_core.embeddings import Embeddings + +from common.functions import get_vs_table + +# Common test constants +DEFAULT_TEST_CONTENT = ( + "This is a test document for embedding. It contains multiple sentences. " + "This should be split into chunks. Each chunk will be embedded and stored in the database." +) + +LONGER_TEST_CONTENT = ( + "This is a test document for embedding. It contains multiple sentences. " + "This should be split into chunks. Each chunk will be embedded and stored in the database. " + "We're adding more text to ensure we get multiple chunks with different chunk sizes. " + "The chunk size parameter controls how large each text segment is. " + "Smaller chunks mean more granular retrieval but potentially less context. " + "Larger chunks provide more context but might retrieve irrelevant information." +) + +DEFAULT_EMBED_PARAMS = { + "model": "mock-embed-model", + "chunk_size": 100, + "chunk_overlap": 20, + "distance_metric": "COSINE", + "index_type": "HNSW", +} + + +class MockEmbeddings(Embeddings): + """Mock implementation of the Embeddings interface for testing""" + + def __init__(self, mock_embedding_model): + self.mock_embedding_model = mock_embedding_model + + def embed_documents(self, texts): + return self.mock_embedding_model(texts) + + def embed_query(self, text: str): + return self.mock_embedding_model([text])[0] + + def embed_strings(self, texts): + """Mock embedding strings""" + return self.embed_documents(texts) class TestEmbedDropVs: @@ -171,3 +220,329 @@ def test_refresh_vector_store_rejects_invalid_token(self, client, auth_headers): ) assert response.status_code == 401 + + +############################################################################# +# Helper functions for embed tests +############################################################################# +def configure_database(client, auth_headers, test_db_payload): + """Update Database Configuration""" + response = client.patch("/v1/databases/DEFAULT", headers=auth_headers["valid_auth"], json=test_db_payload) + assert response.status_code == 200 + + +def create_test_file(client_id, filename="test_document.md", content=DEFAULT_TEST_CONTENT): + """Create a test file in the temporary directory""" + embed_dir = Path("/tmp") / client_id / "embedding" + embed_dir.mkdir(parents=True, exist_ok=True) + test_file = embed_dir / filename + test_file.write_text(content) + return embed_dir, test_file + + +def setup_mock_embeddings(mock_embedding_model): + """Create mock embeddings and get_client_embed function""" + mock_embeddings = MockEmbeddings(mock_embedding_model) + + def mock_get_client_embed(_model_config=None, _oci_config=None, _giskard=False): + return mock_embeddings + + return mock_get_client_embed + + +def create_embed_params(alias): + """Create embedding parameters with the given alias""" + params = DEFAULT_EMBED_PARAMS.copy() + params["alias"] = alias + return params + + +def get_vector_store_name(alias): + """Get the expected vector store name for an alias""" + vector_store_name, _ = get_vs_table( + model=DEFAULT_EMBED_PARAMS["model"], + chunk_size=DEFAULT_EMBED_PARAMS["chunk_size"], + chunk_overlap=DEFAULT_EMBED_PARAMS["chunk_overlap"], + distance_metric=DEFAULT_EMBED_PARAMS["distance_metric"], + index_type=DEFAULT_EMBED_PARAMS["index_type"], + alias=alias, + ) + return vector_store_name + + +def verify_vector_store_exists(client, auth_headers, vector_store_name, should_exist=True): + """Verify if a vector store exists in the database""" + db_response = client.get("/v1/databases/DEFAULT", headers=auth_headers["valid_auth"]) + assert db_response.status_code == 200 + db_data = db_response.json() + + vector_stores = db_data.get("vector_stores", []) + vector_store_names = [vs["vector_store"] for vs in vector_stores] + + if should_exist: + assert vector_store_name in vector_store_names, f"Vector store {vector_store_name} not found in database" + else: + assert vector_store_name not in vector_store_names, ( + f"Vector store {vector_store_name} still exists after dropping" + ) + + +############################################################################# +# Functional Tests with Database +############################################################################# +class TestEmbedDropVsWithDb: + """Integration tests for embed_drop_vs with database.""" + + def test_drop_vs_nodb(self, client, test_client_auth_headers): + """Test dropping vector store without a DB connection""" + vs = "TESTVS" + response = client.delete(f"/v1/embed/{vs}", headers=test_client_auth_headers["valid_auth"]) + assert response.status_code in (200, 400) + if response.status_code == 400: + assert "missing connection details" in response.json()["detail"] + + def test_drop_vs_db(self, client, test_client_auth_headers, db_container, test_db_payload): + """Test dropping vector store""" + assert db_container is not None + configure_database(client, test_client_auth_headers, test_db_payload) + vs = "NONEXISTENT_VS" + response = client.delete(f"/v1/embed/{vs}", headers=test_client_auth_headers["valid_auth"]) + assert response.status_code == 200 + assert response.json() == {"message": f"Vector Store: {vs} dropped."} + + +class TestSplitEmbedWithDb: + """Integration tests for split_embed with database.""" + + def test_split_embed(self, client, test_client_auth_headers, db_container, test_db_payload, mock_embedding_model): + """Test split and embed functionality with mock embedding model""" + assert db_container is not None + configure_database(client, test_client_auth_headers, test_db_payload) + + create_test_file("test_client") + _ = MockEmbeddings(mock_embedding_model) + test_data = create_embed_params("test_basic_embed") + + with patch.object(client, "post") as mock_post: + mock_response = MagicMock() + mock_response.status_code = 200 + mock_response.json.return_value = {"message": "10 chunks embedded."} + mock_post.return_value = mock_response + + response = client.post("/v1/embed", headers=test_client_auth_headers["valid_auth"], json=test_data) + assert response.status_code == 200 + response_data = response.json() + assert "message" in response_data + assert "chunks embedded" in response_data["message"].lower() + + def test_split_embed_with_different_chunk_sizes( + self, client, test_client_auth_headers, db_container, test_db_payload, mock_embedding_model + ): + """Test split and embed with different chunk sizes""" + assert db_container is not None + configure_database(client, test_client_auth_headers, test_db_payload) + _ = MockEmbeddings(mock_embedding_model) + + small_chunk_test_data = create_embed_params("test_small_chunks") + small_chunk_test_data["chunk_size"] = 50 + small_chunk_test_data["chunk_overlap"] = 10 + + large_chunk_test_data = create_embed_params("test_large_chunks") + large_chunk_test_data["chunk_size"] = 200 + large_chunk_test_data["chunk_overlap"] = 20 + + with patch.object(client, "post") as mock_post: + mock_response_small = MagicMock() + mock_response_small.status_code = 200 + mock_response_small.json.return_value = {"message": "15 chunks embedded."} + + mock_response_large = MagicMock() + mock_response_large.status_code = 200 + mock_response_large.json.return_value = {"message": "5 chunks embedded."} + + mock_post.side_effect = [mock_response_small, mock_response_large] + + create_test_file("test_client", content=LONGER_TEST_CONTENT) + small_response = client.post( + "/v1/embed", headers=test_client_auth_headers["valid_auth"], json=small_chunk_test_data + ) + assert small_response.status_code == 200 + small_data = small_response.json() + + create_test_file("test_client", content=LONGER_TEST_CONTENT) + large_response = client.post( + "/v1/embed", headers=test_client_auth_headers["valid_auth"], json=large_chunk_test_data + ) + assert large_response.status_code == 200 + large_data = large_response.json() + + small_chunks = int(small_data["message"].split()[0]) + large_chunks = int(large_data["message"].split()[0]) + assert small_chunks > large_chunks, "Smaller chunk size should create more chunks" + + def test_split_embed_no_files(self, client, test_client_auth_headers): + """Test split and embed with no files in the directory""" + client_id = "test_client" + embed_dir = Path("/tmp") / client_id / "embedding" + embed_dir.mkdir(parents=True, exist_ok=True) + + for file_path in embed_dir.iterdir(): + if file_path.is_file(): + file_path.unlink() + + assert not any(embed_dir.iterdir()), "The temporary directory should be empty" + test_data = create_embed_params("test_no_files") + + response = client.post("/v1/embed", headers=test_client_auth_headers["valid_auth"], json=test_data) + assert response.status_code == 404 + assert "no files found in folder" in response.json()["detail"] + + +class TestStoreLocalFileWithDb: + """Integration tests for store_local_file.""" + + def test_store_local_file(self, client, test_client_auth_headers): + """Test storing local files for embedding""" + test_content = b"This is a test file for uploading." + file_obj = BytesIO(test_content) + + response = client.post( + "/v1/embed/local/store", + headers=test_client_auth_headers["valid_auth"], + files={"files": ("test_upload.txt", file_obj, "text/plain")}, + ) + + assert response.status_code == 200 + stored_files = response.json() + assert "test_upload.txt" in stored_files + + +class TestStoreWebFileWithDb: + """Integration tests for store_web_file.""" + + def test_store_web_file(self, client, test_client_auth_headers): + """Test storing web files for embedding""" + test_url = ( + "https://docs.oracle.com/en/database/oracle/oracle-database/23/jjucp/" + "universal-connection-pool-developers-guide.pdf" + ) + + response = client.post("/v1/embed/web/store", headers=test_client_auth_headers["valid_auth"], json=[test_url]) + assert response.status_code == 200 + stored_files = response.json() + assert "universal-connection-pool-developers-guide.pdf" in stored_files + + +class TestVectorStoreLifecycle: + """Integration tests for vector store creation and deletion lifecycle.""" + + def test_vector_store_creation_and_deletion( + self, client, test_client_auth_headers, db_container, test_db_payload, mock_embedding_model + ): + """Test that vector stores are created in the database and can be deleted""" + assert db_container is not None + configure_database(client, test_client_auth_headers, test_db_payload) + + create_test_file("test_client") + mock_get_client_embed = setup_mock_embeddings(mock_embedding_model) + + alias = "test_lifecycle" + test_data = create_embed_params(alias) + expected_vector_store_name = get_vector_store_name(alias) + + with patch("server.api.utils.models.get_client_embed", side_effect=mock_get_client_embed): + response = client.post("/v1/embed", headers=test_client_auth_headers["valid_auth"], json=test_data) + assert response.status_code == 200 + + verify_vector_store_exists(client, test_client_auth_headers, expected_vector_store_name, should_exist=True) + + drop_response = client.delete( + f"/v1/embed/{expected_vector_store_name}", headers=test_client_auth_headers["valid_auth"] + ) + assert drop_response.status_code == 200 + assert drop_response.json() == {"message": f"Vector Store: {expected_vector_store_name} dropped."} + + verify_vector_store_exists( + client, test_client_auth_headers, expected_vector_store_name, should_exist=False + ) + + def test_multiple_vector_stores( + self, client, test_client_auth_headers, db_container, test_db_payload, mock_embedding_model + ): + """Test creating multiple vector stores and verifying they all exist""" + assert db_container is not None + configure_database(client, test_client_auth_headers, test_db_payload) + + aliases = ["test_vs_1", "test_vs_2", "test_vs_3"] + mock_get_client_embed = setup_mock_embeddings(mock_embedding_model) + expected_vector_store_names = [get_vector_store_name(alias) for alias in aliases] + + with patch("server.api.utils.models.get_client_embed", side_effect=mock_get_client_embed): + for alias in aliases: + create_test_file("test_client") + test_data = create_embed_params(alias) + response = client.post("/v1/embed", headers=test_client_auth_headers["valid_auth"], json=test_data) + assert response.status_code == 200 + + for expected_name in expected_vector_store_names: + verify_vector_store_exists(client, test_client_auth_headers, expected_name, should_exist=True) + + for expected_name in expected_vector_store_names: + drop_response = client.delete( + f"/v1/embed/{expected_name}", headers=test_client_auth_headers["valid_auth"] + ) + assert drop_response.status_code == 200 + + for expected_name in expected_vector_store_names: + verify_vector_store_exists(client, test_client_auth_headers, expected_name, should_exist=False) + + +class TestGetVectorStoreFiles: + """Integration tests for getting vector store files.""" + + def test_get_vector_store_files( + self, client, test_client_auth_headers, db_container, test_db_payload, mock_embedding_model + ): + """Test retrieving file list from vector store""" + assert db_container is not None + configure_database(client, test_client_auth_headers, test_db_payload) + + create_test_file("test_client", content=LONGER_TEST_CONTENT) + mock_get_client_embed = setup_mock_embeddings(mock_embedding_model) + + alias = "test_file_listing" + test_data = create_embed_params(alias) + expected_vector_store_name = get_vector_store_name(alias) + + with patch("server.api.utils.models.get_client_embed", side_effect=mock_get_client_embed): + response = client.post("/v1/embed", headers=test_client_auth_headers["valid_auth"], json=test_data) + assert response.status_code == 200 + + file_list_response = client.get( + f"/v1/embed/{expected_vector_store_name}/files", headers=test_client_auth_headers["valid_auth"] + ) + + assert file_list_response.status_code == 200 + data = file_list_response.json() + + assert "vector_store" in data + assert data["vector_store"] == expected_vector_store_name + assert "total_files" in data + assert "total_chunks" in data + assert "files" in data + assert data["total_files"] > 0 + assert data["total_chunks"] > 0 + + drop_response = client.delete( + f"/v1/embed/{expected_vector_store_name}", headers=test_client_auth_headers["valid_auth"] + ) + assert drop_response.status_code == 200 + + def test_get_files_nonexistent_vector_store(self, client, test_client_auth_headers, db_container, test_db_payload): + """Test retrieving file list from nonexistent vector store""" + assert db_container is not None + configure_database(client, test_client_auth_headers, test_db_payload) + + response = client.get("/v1/embed/NONEXISTENT_VS/files", headers=test_client_auth_headers["valid_auth"]) + + assert response.status_code in (200, 400) diff --git a/tests/server/integration/test_endpoints_mcp_prompts.py b/test/integration/server/api/v1/test_mcp_prompts.py similarity index 96% rename from tests/server/integration/test_endpoints_mcp_prompts.py rename to test/integration/server/api/v1/test_mcp_prompts.py index e9dd88f3..cc175260 100644 --- a/tests/server/integration/test_endpoints_mcp_prompts.py +++ b/test/integration/server/api/v1/test_mcp_prompts.py @@ -1,6 +1,11 @@ """ Copyright (c) 2024, 2025, Oracle and/or its affiliates. Licensed under the Universal Permissive License v1.0 as shown at http://oss.oracle.com/licenses/upl. + +Integration tests for server/api/v1/mcp_prompts.py + +Tests the MCP prompts endpoints through the full API stack. +These endpoints require authentication. """ # spell-checker: disable # pylint: disable=protected-access,import-error,import-outside-toplevel diff --git a/test/integration/server/api/v1/test_models.py b/test/integration/server/api/v1/test_models.py index 74cbd11b..c4c7565b 100644 --- a/test/integration/server/api/v1/test_models.py +++ b/test/integration/server/api/v1/test_models.py @@ -269,3 +269,138 @@ def test_models_delete_success(self, client, auth_headers, model_objects_manager assert response.status_code == 200 assert "deleted" in response.json()["message"].lower() + + def test_models_delete_nonexistent_succeeds(self, client, auth_headers): + """DELETE /v1/models/{provider}/{id} should succeed for non-existent model.""" + response = client.delete( + "/v1/models/test_provider/nonexistent_model", + headers=auth_headers["valid_auth"], + ) + + assert response.status_code == 200 + assert response.json() == {"message": "Model: test_provider/nonexistent_model deleted."} + + +class TestModelsValidation: + """Integration tests for model validation and edge cases.""" + + def test_models_list_invalid_type_returns_422(self, client, auth_headers): + """GET /v1/models?model_type=invalid should return 422 validation error.""" + response = client.get("/v1/models?model_type=invalid", headers=auth_headers["valid_auth"]) + assert response.status_code == 422 + + def test_models_supported_invalid_provider_returns_empty(self, client, auth_headers): + """GET /v1/models/supported?model_provider=invalid returns empty list.""" + response = client.get( + "/v1/models/supported?model_provider=invalid_provider", + headers=auth_headers["valid_auth"], + ) + assert response.status_code == 200 + assert response.json() == [] + + def test_models_update_max_chunk_size(self, client, auth_headers, model_objects_manager): + """Test updating max_chunk_size for embedding models (regression test).""" + # pylint: disable=unused-argument + # Create an embedding model with default max_chunk_size + payload = { + "id": "test-embed-chunk-size", + "enabled": False, + "type": "embed", + "provider": "test_provider", + "api_base": "http://127.0.0.1:11434", + "max_chunk_size": 8192, + } + + # Create the model + response = client.post("/v1/models", headers=auth_headers["valid_auth"], json=payload) + assert response.status_code == 201 + assert response.json()["max_chunk_size"] == 8192 + + # Update the max_chunk_size to 512 + payload["max_chunk_size"] = 512 + response = client.patch( + f"/v1/models/{payload['provider']}/{payload['id']}", headers=auth_headers["valid_auth"], json=payload + ) + assert response.status_code == 200 + assert response.json()["max_chunk_size"] == 512 + + # Verify the update persists by fetching the model again + response = client.get(f"/v1/models/{payload['provider']}/{payload['id']}", headers=auth_headers["valid_auth"]) + assert response.status_code == 200 + assert response.json()["max_chunk_size"] == 512 + + # Update to a different value to ensure it's not cached + payload["max_chunk_size"] = 1024 + response = client.patch( + f"/v1/models/{payload['provider']}/{payload['id']}", headers=auth_headers["valid_auth"], json=payload + ) + assert response.status_code == 200 + assert response.json()["max_chunk_size"] == 1024 + + # Verify again + response = client.get(f"/v1/models/{payload['provider']}/{payload['id']}", headers=auth_headers["valid_auth"]) + assert response.status_code == 200 + assert response.json()["max_chunk_size"] == 1024 + + # Clean up + client.delete(f"/v1/models/{payload['provider']}/{payload['id']}", headers=auth_headers["valid_auth"]) + + def test_models_response_schema_validation(self, client, auth_headers): + """Test response schema validation for models list.""" + response = client.get("/v1/models", headers=auth_headers["valid_auth"]) + assert response.status_code == 200 + models = response.json() + assert isinstance(models, list) + + for model in models: + # Validate required fields + assert "id" in model + assert "type" in model + assert "provider" in model + assert "enabled" in model + assert "object" in model + assert "created" in model + assert "owned_by" in model + + # Validate field types + assert isinstance(model["id"], str) + assert model["type"] in ["ll", "embed", "rerank"] + assert isinstance(model["provider"], str) + assert isinstance(model["enabled"], bool) + assert model["object"] == "model" + assert isinstance(model["created"], int) + assert model["owned_by"] == "aioptimizer" + + def test_models_create_response_validation(self, client, auth_headers, model_objects_manager): + """Test model creation response validation.""" + # pylint: disable=unused-argument + payload = { + "id": "test-response-validation-model", + "enabled": False, + "type": "ll", + "provider": "test_provider", + "api_key": "test-key", + "api_base": "https://api.test.com/v1", + "max_input_tokens": 4096, + "temperature": 0.7, + } + + response = client.post("/v1/models", headers=auth_headers["valid_auth"], json=payload) + if response.status_code == 201: + created_model = response.json() + + # Validate all payload fields are in response + for key, value in payload.items(): + assert key in created_model + assert created_model[key] == value + + # Validate additional required fields are added + assert "object" in created_model + assert "created" in created_model + assert "owned_by" in created_model + assert created_model["object"] == "model" + assert created_model["owned_by"] == "aioptimizer" + assert isinstance(created_model["created"], int) + + # Clean up + client.delete(f"/v1/models/{payload['provider']}/{payload['id']}", headers=auth_headers["valid_auth"]) diff --git a/test/integration/server/api/v1/test_oci.py b/test/integration/server/api/v1/test_oci.py index aeed656b..a7bc5ee8 100644 --- a/test/integration/server/api/v1/test_oci.py +++ b/test/integration/server/api/v1/test_oci.py @@ -11,6 +11,79 @@ real OCI credentials will verify endpoint availability and authentication. """ +from unittest.mock import patch, MagicMock +import pytest + + +############################################################################ +# Mocks for OCI endpoints (no real OCI access) +############################################################################ +@pytest.fixture(name="mock_oci_compartments") +def _mock_oci_compartments(): + """Mock get_compartments to return test data""" + with patch( + "server.api.utils.oci.get_compartments", + return_value={ + "compartment1": "ocid1.compartment.oc1..aaaaaaaagq33tv7wzyrjar6m5jbplejbdwnbjqfqvmocvjzsamuaqnkkoubq", + "compartment1 / test": "ocid1.compartment.oc1..aaaaaaaaut53mlkpxo6vpv7z5qlsmbcc3qpdjvjzylzldtb6g3jia", + "compartment2": "ocid1.compartment.oc1..aaaaaaaalbgt4om6izlawie7txut5aciue66htz7dpjzl72fbdw2ezp2uywa", + }, + ) as mock: + yield mock + + +@pytest.fixture(name="mock_oci_buckets") +def _mock_oci_buckets(): + """Mock get_buckets to return test data""" + with patch( + "server.api.utils.oci.get_buckets", + return_value=["bucket1", "bucket2", "bucket3"], + ) as mock: + yield mock + + +@pytest.fixture(name="mock_oci_bucket_objects") +def _mock_oci_bucket_objects(): + """Mock get_bucket_objects to return test data""" + with patch( + "server.api.utils.oci.get_bucket_objects", + return_value=["object1.pdf", "object2.md", "object3.txt"], + ) as mock: + yield mock + + +@pytest.fixture(name="mock_oci_namespace") +def _mock_oci_namespace(): + """Mock get_namespace to return test data""" + with patch("server.api.utils.oci.get_namespace", return_value="test_namespace") as mock: + yield mock + + +@pytest.fixture(name="mock_oci_get_object") +def _mock_oci_get_object(): + """Mock get_object to return a fake file path""" + with patch("server.api.utils.oci.get_object") as mock: + + def side_effect(temp_directory, object_name, bucket_name, config): + # pylint: disable=unused-argument + fake_file = temp_directory / object_name + fake_file.touch() + return str(fake_file) + + mock.side_effect = side_effect + yield mock + + +@pytest.fixture(name="mock_oci_init_client") +def _mock_oci_init_client(): + """Mock init_client to return a fake OCI client""" + mock_client = MagicMock() + mock_client.get_namespace.return_value.data = "test_namespace" + mock_client.get_object.return_value.data.raw.stream.return_value = [b"fake-data"] + + with patch("server.api.utils.oci.init_client", return_value=mock_client): + yield mock_client + class TestOciList: """Integration tests for the OCI list endpoint.""" @@ -222,3 +295,98 @@ def test_oci_create_genai_returns_404_for_unknown_profile(self, client, auth_hea ) assert response.status_code == 404 + + +class TestOciListWithValidation: + """Integration tests with response validation for OCI list endpoint.""" + + def test_oci_list_returns_profiles_with_auth_profile(self, client, auth_headers): + """GET /v1/oci should return list with auth_profile field.""" + response = client.get("/v1/oci", headers=auth_headers["valid_auth"]) + + if response.status_code == 200: + data = response.json() + assert isinstance(data, list) + for item in data: + assert "auth_profile" in item + + def test_oci_get_returns_profile_data(self, client, auth_headers): + """GET /v1/oci/{profile} should return profile data when exists.""" + # First check if DEFAULT profile exists + list_response = client.get("/v1/oci", headers=auth_headers["valid_auth"]) + + if list_response.status_code == 200: + profiles = list_response.json() + if any(p.get("auth_profile") == "DEFAULT" for p in profiles): + response = client.get("/v1/oci/DEFAULT", headers=auth_headers["valid_auth"]) + assert response.status_code == 200 + data = response.json() + assert data["auth_profile"] == "DEFAULT" + + +class TestOciUpdateValidation: + """Integration tests for OCI profile update validation.""" + + def test_oci_update_empty_payload_returns_422(self, client, auth_headers): + """PATCH /v1/oci/{profile} with empty payload should return 422.""" + response = client.patch("/v1/oci/DEFAULT", headers=auth_headers["valid_auth"], json="") + assert response.status_code == 422 + + def test_oci_update_invalid_payload_returns_400_or_404(self, client, auth_headers): + """PATCH /v1/oci/{profile} with invalid payload should return 400 or 404.""" + response = client.patch("/v1/oci/DEFAULT", headers=auth_headers["valid_auth"], json={}) + # 400 if profile exists but payload invalid, 404 if profile doesn't exist + assert response.status_code in [400, 404] + + +class TestOciWithMocks: + """Integration tests using mocks for OCI operations requiring credentials.""" + + def test_oci_compartments_with_mock(self, client, auth_headers, mock_oci_compartments): + """Test compartments endpoint with mocked OCI data.""" + # This test will get 404 if DEFAULT profile doesn't exist + # The mock is for the underlying OCI call, not the profile lookup + response = client.get("/v1/oci/compartments/DEFAULT", headers=auth_headers["valid_auth"]) + + # Either returns mocked data (200) or profile not found (404) + assert response.status_code in [200, 404] + if response.status_code == 200: + assert response.json() == mock_oci_compartments.return_value + + def test_oci_buckets_with_mock(self, client, auth_headers, mock_oci_buckets): + """Test buckets endpoint with mocked OCI data.""" + response = client.get( + "/v1/oci/buckets/ocid1.compartment.oc1..aaaaaaaa/DEFAULT", + headers=auth_headers["valid_auth"], + ) + + # Either returns mocked data (200) or profile not found (404) + assert response.status_code in [200, 404] + if response.status_code == 200: + assert response.json() == mock_oci_buckets.return_value + + def test_oci_bucket_objects_with_mock(self, client, auth_headers, mock_oci_bucket_objects): + """Test bucket objects endpoint with mocked OCI data.""" + response = client.get("/v1/oci/objects/bucket1/DEFAULT", headers=auth_headers["valid_auth"]) + + # Either returns mocked data (200) or profile not found (404) + assert response.status_code in [200, 404] + if response.status_code == 200: + assert response.json() == mock_oci_bucket_objects.return_value + + def test_oci_download_objects_with_mock( + self, client, auth_headers, mock_oci_bucket_objects, mock_oci_get_object + ): + """Test download objects endpoint with mocked OCI data.""" + # pylint: disable=unused-argument + payload = ["object1.pdf", "object2.md"] + response = client.post( + "/v1/oci/objects/download/bucket1/DEFAULT", + headers=auth_headers["valid_auth"], + json=payload, + ) + + # Either returns downloaded files (200) or profile not found (404) + assert response.status_code in [200, 404] + if response.status_code == 200: + assert isinstance(response.json(), list) diff --git a/test/integration/server/api/v1/test_settings.py b/test/integration/server/api/v1/test_settings.py index 15374061..79e0b384 100644 --- a/test/integration/server/api/v1/test_settings.py +++ b/test/integration/server/api/v1/test_settings.py @@ -305,3 +305,110 @@ def test_load_from_json_success(self, client, auth_headers, settings_objects_man assert response.status_code == 200 assert "loaded successfully" in response.json()["message"].lower() + + +class TestSettingsAdvanced: + """Integration tests for advanced settings operations.""" + + def test_settings_update_with_full_payload(self, client, auth_headers, settings_objects_manager): + """Test updating settings with a complete Settings payload.""" + # pylint: disable=unused-argument,import-outside-toplevel + from common.schema import ( + Settings, + LargeLanguageSettings, + VectorSearchSettings, + OciSettings, + ) + + # First get the current settings + response = client.get("/v1/settings", headers=auth_headers["valid_auth"], params={"client": "default"}) + assert response.status_code == 200 + old_settings = response.json() + + # Modify some settings + updated_settings = Settings( + client="default", + ll_model=LargeLanguageSettings(model="updated-model", chat_history=False), + tools_enabled=["Vector Search"], + vector_search=VectorSearchSettings(grade=False, search_type="Similarity", top_k=5), + oci=OciSettings(auth_profile="UPDATED"), + ) + + # Update the settings + response = client.patch( + "/v1/settings", + headers=auth_headers["valid_auth"], + json=updated_settings.model_dump(), + params={"client": "default"}, + ) + assert response.status_code == 200 + new_settings = response.json() + + # Check old do not match update + assert old_settings != new_settings + + # Check that the values were updated + assert new_settings["ll_model"]["model"] == "updated-model" + assert new_settings["ll_model"]["chat_history"] is False + assert new_settings["tools_enabled"] == ["Vector Search"] + assert new_settings["vector_search"]["grade"] is False + assert new_settings["vector_search"]["top_k"] == 5 + assert new_settings["oci"]["auth_profile"] == "UPDATED" + + def test_settings_copy_between_clients(self, client, auth_headers, settings_objects_manager): + """Test copying settings from one client to another.""" + # pylint: disable=unused-argument + # First modify the default settings to make them different + response = client.patch( + "/v1/settings", + headers=auth_headers["valid_auth"], + params={"client": "default"}, + json={ + "client": "default", + "ll_model": {"model": "copy-test-model", "temperature": 0.99}, + }, + ) + assert response.status_code == 200 + + # Get the modified default settings + response = client.get("/v1/settings", headers=auth_headers["valid_auth"], params={"client": "default"}) + assert response.status_code == 200 + default_settings = response.json() + assert default_settings["ll_model"]["model"] == "copy-test-model" + + response = client.get("/v1/settings", headers=auth_headers["valid_auth"], params={"client": "server"}) + assert response.status_code == 200 + old_server_settings = response.json() + + # Server settings should be different from modified default + assert old_server_settings["ll_model"]["model"] != default_settings["ll_model"]["model"] + + # Copy the client settings to the server settings + response = client.patch( + "/v1/settings", + headers=auth_headers["valid_auth"], + json=default_settings, + params={"client": "server"}, + ) + assert response.status_code == 200 + response = client.get("/v1/settings", headers=auth_headers["valid_auth"], params={"client": "server"}) + new_server_settings = response.json() + + # After copy, server settings should match default (except client name) + del new_server_settings["client"] + del default_settings["client"] + assert new_server_settings == default_settings + + def test_settings_get_returns_expected_structure(self, client, auth_headers): + """Test that settings response has expected structure.""" + response = client.get("/v1/settings", headers=auth_headers["valid_auth"], params={"client": "default"}) + assert response.status_code == 200 + settings = response.json() + + # Verify the response contains the expected structure + assert settings["client"] == "default" + assert "ll_model" in settings + assert "vector_search" in settings + assert "oci" in settings + assert "database" in settings + assert "testbed" in settings diff --git a/tests/server/integration/test_endpoints_testbed.py b/test/integration/server/api/v1/test_testbed.py similarity index 66% rename from tests/server/integration/test_endpoints_testbed.py rename to test/integration/server/api/v1/test_testbed.py index 40430599..bddaeaf3 100644 --- a/tests/server/integration/test_endpoints_testbed.py +++ b/test/integration/server/api/v1/test_testbed.py @@ -1,73 +1,114 @@ """ Copyright (c) 2024, 2025, Oracle and/or its affiliates. Licensed under the Universal Permissive License v1.0 as shown at http://oss.oracle.com/licenses/upl. + +Integration tests for server/api/v1/testbed.py + +Tests the testbed (Q&A evaluation) endpoints through the full API stack. +These endpoints require authentication and database connectivity. """ -# spell-checker: disable -# pylint: disable=protected-access import-error import-outside-toplevel -import json import io +import json from unittest.mock import patch, MagicMock + import pytest -from conftest import get_test_db_payload + from common.schema import TestSetQA as QATestSet, Evaluation, EvaluationReport -############################################################################# -# Endpoints Test -############################################################################# -class TestEndpoints: - """Test Endpoints""" - - @pytest.mark.parametrize( - "auth_type, status_code", - [ - pytest.param("no_auth", 401, id="no_auth"), - pytest.param("invalid_auth", 401, id="invalid_auth"), - ], - ) - @pytest.mark.parametrize( - "endpoint, api_method", - [ - pytest.param("/v1/testbed/testsets", "get", id="testbed_testsets"), - pytest.param("/v1/testbed/evaluations", "get", id="testbed_evaluations"), - pytest.param("/v1/testbed/evaluation", "get", id="testbed_evaluation"), - pytest.param("/v1/testbed/testset_qa", "get", id="testbed_testset_qa"), - pytest.param("/v1/testbed/testset_delete/1234", "delete", id="testbed_delete_testset"), - pytest.param("/v1/testbed/testset_load", "post", id="testbed_upsert_testsets"), - pytest.param("/v1/testbed/testset_generate", "post", id="testbed_generate_qa"), - pytest.param("/v1/testbed/evaluate", "post", id="testbed_evaluate_qa"), - ], - ) - def test_invalid_auth_endpoints(self, client, auth_headers, endpoint, api_method, auth_type, status_code): - """Test endpoints require valid authentication.""" - response = getattr(client, api_method)(endpoint, headers=auth_headers[auth_type]) - assert response.status_code == status_code - - def setup_database(self, client, auth_headers, db_container): +class TestAuthentication: + """Integration tests for authentication on testbed endpoints.""" + + def test_testbed_testsets_requires_auth(self, client): + """GET /v1/testbed/testsets should require authentication.""" + response = client.get("/v1/testbed/testsets") + + assert response.status_code == 401 + + def test_testbed_testsets_rejects_invalid_token(self, client, auth_headers): + """GET /v1/testbed/testsets should reject invalid tokens.""" + response = client.get("/v1/testbed/testsets", headers=auth_headers["invalid_auth"]) + + assert response.status_code == 401 + + def test_testbed_evaluations_requires_auth(self, client): + """GET /v1/testbed/evaluations should require authentication.""" + response = client.get("/v1/testbed/evaluations") + + assert response.status_code == 401 + + def test_testbed_evaluation_requires_auth(self, client): + """GET /v1/testbed/evaluation should require authentication.""" + response = client.get("/v1/testbed/evaluation") + + assert response.status_code == 401 + + def test_testbed_testset_qa_requires_auth(self, client): + """GET /v1/testbed/testset_qa should require authentication.""" + response = client.get("/v1/testbed/testset_qa") + + assert response.status_code == 401 + + def test_testbed_delete_requires_auth(self, client): + """DELETE /v1/testbed/testset_delete/{tid} should require authentication.""" + response = client.delete("/v1/testbed/testset_delete/1234") + + assert response.status_code == 401 + + def test_testbed_load_requires_auth(self, client): + """POST /v1/testbed/testset_load should require authentication.""" + response = client.post("/v1/testbed/testset_load") + + assert response.status_code == 401 + + def test_testbed_generate_requires_auth(self, client): + """POST /v1/testbed/testset_generate should require authentication.""" + response = client.post("/v1/testbed/testset_generate") + + assert response.status_code == 401 + + def test_testbed_evaluate_requires_auth(self, client): + """POST /v1/testbed/evaluate should require authentication.""" + response = client.post("/v1/testbed/evaluate") + + assert response.status_code == 401 + + +class TestTestbedWithDatabase: + """Integration tests for testbed endpoints that require database connectivity.""" + + @pytest.fixture(autouse=True) + def setup_database( + self, client, test_client_auth_headers, db_container, test_db_payload, db_objects_manager, make_database + ): """Setup database connection for tests""" - assert db_container is not None - payload = get_test_db_payload() - response = client.patch("/v1/databases/DEFAULT", headers=auth_headers["valid_auth"], json=payload) + # pylint: disable=unused-argument + _ = db_container # Ensure container is running + + # Ensure DEFAULT database exists + default_db = next((db for db in db_objects_manager if db.name == "DEFAULT"), None) + if not default_db: + db_objects_manager.append(make_database(name="DEFAULT")) + + response = client.patch( + "/v1/databases/DEFAULT", headers=test_client_auth_headers["valid_auth"], json=test_db_payload + ) assert response.status_code == 200 # Create the testset tables by calling an endpoint that will trigger table creation - response = client.get("/v1/testbed/testsets", headers=auth_headers["valid_auth"]) + response = client.get("/v1/testbed/testsets", headers=test_client_auth_headers["valid_auth"]) assert response.status_code == 200 - def test_testbed_testsets_empty(self, client, auth_headers, db_container): + def test_testbed_testsets_empty(self, client, test_client_auth_headers): """Test getting empty testsets list""" - self.setup_database(client, auth_headers, db_container) - with patch("server.api.utils.testbed.get_testsets", return_value=[]): - response = client.get("/v1/testbed/testsets", headers=auth_headers["valid_auth"]) + response = client.get("/v1/testbed/testsets", headers=test_client_auth_headers["valid_auth"]) assert response.status_code == 200 assert response.json() == [] - def test_testbed_testsets_with_data(self, client, auth_headers, db_container): + def test_testbed_testsets_with_data(self, client, test_client_auth_headers): """Test getting testsets with data""" - self.setup_database(client, auth_headers, db_container) - # Create two test sets with actual data for i, name in enumerate(["Test Set 1", "Test Set 2"]): test_data = json.dumps([{"question": f"Test Q{i}?", "answer": f"Test A{i}"}]) @@ -76,13 +117,13 @@ def test_testbed_testsets_with_data(self, client, auth_headers, db_container): response = client.post( f"/v1/testbed/testset_load?name={name.replace(' ', '%20')}", - headers=auth_headers["valid_auth"], + headers=test_client_auth_headers["valid_auth"], files=files, ) assert response.status_code == 200 # Now get the testsets and verify - response = client.get("/v1/testbed/testsets", headers=auth_headers["valid_auth"]) + response = client.get("/v1/testbed/testsets", headers=test_client_auth_headers["valid_auth"]) assert response.status_code == 200 testsets = response.json() assert len(testsets) >= 2 @@ -96,10 +137,8 @@ def test_testbed_testsets_with_data(self, client, auth_headers, db_container): assert "tid" in test_set_1 assert "tid" in test_set_2 - def test_testbed_testset_qa(self, client, auth_headers, db_container): + def test_testbed_testset_qa(self, client, test_client_auth_headers): """Test getting testset Q&A data""" - self.setup_database(client, auth_headers, db_container) - # Create a test set with specific Q&A data test_data = json.dumps( [{"question": "What is X?", "answer": "X is Y"}, {"question": "What is Z?", "answer": "Z is W"}] @@ -108,19 +147,19 @@ def test_testbed_testset_qa(self, client, auth_headers, db_container): files = {"files": ("test.json", test_file, "application/json")} response = client.post( - "/v1/testbed/testset_load?name=QA%20Test%20Set", headers=auth_headers["valid_auth"], files=files + "/v1/testbed/testset_load?name=QA%20Test%20Set", headers=test_client_auth_headers["valid_auth"], files=files ) assert response.status_code == 200 # Get the testset ID - response = client.get("/v1/testbed/testsets", headers=auth_headers["valid_auth"]) + response = client.get("/v1/testbed/testsets", headers=test_client_auth_headers["valid_auth"]) testsets = response.json() testset = next((ts for ts in testsets if ts["name"] == "QA Test Set"), None) assert testset is not None tid = testset["tid"] # Now get the Q&A data for this testset - response = client.get(f"/v1/testbed/testset_qa?tid={tid}", headers=auth_headers["valid_auth"]) + response = client.get(f"/v1/testbed/testset_qa?tid={tid}", headers=test_client_auth_headers["valid_auth"]) assert response.status_code == 200 qa_data = response.json() @@ -137,19 +176,17 @@ def test_testbed_testset_qa(self, client, auth_headers, db_container): assert "X is Y" in answers assert "Z is W" in answers - def test_testbed_evaluations_empty(self, client, auth_headers, db_container): + def test_testbed_evaluations_empty(self, client, test_client_auth_headers): """Test getting empty evaluations list""" - self.setup_database(client, auth_headers, db_container) - with patch("server.api.utils.testbed.get_evaluations", return_value=[]): - response = client.get("/v1/testbed/evaluations?tid=123abc", headers=auth_headers["valid_auth"]) + response = client.get( + "/v1/testbed/evaluations?tid=123abc", headers=test_client_auth_headers["valid_auth"] + ) assert response.status_code == 200 assert response.json() == [] - def test_testbed_evaluations_with_data(self, client, auth_headers, db_container): + def test_testbed_evaluations_with_data(self, client, test_client_auth_headers): """Test getting evaluations with data""" - self.setup_database(client, auth_headers, db_container) - # First, create a testset to evaluate test_data = json.dumps( [{"question": "Eval Q1?", "answer": "Eval A1"}, {"question": "Eval Q2?", "answer": "Eval A2"}] @@ -158,12 +195,14 @@ def test_testbed_evaluations_with_data(self, client, auth_headers, db_container) files = {"files": ("test.json", test_file, "application/json")} response = client.post( - "/v1/testbed/testset_load?name=Eval%20Test%20Set", headers=auth_headers["valid_auth"], files=files + "/v1/testbed/testset_load?name=Eval%20Test%20Set", + headers=test_client_auth_headers["valid_auth"], + files=files, ) assert response.status_code == 200 # Get the testset ID - response = client.get("/v1/testbed/testsets", headers=auth_headers["valid_auth"]) + response = client.get("/v1/testbed/testsets", headers=test_client_auth_headers["valid_auth"]) testsets = response.json() testset = next((ts for ts in testsets if ts["name"] == "Eval Test Set"), None) assert testset is not None @@ -176,7 +215,7 @@ def test_testbed_evaluations_with_data(self, client, auth_headers, db_container) ] with patch("server.api.utils.testbed.get_evaluations", return_value=mock_evaluations): - response = client.get(f"/v1/testbed/evaluations?tid={tid}", headers=auth_headers["valid_auth"]) + response = client.get(f"/v1/testbed/evaluations?tid={tid}", headers=test_client_auth_headers["valid_auth"]) assert response.status_code == 200 evaluations = response.json() assert len(evaluations) == 2 @@ -185,10 +224,8 @@ def test_testbed_evaluations_with_data(self, client, auth_headers, db_container) assert evaluations[1]["eid"] == "eval2" assert evaluations[1]["correctness"] == 0.92 - def test_testbed_evaluation(self, client, auth_headers, db_container): + def test_testbed_evaluation_report(self, client, test_client_auth_headers): """Test getting a single evaluation report""" - self.setup_database(client, auth_headers, db_container) - # First, create a testset to evaluate test_data = json.dumps( [{"question": "Report Q1?", "answer": "Report A1"}, {"question": "Report Q2?", "answer": "Report A2"}] @@ -197,17 +234,12 @@ def test_testbed_evaluation(self, client, auth_headers, db_container): files = {"files": ("test.json", test_file, "application/json")} response = client.post( - "/v1/testbed/testset_load?name=Report%20Test%20Set", headers=auth_headers["valid_auth"], files=files + "/v1/testbed/testset_load?name=Report%20Test%20Set", + headers=test_client_auth_headers["valid_auth"], + files=files, ) assert response.status_code == 200 - # Get the testset ID - response = client.get("/v1/testbed/testsets", headers=auth_headers["valid_auth"]) - testsets = response.json() - testset = next((ts for ts in testsets if ts["name"] == "Report Test Set"), None) - assert testset is not None - _ = testset["tid"] - # Mock the evaluation report mock_report = EvaluationReport( eid="eval1", @@ -221,7 +253,7 @@ def test_testbed_evaluation(self, client, auth_headers, db_container): ) with patch("server.api.utils.testbed.process_report", return_value=mock_report): - response = client.get("/v1/testbed/evaluation?eid=eval1", headers=auth_headers["valid_auth"]) + response = client.get("/v1/testbed/evaluation?eid=eval1", headers=test_client_auth_headers["valid_auth"]) assert response.status_code == 200 report = response.json() @@ -234,18 +266,14 @@ def test_testbed_evaluation(self, client, auth_headers, db_container): assert "correct_by_topic" in report assert "failures" in report - def test_testbed_delete_testset(self, client, auth_headers, db_container): + def test_testbed_delete_testset(self, client, test_client_auth_headers): """Test deleting a testset""" - self.setup_database(client, auth_headers, db_container) - - response = client.delete("/v1/testbed/testset_delete/1234", headers=auth_headers["valid_auth"]) + response = client.delete("/v1/testbed/testset_delete/1234", headers=test_client_auth_headers["valid_auth"]) assert response.status_code == 200 assert "message" in response.json() - def test_testbed_upsert_testsets(self, client, auth_headers, db_container): + def test_testbed_upsert_testsets(self, client, test_client_auth_headers): """Test upserting testsets""" - self.setup_database(client, auth_headers, db_container) - # Create test data test_data = json.dumps([{"question": "Test Q?", "answer": "Test A"}]) test_file = io.BytesIO(test_data.encode()) @@ -254,14 +282,9 @@ def test_testbed_upsert_testsets(self, client, auth_headers, db_container): files = {"files": ("test.json", test_file, "application/json")} response = client.post( - "/v1/testbed/testset_load?name=Test%20Set", headers=auth_headers["valid_auth"], files=files + "/v1/testbed/testset_load?name=Test%20Set", headers=test_client_auth_headers["valid_auth"], files=files ) - # Print response content if it fails - if response.status_code != 200: - print(f"Response status code: {response.status_code}") - print(f"Response content: {response.content}") - # Verify the response assert response.status_code == 200 assert "qa_data" in response.json() @@ -269,10 +292,8 @@ def test_testbed_upsert_testsets(self, client, auth_headers, db_container): assert response.json()["qa_data"][0]["question"] == "Test Q?" assert response.json()["qa_data"][0]["answer"] == "Test A" - def test_testbed_generate_qa(self, client, auth_headers, db_container): - """Test generating Q&A testset""" - self.setup_database(client, auth_headers, db_container) - + def test_testbed_generate_qa_mocked(self, client, test_client_auth_headers): + """Test generating Q&A testset with mocked client""" # This is a complex operation that requires a model to generate Q&A, so we'll mock this part with patch.object(client, "post") as mock_post: # Configure the mock to return a successful response @@ -290,7 +311,7 @@ def test_testbed_generate_qa(self, client, auth_headers, db_container): # Make the request response = client.post( "/v1/testbed/testset_generate", - headers=auth_headers["valid_auth"], + headers=test_client_auth_headers["valid_auth"], files={"files": ("test.pdf", b"Test PDF content", "application/pdf")}, data={ "name": "Generated Test Set", @@ -304,10 +325,8 @@ def test_testbed_generate_qa(self, client, auth_headers, db_container): assert response.status_code == 200 assert mock_post.called - def test_testbed_evaluate_qa(self, client, auth_headers, db_container): - """Test evaluating Q&A testset""" - self.setup_database(client, auth_headers, db_container) - + def test_testbed_evaluate_qa_mocked(self, client, test_client_auth_headers): + """Test evaluating Q&A testset with mocked client""" # First, create a testset to evaluate test_data = json.dumps( [{"question": "Test Q1?", "answer": "Test A1"}, {"question": "Test Q2?", "answer": "Test A2"}] @@ -317,13 +336,15 @@ def test_testbed_evaluate_qa(self, client, auth_headers, db_container): files = {"files": ("test.json", test_file, "application/json")} response = client.post( - "/v1/testbed/testset_load?name=Evaluation%20Test%20Set", headers=auth_headers["valid_auth"], files=files + "/v1/testbed/testset_load?name=Evaluation%20Test%20Set", + headers=test_client_auth_headers["valid_auth"], + files=files, ) assert response.status_code == 200 # Get the testset ID - response = client.get("/v1/testbed/testsets", headers=auth_headers["valid_auth"]) + response = client.get("/v1/testbed/testsets", headers=test_client_auth_headers["valid_auth"]) testsets = response.json() testset = next((ts for ts in testsets if ts["name"] == "Evaluation Test Set"), None) assert testset is not None @@ -349,7 +370,9 @@ def test_testbed_evaluate_qa(self, client, auth_headers, db_container): # Make the request response = client.post( - "/v1/testbed/evaluate", headers=auth_headers["valid_auth"], json={"tid": tid, "judge": "test-judge"} + "/v1/testbed/evaluate", + headers=test_client_auth_headers["valid_auth"], + json={"tid": tid, "judge": "test-judge"}, ) # Verify the response @@ -357,15 +380,13 @@ def test_testbed_evaluate_qa(self, client, auth_headers, db_container): assert mock_post.called # Clean up by deleting the testset - response = client.delete(f"/v1/testbed/testset_delete/{tid}", headers=auth_headers["valid_auth"]) + response = client.delete(f"/v1/testbed/testset_delete/{tid}", headers=test_client_auth_headers["valid_auth"]) assert response.status_code == 200 - def test_end_to_end_testbed_flow(self, client, auth_headers, db_container): + def test_end_to_end_testbed_flow(self, client, test_client_auth_headers): """Test the complete testbed workflow""" - self.setup_database(client, auth_headers, db_container) - - # Step 1: Verify no testsets exist - response = client.get("/v1/testbed/testsets", headers=auth_headers["valid_auth"]) + # Step 1: Verify initial state + response = client.get("/v1/testbed/testsets", headers=test_client_auth_headers["valid_auth"]) initial_testsets = response.json() # Step 2: Create a testset @@ -375,15 +396,16 @@ def test_end_to_end_testbed_flow(self, client, auth_headers, db_container): files = {"files": ("test.json", test_file, "application/json")} response = client.post( - "/v1/testbed/testset_load?name=Test%20Flow%20Set", headers=auth_headers["valid_auth"], files=files + "/v1/testbed/testset_load?name=Test%20Flow%20Set", + headers=test_client_auth_headers["valid_auth"], + files=files, ) assert response.status_code == 200 assert "qa_data" in response.json() # Get the testset ID from the response - # We need to get the testset ID from the database since it's not returned in the response - response = client.get("/v1/testbed/testsets", headers=auth_headers["valid_auth"]) + response = client.get("/v1/testbed/testsets", headers=test_client_auth_headers["valid_auth"]) testsets = response.json() assert len(testsets) > len(initial_testsets) @@ -393,15 +415,14 @@ def test_end_to_end_testbed_flow(self, client, auth_headers, db_container): tid = testset["tid"] # Step 3: Get the testset QA data - response = client.get(f"/v1/testbed/testset_qa?tid={tid}", headers=auth_headers["valid_auth"]) + response = client.get(f"/v1/testbed/testset_qa?tid={tid}", headers=test_client_auth_headers["valid_auth"]) assert response.status_code == 200 assert "qa_data" in response.json() assert len(response.json()["qa_data"]) == 1 assert response.json()["qa_data"][0]["question"] == "What is X?" assert response.json()["qa_data"][0]["answer"] == "X is Y" - # Step 4: Evaluate the testset - # This is a complex operation that requires a judge model, so we'll mock this part + # Step 4: Evaluate the testset (mocked) with patch.object(client, "post") as mock_post: mock_response = MagicMock() mock_response.status_code = 200 @@ -409,12 +430,13 @@ def test_end_to_end_testbed_flow(self, client, auth_headers, db_container): mock_post.return_value = mock_response response = client.post( - "/v1/testbed/evaluate", headers=auth_headers["valid_auth"], json={"tid": tid, "judge": "flow-judge"} + "/v1/testbed/evaluate", + headers=test_client_auth_headers["valid_auth"], + json={"tid": tid, "judge": "flow-judge"}, ) assert response.status_code == 200 - # Step 5: Get the evaluation report - # This also requires a complex setup, so we'll mock this part + # Step 5: Get the evaluation report (mocked) with patch.object(client, "get") as mock_get: mock_report = EvaluationReport( eid="flow_eval_id", @@ -431,15 +453,17 @@ def test_end_to_end_testbed_flow(self, client, auth_headers, db_container): mock_response.json.return_value = mock_report.dict() mock_get.return_value = mock_response - response = client.get("/v1/testbed/evaluation?eid=flow_eval_id", headers=auth_headers["valid_auth"]) + response = client.get( + "/v1/testbed/evaluation?eid=flow_eval_id", headers=test_client_auth_headers["valid_auth"] + ) assert response.status_code == 200 # Step 6: Delete the testset - response = client.delete(f"/v1/testbed/testset_delete/{tid}", headers=auth_headers["valid_auth"]) + response = client.delete(f"/v1/testbed/testset_delete/{tid}", headers=test_client_auth_headers["valid_auth"]) assert response.status_code == 200 assert "message" in response.json() # Verify the testset was deleted - response = client.get("/v1/testbed/testsets", headers=auth_headers["valid_auth"]) + response = client.get("/v1/testbed/testsets", headers=test_client_auth_headers["valid_auth"]) final_testsets = response.json() assert len(final_testsets) == len(initial_testsets) diff --git a/test/unit/server/api/utils/test_utils_databases.py b/test/unit/server/api/utils/test_utils_databases.py index 62f54a32..ac9b99c9 100644 --- a/test/unit/server/api/utils/test_utils_databases.py +++ b/test/unit/server/api/utils/test_utils_databases.py @@ -164,14 +164,15 @@ def test_connect_success_real_db(self, db_container, make_database): assert result.is_healthy() result.close() - def test_connect_raises_value_error_missing_details(self, make_database): - """connect should raise ValueError if connection details missing.""" + def test_connect_raises_db_exception_missing_details(self, make_database): + """connect should raise DbException if connection details missing.""" config = make_database(user=None, password=None, dsn=None) - with pytest.raises(ValueError) as exc_info: + with pytest.raises(DbException) as exc_info: utils_databases.connect(config) - assert "missing connection details" in str(exc_info.value) + assert exc_info.value.status_code == 400 + assert "missing connection details" in str(exc_info.value.detail) def test_connect_raises_permission_error_invalid_credentials(self, db_container, make_database): """connect should raise PermissionError on invalid credentials (real database).""" diff --git a/test/unit/server/api/v1/test_v1_databases.py b/test/unit/server/api/v1/test_v1_databases.py index 98f957a4..cafe6ed1 100644 --- a/test/unit/server/api/v1/test_v1_databases.py +++ b/test/unit/server/api/v1/test_v1_databases.py @@ -15,6 +15,7 @@ from fastapi import HTTPException from server.api.v1 import databases +from server.api.utils import databases as utils_databases class TestDatabasesList: @@ -142,13 +143,15 @@ async def test_databases_update_raises_404_when_not_found(self, mock_get_databas @pytest.mark.asyncio @patch("server.api.v1.databases.utils_databases.get_databases") @patch("server.api.v1.databases.utils_databases.connect") - async def test_databases_update_raises_400_on_value_error( + async def test_databases_update_raises_400_on_db_exception( self, mock_connect, mock_get_databases, make_database, make_database_auth ): - """databases_update should raise 400 on ValueError during connect.""" + """databases_update should raise 400 on DbException with status 400 during connect.""" existing_db = make_database(name="TEST_DB") mock_get_databases.return_value = existing_db - mock_connect.side_effect = ValueError("Invalid parameters") + mock_connect.side_effect = utils_databases.DbException( + status_code=400, detail="Missing connection details" + ) payload = make_database_auth() @@ -156,11 +159,12 @@ async def test_databases_update_raises_400_on_value_error( await databases.databases_update(name="TEST_DB", payload=payload) assert exc_info.value.status_code == 400 + assert "Missing connection details" in exc_info.value.detail # Verify get_databases was called to retrieve the target database mock_get_databases.assert_called_once_with(db_name="TEST_DB", validate=False) - # Verify connect was called with the payload + # Verify connect was called with the test config mock_connect.assert_called_once() connect_arg = mock_connect.call_args[0][0] assert connect_arg.user == payload.user diff --git a/tests/server/integration/test_endpoints_databases.py b/tests/server/integration/test_endpoints_databases.py deleted file mode 100644 index a05d6d26..00000000 --- a/tests/server/integration/test_endpoints_databases.py +++ /dev/null @@ -1,207 +0,0 @@ -""" -Copyright (c) 2024, 2025, Oracle and/or its affiliates. -Licensed under the Universal Permissive License v1.0 as shown at http://oss.oracle.com/licenses/upl. -""" -# spell-checker: disable -# pylint: disable=protected-access import-error import-outside-toplevel - -import pytest -from conftest import TEST_CONFIG, get_test_db_payload - - -############################################################################# -# Endpoints Test -############################################################################# -class TestEndpoints: - """Test Endpoints""" - - @pytest.mark.parametrize( - "auth_type, status_code", - [ - pytest.param("no_auth", 401, id="no_auth"), - pytest.param("invalid_auth", 401, id="invalid_auth"), - ], - ) - @pytest.mark.parametrize( - "endpoint, api_method", - [ - pytest.param("/v1/databases", "get", id="databases_list"), - pytest.param("/v1/databases/DEFAULT", "get", id="databases_get"), - pytest.param("/v1/databases/DEFAULT", "patch", id="databases_update"), - ], - ) - def test_invalid_auth_endpoints(self, client, auth_headers, endpoint, api_method, auth_type, status_code): - """Test endpoints require valid authentication.""" - response = getattr(client, api_method)(endpoint, headers=auth_headers[auth_type]) - assert response.status_code == status_code - - def test_databases_list_initial(self, client, auth_headers): - """Test initial database listing before any updates""" - response = client.get("/v1/databases", headers=auth_headers["valid_auth"]) - assert response.status_code == 200 - data = response.json() - assert isinstance(data, list) - assert len(data) > 0 - default_db = next((db for db in data if db["name"] == "DEFAULT"), None) - assert default_db is not None - assert default_db["connected"] is False - assert default_db["dsn"] is None - assert default_db["password"] is None - assert default_db["tcp_connect_timeout"] == 5 - assert default_db["user"] is None - assert default_db["vector_stores"] == [] - assert default_db["wallet_location"] is None - assert default_db["wallet_password"] is None - - def test_databases_get_nonexistent(self, client, auth_headers): - """Test getting non-existent database""" - response = client.get("/v1/databases/NONEXISTENT", headers=auth_headers["valid_auth"]) - assert response.status_code == 404 - assert response.json() == {"detail": "Database: NONEXISTENT not found."} - - def test_databases_update_nonexistent(self, client, auth_headers): - """Test updating non-existent database""" - payload = {"user": "test_user", "password": "test_pass", "dsn": "test_dsn", "wallet_password": "test_wallet"} - response = client.patch("/v1/databases/NONEXISTENT", headers=auth_headers["valid_auth"], json=payload) - assert response.status_code == 404 - assert response.json() == {"detail": "Database: NONEXISTENT not found."} - - def test_databases_update_db_down(self, client, auth_headers): - """Test updating the DB when it is down""" - payload = get_test_db_payload() - payload["dsn"] = "//localhost:1521/DOWNDB_TP" # Override with invalid DSN - response = client.patch("/v1/databases/DEFAULT", headers=auth_headers["valid_auth"], json=payload) - assert response.status_code == 503 - assert "cannot connect to database" in response.json().get("detail", "") - - test_cases = [ - pytest.param( - TEST_CONFIG["db_dsn"].split("/")[3], - 404, - get_test_db_payload(), - {"detail": f"Database: {TEST_CONFIG['db_dsn'].split('/')[3]} not found."}, - id="non_existent_database", - ), - pytest.param( - "DEFAULT", - 422, - "", - { - "detail": [ - { - "input": "", - "loc": ["body"], - "msg": "Input should be a valid dictionary or object to extract fields from", - "type": "model_attributes_type", - } - ] - }, - id="empty_payload", - ), - pytest.param( - "DEFAULT", - 400, - {}, - {"detail": "Database: DEFAULT missing connection details."}, - id="missing_credentials", - ), - pytest.param( - "DEFAULT", - 503, - {"user": "user", "password": "password", "dsn": "//localhost:1521/dsn"}, - {"detail": "cannot connect to database"}, - id="invalid_connection", - ), - pytest.param( - "DEFAULT", - 401, - { - "user": TEST_CONFIG["db_username"], - "password": "Wr0ng_P4sswOrd", - "dsn": TEST_CONFIG["db_dsn"], - }, - {"detail": "invalid credential or not authorized"}, - id="wrong_password", - ), - pytest.param( - "DEFAULT", - 200, - get_test_db_payload(), - { - "connected": True, - "dsn": TEST_CONFIG["db_dsn"], - "name": "DEFAULT", - "password": TEST_CONFIG["db_password"], - "tcp_connect_timeout": 5, - "user": TEST_CONFIG["db_username"], - "vector_stores": [], - "wallet_location": None, - "wallet_password": None, - }, - id="successful_update", - ), - ] - - @pytest.mark.parametrize("database, status_code, payload, expected", test_cases) - def test_databases_update_cases( - self, client, auth_headers, db_container, database, status_code, payload, expected - ): - """Test various database update scenarios""" - assert db_container is not None - response = client.patch(f"/v1/databases/{database}", headers=auth_headers["valid_auth"], json=payload) - assert response.status_code == status_code - - if response.status_code != 200: - if response.status_code == 422: - assert response.json() == expected - else: - assert expected["detail"] in response.json().get("detail", "") - else: - data = response.json() - data.pop("config_dir", None) # Remove config_dir as it's environment-specific - assert data == expected - # Get after successful update - response = client.get("/v1/databases/DEFAULT", headers=auth_headers["valid_auth"]) - assert response.status_code == 200 - data = response.json() - assert "config_dir" in data - data.pop("config_dir", None) - assert data == expected - # List after successful update - response = client.get("/v1/databases", headers=auth_headers["valid_auth"]) - assert response.status_code == 200 - data = response.json() - default_db = next((db for db in data if db["name"] == "DEFAULT"), None) - assert default_db is not None - assert "config_dir" in default_db - default_db.pop("config_dir", None) - assert default_db == expected - - def test_databases_update_invalid_wallet(self, client, auth_headers, db_container): - """Test updating database with invalid wallet configuration""" - assert db_container is not None - payload = { - **get_test_db_payload(), - "wallet_location": "/nonexistent/path", - "wallet_password": "invalid", - } - response = client.patch("/v1/databases/DEFAULT", headers=auth_headers["valid_auth"], json=payload) - # Should still work if wallet is not required. - assert response.status_code == 200 - - def test_databases_concurrent_connections(self, client, auth_headers, db_container): - """Test concurrent database connections""" - assert db_container is not None - # Make multiple concurrent connection attempts - payload = get_test_db_payload() - responses = [] - for _ in range(5): # Try 5 concurrent connections - response = client.patch("/v1/databases/DEFAULT", headers=auth_headers["valid_auth"], json=payload) - responses.append(response) - - # Verify all connections were handled properly - for response in responses: - assert response.status_code in [200, 503] # Either successful or proper error - if response.status_code == 200: - data = response.json() - assert data["connected"] is True diff --git a/tests/server/integration/test_endpoints_embed.py b/tests/server/integration/test_endpoints_embed.py deleted file mode 100644 index 91e90f57..00000000 --- a/tests/server/integration/test_endpoints_embed.py +++ /dev/null @@ -1,532 +0,0 @@ -""" -Copyright (c) 2024, 2025, Oracle and/or its affiliates. -Licensed under the Universal Permissive License v1.0 as shown at http://oss.oracle.com/licenses/upl. -""" -# spell-checker: disable -# pylint: disable=protected-access import-error import-outside-toplevel - -from io import BytesIO -from pathlib import Path -from unittest.mock import MagicMock, patch -import pytest -from conftest import TEST_CONFIG, get_test_db_payload -from langchain_core.embeddings import Embeddings -from common.functions import get_vs_table - -# Common test constants -DEFAULT_TEST_CONTENT = ( - "This is a test document for embedding. It contains multiple sentences. " - "This should be split into chunks. Each chunk will be embedded and stored in the database." -) - -LONGER_TEST_CONTENT = ( - "This is a test document for embedding. It contains multiple sentences. " - "This should be split into chunks. Each chunk will be embedded and stored in the database. " - "We're adding more text to ensure we get multiple chunks with different chunk sizes. " - "The chunk size parameter controls how large each text segment is. " - "Smaller chunks mean more granular retrieval but potentially less context. " - "Larger chunks provide more context but might retrieve irrelevant information." -) - -DEFAULT_EMBED_PARAMS = { - "model": "mock-embed-model", - "chunk_size": 100, - "chunk_overlap": 20, - "distance_metric": "COSINE", - "index_type": "HNSW", -} - - -############################################################################# -# Endpoints Test -############################################################################# -class TestEndpoints: - """Test Endpoints""" - - @pytest.mark.parametrize( - "auth_type, status_code", - [ - pytest.param("no_auth", 401, id="no_auth"), - pytest.param("invalid_auth", 401, id="invalid_auth"), - ], - ) - @pytest.mark.parametrize( - "endpoint, api_method", - [ - pytest.param("/v1/embed/TESTVS", "delete", id="embed_drop_vs"), - pytest.param("/v1/embed/TESTVS/files", "get", id="embed_get_files"), - pytest.param("/v1/embed/web/store", "post", id="store_web_file"), - pytest.param("/v1/embed/local/store", "post", id="store_local_file"), - pytest.param("/v1/embed", "post", id="split_embed"), - pytest.param("/v1/embed/refresh", "post", id="refresh_vector_store"), - ], - ) - def test_invalid_auth_endpoints(self, client, auth_headers, endpoint, api_method, auth_type, status_code): - """Test endpoints require valid authentication.""" - response = getattr(client, api_method)(endpoint, headers=auth_headers[auth_type]) - assert response.status_code == status_code - - def configure_database(self, client, auth_headers): - """Update Database Configuration""" - payload = get_test_db_payload() - response = client.patch("/v1/databases/DEFAULT", headers=auth_headers["valid_auth"], json=payload) - assert response.status_code == 200 - - def create_test_file(self, filename="test_document.md", content=DEFAULT_TEST_CONTENT): - """Create a test file in the temporary directory""" - client_id = TEST_CONFIG["client"] - embed_dir = Path("/tmp") / client_id / "embedding" - embed_dir.mkdir(parents=True, exist_ok=True) - test_file = embed_dir / filename - test_file.write_text(content) - return embed_dir, test_file - - # Define MockEmbeddings class once at the class level - class MockEmbeddings(Embeddings): - """Mock implementation of the Embeddings interface for testing""" - - def __init__(self, mock_embedding_model): - self.mock_embedding_model = mock_embedding_model - - def embed_documents(self, texts): - return self.mock_embedding_model(texts) - - def embed_query(self, text: str): - return self.mock_embedding_model([text])[0] - - # Required by the Embeddings base class - def embed_strings(self, texts): - """Mock embedding strings""" - return self.embed_documents(texts) - - def setup_mock_embeddings(self, mock_embedding_model): - """Create mock embeddings and get_client_embed function""" - mock_embeddings = self.MockEmbeddings(mock_embedding_model) - - def mock_get_client_embed(_model_config=None, _oci_config=None, _giskard=False): - return mock_embeddings - - return mock_get_client_embed - - def create_embed_params(self, alias): - """Create embedding parameters with the given alias""" - params = DEFAULT_EMBED_PARAMS.copy() - params["alias"] = alias - return params - - def get_vector_store_name(self, alias): - """Get the expected vector store name for an alias""" - vector_store_name, _ = get_vs_table( - model=DEFAULT_EMBED_PARAMS["model"], - chunk_size=DEFAULT_EMBED_PARAMS["chunk_size"], - chunk_overlap=DEFAULT_EMBED_PARAMS["chunk_overlap"], - distance_metric=DEFAULT_EMBED_PARAMS["distance_metric"], - index_type=DEFAULT_EMBED_PARAMS["index_type"], - alias=alias, - ) - return vector_store_name - - def verify_vector_store_exists(self, client, auth_headers, vector_store_name, should_exist=True): - """Verify if a vector store exists in the database""" - db_response = client.get("/v1/databases/DEFAULT", headers=auth_headers["valid_auth"]) - assert db_response.status_code == 200 - db_data = db_response.json() - - vector_stores = db_data.get("vector_stores", []) - vector_store_names = [vs["vector_store"] for vs in vector_stores] - - if should_exist: - assert vector_store_name in vector_store_names, f"Vector store {vector_store_name} not found in database" - else: - assert vector_store_name not in vector_store_names, ( - f"Vector store {vector_store_name} still exists after dropping" - ) - - ######################################################################### - # Tests Start - ######################################################################### - def test_drop_vs_nodb(self, client, auth_headers): - """Test dropping vector store without a DB connection""" - # Test with valid vector store - vs = "TESTVS" - response = client.delete(f"/v1/embed/{vs}", headers=auth_headers["valid_auth"]) - assert response.status_code in (200, 400) - # 200 if run as part of full test-suite; 400 if run on its own - if response.status_code == 400: - assert "missing connection details" in response.json()["detail"] - - def test_drop_vs_db(self, client, auth_headers, db_container): - """Test dropping vector store""" - assert db_container is not None - self.configure_database(client, auth_headers) - # Test with invalid vector store - vs = "NONEXISTENT_VS" - response = client.delete(f"/v1/embed/{vs}", headers=auth_headers["valid_auth"]) - assert response.status_code == 200 # Should still return 200 as dropping non-existent is valid - assert response.json() == {"message": f"Vector Store: {vs} dropped."} - - def test_split_embed(self, client, auth_headers, db_container, mock_embedding_model): - """Test split and embed functionality with mock embedding model""" - assert db_container is not None - self.configure_database(client, auth_headers) - - # Create a test file in the temporary directory - self.create_test_file() - - # Setup mock embeddings - _ = self.MockEmbeddings(mock_embedding_model) - - # Create test request data - test_data = self.create_embed_params("test_basic_embed") - - # Mock the client's post method - with patch.object(client, "post") as mock_post: - # Configure the mock response - mock_response = MagicMock() - mock_response.status_code = 200 - mock_response.json.return_value = {"message": "10 chunks embedded."} - mock_post.return_value = mock_response - - # Make request to the split_embed endpoint - response = client.post("/v1/embed", headers=auth_headers["valid_auth"], json=test_data) - - # Verify the response - assert response.status_code == 200 - response_data = response.json() - assert "message" in response_data - assert "chunks embedded" in response_data["message"].lower() - - def test_split_embed_with_different_chunk_sizes(self, client, auth_headers, db_container, mock_embedding_model): - """Test split and embed with different chunk sizes""" - assert db_container is not None - self.configure_database(client, auth_headers) - - # Setup mock embeddings - _ = self.MockEmbeddings(mock_embedding_model) - - # Test with small chunk size - small_chunk_test_data = self.create_embed_params("test_small_chunks") - small_chunk_test_data["chunk_size"] = 50 # Small chunks - small_chunk_test_data["chunk_overlap"] = 10 - - # Test with large chunk size - large_chunk_test_data = self.create_embed_params("test_large_chunks") - large_chunk_test_data["chunk_size"] = 200 # Large chunks - large_chunk_test_data["chunk_overlap"] = 20 - - # Mock the client's post method - with patch.object(client, "post") as mock_post: - # Configure the mock responses - mock_response_small = MagicMock() - mock_response_small.status_code = 200 - mock_response_small.json.return_value = {"message": "15 chunks embedded."} - - mock_response_large = MagicMock() - mock_response_large.status_code = 200 - mock_response_large.json.return_value = {"message": "5 chunks embedded."} - - # Set up the side effect to return different responses - mock_post.side_effect = [mock_response_small, mock_response_large] - - # Create a test file for the first request - self.create_test_file(content=LONGER_TEST_CONTENT) - - # Test with small chunks - small_response = client.post("/v1/embed", headers=auth_headers["valid_auth"], json=small_chunk_test_data) - assert small_response.status_code == 200 - small_data = small_response.json() - - # Create a test file again for the second request (since the first one was cleaned up) - self.create_test_file(content=LONGER_TEST_CONTENT) - - # Test with large chunks - large_response = client.post("/v1/embed", headers=auth_headers["valid_auth"], json=large_chunk_test_data) - assert large_response.status_code == 200 - large_data = large_response.json() - - # Extract the number of chunks from each response - small_chunks = int(small_data["message"].split()[0]) - large_chunks = int(large_data["message"].split()[0]) - - # Smaller chunk size should result in more chunks - assert small_chunks > large_chunks, "Smaller chunk size should create more chunks" - - def test_store_local_file(self, client, auth_headers): - """Test storing local files for embedding""" - # Create a test file content - test_content = b"This is a test file for uploading." - - file_obj = BytesIO(test_content) - - # Make the request using TestClient's built-in file upload support - response = client.post( - "/v1/embed/local/store", - headers=auth_headers["valid_auth"], - files={"files": ("test_upload.txt", file_obj, "text/plain")}, - ) - - # Verify the response - assert response.status_code == 200 - stored_files = response.json() - assert "test_upload.txt" in stored_files - - # Verify the file was actually created in the temporary directory - client_id = TEST_CONFIG["client"] - embed_dir = Path("/tmp") / client_id / "embedding" - file_path = embed_dir / "test_upload.txt" - assert file_path.exists(), f"File {file_path} was not created in the temporary directory" - assert file_path.is_file(), f"Path {file_path} exists but is not a file" - assert file_path.stat().st_size > 0, f"File {file_path} exists but is empty" - - def test_store_web_file(self, client, auth_headers): - """Test storing web files for embedding""" - # Test URL - test_url = ( - "https://docs.oracle.com/en/database/oracle/oracle-database/23/jjucp/" - "universal-connection-pool-developers-guide.pdf" - ) - - # Make the request - response = client.post("/v1/embed/web/store", headers=auth_headers["valid_auth"], json=[test_url]) - - # Verify the response - assert response.status_code == 200 - stored_files = response.json() - assert "universal-connection-pool-developers-guide.pdf" in stored_files - - # Verify the file was actually created in the temporary directory - client_id = TEST_CONFIG["client"] - embed_dir = Path("/tmp") / client_id / "embedding" - file_path = embed_dir / "universal-connection-pool-developers-guide.pdf" - assert file_path.exists(), f"File {file_path} was not created in the temporary directory" - assert file_path.is_file(), f"Path {file_path} exists but is not a file" - assert file_path.stat().st_size > 0, f"File {file_path} exists but is empty" - - def test_split_embed_no_files(self, client, auth_headers): - """Test split and embed with no files in the directory""" - # Ensure the temporary directory exists but is empty - client_id = TEST_CONFIG["client"] - embed_dir = Path("/tmp") / client_id / "embedding" - embed_dir.mkdir(parents=True, exist_ok=True) - - # Remove any existing files in the directory - for file_path in embed_dir.iterdir(): - if file_path.is_file(): - file_path.unlink() - - # Verify the directory is empty - assert not any(embed_dir.iterdir()), "The temporary directory should be empty" - - # Create test request data - test_data = self.create_embed_params("test_no_files") - - # Make request to the split_embed endpoint without creating any files - response = client.post("/v1/embed", headers=auth_headers["valid_auth"], json=test_data) - - # Verify the response - assert response.status_code == 404 - assert "no files found in folder" in response.json()["detail"] - - def test_split_embed_with_different_file_types(self, client, auth_headers, db_container, mock_embedding_model): - """Test split and embed with different file types""" - assert db_container is not None - self.configure_database(client, auth_headers) - - # Create test files of different types - client_id = TEST_CONFIG["client"] - embed_dir = Path("/tmp") / client_id / "embedding" - embed_dir.mkdir(parents=True, exist_ok=True) - - # Create a markdown file - md_file = embed_dir / "test_document.md" - md_file.write_text( - "# Test Markdown Document\n\n" - "This is a test markdown document for embedding.\n\n" - "## Section 1\n\n" - "This is section 1 content.\n\n" - "## Section 2\n\n" - "This is section 2 content." - ) - - # Create a CSV file - csv_file = embed_dir / "test_data.csv" - csv_file.write_text( - "id,name,description\n" - "1,Item 1,This is item 1 description\n" - "2,Item 2,This is item 2 description\n" - "3,Item 3,This is item 3 description" - ) - - # Setup mock embeddings - mock_get_client_embed = self.setup_mock_embeddings(mock_embedding_model) - - # Test data - test_data = self.create_embed_params("test_mixed_files") - - with patch("server.api.utils.models.get_client_embed", side_effect=mock_get_client_embed): - # Make request to the split_embed endpoint - response = client.post("/v1/embed", headers=auth_headers["valid_auth"], json=test_data) - - # Verify the response - assert response.status_code == 200 - response_data = response.json() - assert "message" in response_data - assert "chunks embedded" in response_data["message"].lower() - - # Should have embedded chunks from both files - num_chunks = int(response_data["message"].split()[0]) - assert num_chunks > 0, "Should have embedded at least one chunk" - - # Clean up - drop the vector store that was created - expected_vector_store_name = self.get_vector_store_name("test_mixed_files") - drop_response = client.delete(f"/v1/embed/{expected_vector_store_name}", headers=auth_headers["valid_auth"]) - assert drop_response.status_code == 200 - - def test_vector_store_creation_and_deletion(self, client, auth_headers, db_container, mock_embedding_model): - """Test that vector stores are created in the database and can be deleted""" - assert db_container is not None - self.configure_database(client, auth_headers) - - # Create a test file in the temporary directory - self.create_test_file() - - # Setup mock embeddings - mock_get_client_embed = self.setup_mock_embeddings(mock_embedding_model) - - # Test data for embedding - alias = "test_lifecycle" - test_data = self.create_embed_params(alias) - - # Calculate the expected vector store name - expected_vector_store_name = self.get_vector_store_name(alias) - - with patch("server.api.utils.models.get_client_embed", side_effect=mock_get_client_embed): - # Step 1: Create the vector store by embedding documents - response = client.post("/v1/embed", headers=auth_headers["valid_auth"], json=test_data) - assert response.status_code == 200 - - # Step 2: Verify the vector store exists in the database - self.verify_vector_store_exists(client, auth_headers, expected_vector_store_name, should_exist=True) - - # Step 3: Drop the vector store - drop_response = client.delete( - f"/v1/embed/{expected_vector_store_name}", headers=auth_headers["valid_auth"] - ) - assert drop_response.status_code == 200 - assert drop_response.json() == {"message": f"Vector Store: {expected_vector_store_name} dropped."} - - # Step 4: Verify the vector store no longer exists - self.verify_vector_store_exists(client, auth_headers, expected_vector_store_name, should_exist=False) - - def test_multiple_vector_stores(self, client, auth_headers, db_container, mock_embedding_model): - """Test creating multiple vector stores and verifying they all exist""" - assert db_container is not None - self.configure_database(client, auth_headers) - - # Create aliases for different vector stores - aliases = ["test_vs_1", "test_vs_2", "test_vs_3"] - - # Setup mock embeddings - mock_get_client_embed = self.setup_mock_embeddings(mock_embedding_model) - - # Calculate expected vector store names - expected_vector_store_names = [self.get_vector_store_name(alias) for alias in aliases] - - with patch("server.api.utils.models.get_client_embed", side_effect=mock_get_client_embed): - # Create multiple vector stores with different aliases - for alias in aliases: - # Create a test file for each request (since previous ones were cleaned up) - self.create_test_file() - - test_data = self.create_embed_params(alias) - response = client.post("/v1/embed", headers=auth_headers["valid_auth"], json=test_data) - assert response.status_code == 200 - - # Verify all vector stores exist in the database - for expected_name in expected_vector_store_names: - self.verify_vector_store_exists(client, auth_headers, expected_name, should_exist=True) - - # Clean up - drop all vector stores - for expected_name in expected_vector_store_names: - drop_response = client.delete(f"/v1/embed/{expected_name}", headers=auth_headers["valid_auth"]) - assert drop_response.status_code == 200 - - # Verify all vector stores are removed - for expected_name in expected_vector_store_names: - self.verify_vector_store_exists(client, auth_headers, expected_name, should_exist=False) - - def test_get_vector_store_files(self, client, auth_headers, db_container, mock_embedding_model): - """Test retrieving file list from vector store""" - assert db_container is not None - self.configure_database(client, auth_headers) - - # Create and populate a vector store - self.create_test_file(content=LONGER_TEST_CONTENT) - mock_get_client_embed = self.setup_mock_embeddings(mock_embedding_model) - - alias = "test_file_listing" - test_data = self.create_embed_params(alias) - expected_vector_store_name = self.get_vector_store_name(alias) - - with patch("server.api.utils.models.get_client_embed", side_effect=mock_get_client_embed): - # Create vector store - response = client.post("/v1/embed", headers=auth_headers["valid_auth"], json=test_data) - assert response.status_code == 200 - - # Get file list - file_list_response = client.get( - f"/v1/embed/{expected_vector_store_name}/files", - headers=auth_headers["valid_auth"] - ) - - # Verify response - assert file_list_response.status_code == 200 - data = file_list_response.json() - - assert "vector_store" in data - assert data["vector_store"] == expected_vector_store_name - assert "total_files" in data - assert "total_chunks" in data - assert "files" in data - assert data["total_files"] > 0 - assert data["total_chunks"] > 0 - - # Clean up - drop_response = client.delete(f"/v1/embed/{expected_vector_store_name}", headers=auth_headers["valid_auth"]) - assert drop_response.status_code == 200 - - def test_get_files_empty_vector_store(self, client, auth_headers, db_container, mock_embedding_model): - """Test retrieving file list from empty vector store""" - assert db_container is not None - self.configure_database(client, auth_headers) - - # Create empty vector store - self.create_test_file() - mock_get_client_embed = self.setup_mock_embeddings(mock_embedding_model) - - alias = "test_empty_listing" - test_data = self.create_embed_params(alias) - expected_vector_store_name = self.get_vector_store_name(alias) - - with patch("server.api.utils.models.get_client_embed", side_effect=mock_get_client_embed): - # Create vector store - response = client.post("/v1/embed", headers=auth_headers["valid_auth"], json=test_data) - assert response.status_code == 200 - - # Drop all chunks to make it empty - drop_response = client.delete(f"/v1/embed/{expected_vector_store_name}", headers=auth_headers["valid_auth"]) - assert drop_response.status_code == 200 - - def test_get_files_nonexistent_vector_store(self, client, auth_headers, db_container): - """Test retrieving file list from nonexistent vector store""" - assert db_container is not None - self.configure_database(client, auth_headers) - - # Try to get files from non-existent vector store - response = client.get( - "/v1/embed/NONEXISTENT_VS/files", - headers=auth_headers["valid_auth"] - ) - - # Should return error or empty list - assert response.status_code in (200, 400) diff --git a/tests/server/integration/test_endpoints_health.py b/tests/server/integration/test_endpoints_health.py deleted file mode 100644 index af3adb12..00000000 --- a/tests/server/integration/test_endpoints_health.py +++ /dev/null @@ -1,30 +0,0 @@ -""" -Copyright (c) 2024, 2025, Oracle and/or its affiliates. -Licensed under the Universal Permissive License v1.0 as shown at http://oss.oracle.com/licenses/upl. -""" -# spell-checker: disable -# pylint: disable=protected-access import-error import-outside-toplevel - -import pytest - - -@pytest.mark.parametrize( - "endpoint, status_msg", - [ - pytest.param("/v1/liveness", {"status": "alive"}, id="liveness"), - pytest.param("/v1/readiness", {"status": "ready"}, id="readiness"), - ], -) -@pytest.mark.parametrize( - "auth_type", - [ - pytest.param("no_auth", id="no_auth"), - pytest.param("invalid_auth", id="invalid_auth"), - pytest.param("valid_auth", id="valid_auth"), - ], -) -def test_health_endpoints(client, auth_headers, endpoint, status_msg, auth_type): - """Test that health check endpoints work with or without authentication.""" - response = client.get(endpoint, headers=auth_headers[auth_type]) - assert response.status_code == 200 # Health endpoints should always return 200 - assert response.json() == status_msg diff --git a/tests/server/integration/test_endpoints_models.py b/tests/server/integration/test_endpoints_models.py deleted file mode 100644 index f4e0dd10..00000000 --- a/tests/server/integration/test_endpoints_models.py +++ /dev/null @@ -1,456 +0,0 @@ -""" -Copyright (c) 2024, 2025, Oracle and/or its affiliates. -Licensed under the Universal Permissive License v1.0 as shown at http://oss.oracle.com/licenses/upl. -""" -# spell-checker: disable -# pylint: disable=protected-access import-error import-outside-toplevel - -import pytest - - -############################################################################# -# Endpoints Test -############################################################################# -class TestEndpoints: - """Test Endpoints""" - - @pytest.mark.parametrize( - "auth_type, status_code", - [ - pytest.param("no_auth", 401, id="no_auth"), - pytest.param("invalid_auth", 401, id="invalid_auth"), - ], - ) - @pytest.mark.parametrize( - "endpoint, api_method", - [ - pytest.param("/v1/models", "get", id="models_list"), - pytest.param("/v1/models/supported", "get", id="models_supported"), - pytest.param("/v1/models/model_provider/model_id", "get", id="models_get"), - pytest.param("/v1/models/model_provider/model_id", "patch", id="models_update"), - pytest.param("/v1/models", "post", id="models_create"), - pytest.param("/v1/models/model_provider/model_id", "delete", id="models_delete"), - ], - ) - def test_invalid_auth_endpoints(self, client, auth_headers, endpoint, api_method, auth_type, status_code): - """Test endpoints require valid authentication.""" - response = getattr(client, api_method)(endpoint, headers=auth_headers[auth_type]) - assert response.status_code == status_code - - def test_models_list_api(self, client, auth_headers): - """Get a list of model Providers to use with tests""" - response = client.get("/v1/models/supported", headers=auth_headers["valid_auth"]) - assert response.status_code == 200 - - def test_models_list_with_model_type_filter(self, client, auth_headers): - """Test /v1/models endpoint with model_type parameter""" - # Test with valid model types - for model_type in ["ll", "embed", "rerank"]: - response = client.get(f"/v1/models?model_type={model_type}", headers=auth_headers["valid_auth"]) - assert response.status_code == 200 - models = response.json() - # If models exist, they should all match the requested type - for model in models: - assert model["type"] == model_type - - # Test with model_type and include_disabled - response = client.get("/v1/models?model_type=ll&include_disabled=true", headers=auth_headers["valid_auth"]) - assert response.status_code == 200 - - # Test with invalid model type should return 422 validation error - response = client.get("/v1/models?model_type=invalid", headers=auth_headers["valid_auth"]) - assert response.status_code == 422 - - def test_models_supported_with_filters(self, client, auth_headers): - """Test /v1/models/supported endpoint with query parameters""" - # Test basic supported models - response = client.get("/v1/models/supported", headers=auth_headers["valid_auth"]) - assert response.status_code == 200 - all_supported = response.json() - assert isinstance(all_supported, list) - - # Test with model_provider filter - if all_supported: - # Get a provider from the response to test with - test_provider = all_supported[0].get("provider", "openai") - response = client.get( - f"/v1/models/supported?model_provider={test_provider}", headers=auth_headers["valid_auth"] - ) - assert response.status_code == 200 - filtered_models = response.json() - for model in filtered_models: - assert model.get("provider") == test_provider - - # Test with model_type filter - for model_type in ["ll", "embed", "rerank"]: - response = client.get(f"/v1/models/supported?model_type={model_type}", headers=auth_headers["valid_auth"]) - assert response.status_code == 200 - filtered_providers = response.json() - for provider_entry in filtered_providers: - assert "provider" in provider_entry - assert "models" in provider_entry - for model in provider_entry["models"]: - # Only check type if it exists (some models may not have type set due to exceptions) - if "type" in model: - assert model["type"] == model_type - - # Test with both filters - response = client.get( - "/v1/models/supported?model_provider=openai&model_type=ll", headers=auth_headers["valid_auth"] - ) - assert response.status_code == 200 - filtered_providers = response.json() - for provider_entry in filtered_providers: - assert provider_entry.get("provider") == "openai" - for model in provider_entry["models"]: - # Only check type if it exists (some models may not have type set due to exceptions) - if "type" in model: - assert model["type"] == "ll" - - # Test with invalid provider - response = client.get( - "/v1/models/supported?model_provider=invalid_provider", headers=auth_headers["valid_auth"] - ) - assert response.status_code == 200 - assert response.json() == [] - - def test_models_get_before(self, client, auth_headers): - """Retrieve each individual model""" - all_models = client.get("/v1/models?include_disabled=true", headers=auth_headers["valid_auth"]) - assert len(all_models.json()) > 0 - for model in all_models.json(): - response = client.get(f"/v1/models/{model['provider']}/{model['id']}", headers=auth_headers["valid_auth"]) - assert response.status_code == 200 - - def test_models_delete_add(self, client, auth_headers): - """Delete and Re-Add Models""" - all_models = client.get("/v1/models?include_disabled=true", headers=auth_headers["valid_auth"]) - assert len(all_models.json()) > 0 - - # Delete all models - for model in all_models.json(): - response = client.delete( - f"/v1/models/{model['provider']}/{model['id']}", headers=auth_headers["valid_auth"] - ) - assert response.status_code == 200 - assert response.json() == {"message": f"Model: {model['provider']}/{model['id']} deleted."} - # Check that no models exists - deleted_models = client.get("/v1/models?include_disabled=true", headers=auth_headers["valid_auth"]) - assert len(deleted_models.json()) == 0 - - # Delete a non-existent model - response = client.delete("/v1/models/test_provider/test_model", headers=auth_headers["valid_auth"]) - assert response.status_code == 200 - assert response.json() == {"message": "Model: test_provider/test_model deleted."} - - # Add all models back - for model in all_models.json(): - payload = model - response = client.post("/v1/models", headers=auth_headers["valid_auth"], json=payload) - assert response.status_code == 201 - assert response.json() == payload - new_models = client.get("/v1/models?include_disabled=true", headers=auth_headers["valid_auth"]) - assert new_models.json() == all_models.json() - - def test_models_add_dupl(self, client, auth_headers): - """Add Duplicate Models""" - all_models = client.get("/v1/models?include_disabled=true", headers=auth_headers["valid_auth"]) - assert len(all_models.json()) > 0 - for model in all_models.json(): - payload = model - response = client.post("/v1/models", headers=auth_headers["valid_auth"], json=payload) - assert response.status_code == 409 - assert response.json() == {"detail": f"Model: {model['provider']}/{model['id']} already exists."} - - test_cases = [ - pytest.param( - { - "id": "gpt-3.5-turbo", - "enabled": True, - "type": "ll", - "provider": "openai", - "api_key": "test-key", - "api_base": "https://api.openai.com/v1", - "max_input_tokens": 127072, - "temperature": 1.0, - "max_tokens": 4096, - "frequency_penalty": 0.0, - }, - 201, - 200, - id="valid_ll_model", - ), - pytest.param( - { - "id": "invalid_ll_model", - "provider": "invalid_ll_model", - "enabled": False, - }, - 422, - 422, - id="invalid_ll_model", - ), - pytest.param( - { - "id": "test_embed_model", - "enabled": False, - "type": "embed", - "provider": "huggingface", - "api_base": "http://127.0.0.1:8080", - "api_key": "", - "max_chunk_size": 512, - }, - 201, - 422, - id="valid_embed_model", - ), - pytest.param( - { - "id": "unreachable_api_base_model", - "enabled": True, - "type": "embed", - "provider": "huggingface", - "api_base": "http://127.0.0.1:112233", - "api_key": "", - "max_chunk_size": 512, - }, - 201, - 422, - id="unreachable_api_base_model", - ), - ] - - @pytest.mark.parametrize("payload, add_status_code, _", test_cases) - def test_model_create(self, client, auth_headers, payload, add_status_code, _, request): - """Create Models""" - response = client.post("/v1/models", headers=auth_headers["valid_auth"], json=payload) - assert response.status_code == add_status_code - if add_status_code == 201: - if request.node.callspec.id == "unreachable_api_base_model": - assert response.json()["enabled"] is False - else: - print(response.json()) - assert all(item in response.json().items() for item in payload.items()) - # Model was added, should get 200 back - response = client.get( - f"/v1/models/{payload['provider']}/{payload['id']}", headers=auth_headers["valid_auth"] - ) - assert response.status_code == 200 - else: - # Model wasn't added, should get a 404 back - response = client.get( - f"/v1/models/{payload['provider']}/{payload['id']}", headers=auth_headers["valid_auth"] - ) - assert response.status_code == 404 - - @pytest.mark.parametrize("payload, add_status_code, update_status_code", test_cases) - def test_model_update(self, client, auth_headers, payload, add_status_code, update_status_code): - """Update Models""" - if add_status_code == 201: - # Create the model when we know it will succeed - _ = client.post("/v1/models", headers=auth_headers["valid_auth"], json=payload) - response = client.get( - f"/v1/models/{payload['provider']}/{payload['id']}", headers=auth_headers["valid_auth"] - ) - old_enabled = response.json()["enabled"] - # Switch up the enabled for the update - payload["enabled"] = not old_enabled - - response = client.patch( - f"/v1/models/{payload['provider']}/{payload['id']}", headers=auth_headers["valid_auth"], json=payload - ) - assert response.status_code == update_status_code - if update_status_code == 200: - new_enabled = response.json()["enabled"] - assert new_enabled is not old_enabled - - def test_models_get_edge_cases(self, client, auth_headers): - """Test edge cases for model path parameters""" - # Test with non-existent model - response = client.get("/v1/models/nonexistent_provider/nonexistent_model", headers=auth_headers["valid_auth"]) - assert response.status_code == 404 - - # Test with special characters in model_id (URL encoded) - test_cases = [ - ("test_provider", "model-with-dashes"), - ("test_provider", "model_with_underscores"), - ("test_provider", "model.with.dots"), - ("test_provider", "model/with/slashes"), - ("test_provider", "model with spaces"), - ] - - for provider, model_id in test_cases: - # These should return 404 since they don't exist - response = client.get(f"/v1/models/{provider}/{model_id}", headers=auth_headers["valid_auth"]) - assert response.status_code == 404 - - # Test very long model ID - long_model_id = "a" * 1000 - response = client.get(f"/v1/models/test_provider/{long_model_id}", headers=auth_headers["valid_auth"]) - assert response.status_code == 404 - - def test_models_delete_edge_cases(self, client, auth_headers): - """Test edge cases for model deletion""" - # Test deleting non-existent models (should succeed with 200) - test_cases = [ - ("nonexistent_provider", "nonexistent_model"), - ("test_provider", "model-with-dashes"), - ("test_provider", "model/with/slashes"), - ] - - for provider, model_id in test_cases: - response = client.delete(f"/v1/models/{provider}/{model_id}", headers=auth_headers["valid_auth"]) - assert response.status_code == 200 - assert response.json() == {"message": f"Model: {provider}/{model_id} deleted."} - - def test_models_update_edge_cases(self, client, auth_headers): - """Test edge cases for model updates""" - # Test updating non-existent model - payload = {"id": "nonexistent_model", "provider": "nonexistent_provider", "type": "ll", "enabled": True} - response = client.patch( - "/v1/models/nonexistent_provider/nonexistent_model", headers=auth_headers["valid_auth"], json=payload - ) - assert response.status_code == 404 - - def test_models_update_max_chunk_size(self, client, auth_headers): - """Test updating max_chunk_size for embedding models (regression test)""" - # Create an embedding model with default max_chunk_size - payload = { - "id": "test-embed-chunk-size", - "enabled": False, - "type": "embed", - "provider": "test_provider", - "api_base": "http://127.0.0.1:11434", - "max_chunk_size": 8192, - } - - # Create the model - response = client.post("/v1/models", headers=auth_headers["valid_auth"], json=payload) - assert response.status_code == 201 - assert response.json()["max_chunk_size"] == 8192 - - # Update the max_chunk_size to 512 - payload["max_chunk_size"] = 512 - response = client.patch( - f"/v1/models/{payload['provider']}/{payload['id']}", headers=auth_headers["valid_auth"], json=payload - ) - assert response.status_code == 200 - assert response.json()["max_chunk_size"] == 512 - - # Verify the update persists by fetching the model again - response = client.get(f"/v1/models/{payload['provider']}/{payload['id']}", headers=auth_headers["valid_auth"]) - assert response.status_code == 200 - assert response.json()["max_chunk_size"] == 512 - - # Update to a different value to ensure it's not cached - payload["max_chunk_size"] = 1024 - response = client.patch( - f"/v1/models/{payload['provider']}/{payload['id']}", headers=auth_headers["valid_auth"], json=payload - ) - assert response.status_code == 200 - assert response.json()["max_chunk_size"] == 1024 - - # Verify again - response = client.get(f"/v1/models/{payload['provider']}/{payload['id']}", headers=auth_headers["valid_auth"]) - assert response.status_code == 200 - assert response.json()["max_chunk_size"] == 1024 - - # Clean up - client.delete(f"/v1/models/{payload['provider']}/{payload['id']}", headers=auth_headers["valid_auth"]) - - def test_models_response_schema_validation(self, client, auth_headers): - """Test response schema validation for all endpoints""" - # Test /v1/models response schema - response = client.get("/v1/models", headers=auth_headers["valid_auth"]) - assert response.status_code == 200 - models = response.json() - assert isinstance(models, list) - - for model in models: - # Validate required fields - assert "id" in model - assert "type" in model - assert "provider" in model - assert "enabled" in model - assert "object" in model - assert "created" in model - assert "owned_by" in model - - # Validate field types - assert isinstance(model["id"], str) - assert model["type"] in ["ll", "embed", "rerank"] - assert isinstance(model["provider"], str) - assert isinstance(model["enabled"], bool) - assert model["object"] == "model" - assert isinstance(model["created"], int) - assert model["owned_by"] == "aioptimizer" - - # Validate optional fields if present - if "api_base" in model and model["api_base"] is not None: - assert isinstance(model["api_base"], str) - if "max_input_tokens" in model and model["max_input_tokens"] is not None: - assert isinstance(model["max_input_tokens"], int) - if "temperature" in model and model["temperature"] is not None: - assert isinstance(model["temperature"], (int, float)) - - # Test /v1/models/supported response schema - response = client.get("/v1/models/supported", headers=auth_headers["valid_auth"]) - assert response.status_code == 200 - supported_models = response.json() - assert isinstance(supported_models, list) - - for model in supported_models: - assert isinstance(model, dict) - # These are the models from LiteLLM, so schema may vary - # Just ensure basic structure is maintained - - # Test individual model GET response schema - if models: - first_model = models[0] - response = client.get( - f"/v1/models/{first_model['provider']}/{first_model['id']}", headers=auth_headers["valid_auth"] - ) - assert response.status_code == 200 - model = response.json() - - # Should have same schema as models list item - assert "id" in model - assert "type" in model - assert "provider" in model - assert "enabled" in model - assert model["object"] == "model" - assert model["owned_by"] == "aioptimizer" - - def test_models_create_response_validation(self, client, auth_headers): - """Test model creation response validation""" - # Create a test model and validate response - payload = { - "id": "test-response-validation-model", - "enabled": False, - "type": "ll", - "provider": "test_provider", - "api_key": "test-key", - "api_base": "https://api.test.com/v1", - "max_input_tokens": 4096, - "temperature": 0.7, - } - - response = client.post("/v1/models", headers=auth_headers["valid_auth"], json=payload) - if response.status_code == 201: - created_model = response.json() - - # Validate all payload fields are in response - for key, value in payload.items(): - assert key in created_model - assert created_model[key] == value - - # Validate additional required fields are added - assert "object" in created_model - assert "created" in created_model - assert "owned_by" in created_model - assert created_model["object"] == "model" - assert created_model["owned_by"] == "aioptimizer" - assert isinstance(created_model["created"], int) - - # Clean up - client.delete(f"/v1/models/{payload['provider']}/{payload['id']}", headers=auth_headers["valid_auth"]) diff --git a/tests/server/integration/test_endpoints_oci.py b/tests/server/integration/test_endpoints_oci.py deleted file mode 100644 index 59030b39..00000000 --- a/tests/server/integration/test_endpoints_oci.py +++ /dev/null @@ -1,259 +0,0 @@ -""" -Copyright (c) 2024, 2025, Oracle and/or its affiliates. -Licensed under the Universal Permissive License v1.0 as shown at http://oss.oracle.com/licenses/upl. -""" -# spell-checker: disable -# pylint: disable=protected-access import-error import-outside-toplevel - -from unittest.mock import patch, MagicMock -import pytest - - -############################################################################ -# Mocks as no OCI Access -############################################################################ -def mock_client_response(client, method, status_code=200, json_data=None): - """Context manager to mock client responses""" - mock_response = MagicMock() - mock_response.status_code = status_code - if json_data is not None: - mock_response.json.return_value = json_data - return patch.object(client, method, return_value=mock_response) - - -@pytest.fixture(name="mock_init_client") -def _mock_init_client(): - """Mock init_client to return a fake OCI client""" - mock_client = MagicMock() - mock_client.get_namespace.return_value.data = "test_namespace" - mock_client.get_object.return_value.data.raw.stream.return_value = [b"fake-data"] - - with patch("server.api.utils.oci.init_client", return_value=mock_client): - yield mock_client - - -@pytest.fixture(name="mock_get_compartments") -def _mock_get_compartments(): - """Mock get_compartments""" - with patch( - "server.api.utils.oci.get_compartments", - return_value={ - "compartment1": "ocid1.compartment.oc1..aaaaaaaagq33tv7wzyrjar6m5jbplejbdwnbjqfqvmocvjzsamuaqnkkoubq", - "compartment1 / test": "ocid1.compartment.oc1..aaaaaaaaut53mlkpxo6vpv7z5qlsmbcc3qpdjvjzylzldtb6g3jia", - "compartment2": "ocid1.compartment.oc1..aaaaaaaalbgt4om6izlawie7txut5aciue66htz7dpjzl72fbdw2ezp2uywa", - }, - ) as mock: - yield mock - - -@pytest.fixture(name="mock_get_buckets") -def _mock_get_buckets(): - """Mock server_oci.get_buckets""" - with patch( - "server.api.utils.oci.get_buckets", - return_value=["bucket1", "bucket2", "bucket3"], - ) as mock: - yield mock - - -@pytest.fixture(name="mock_get_bucket_objects") -def _mock_get_bucket_objects(): - """Mock server_oci.get_bucket_objects""" - with patch( - "server.api.utils.oci.get_bucket_objects", - return_value=["object1.pdf", "object2.md", "object3.txt"], - ) as mock: - yield mock - - -@pytest.fixture(name="mock_get_namespace") -def _mock_get_namespace(): - """Mock server_oci.get_namespace""" - with patch("server.api.utils.oci.get_namespace", return_value="test_namespace") as mock: - yield mock - - -@pytest.fixture(name="mock_get_object") -def _mock_get_object(): - """Mock get_object to return a fake file path""" - with patch("server.api.utils.oci.get_object") as mock: - - def side_effect(temp_directory, object_name): - fake_file = temp_directory / object_name - fake_file.touch() # Create an empty file to simulate download - return str(fake_file) # Return the path as string to match the actual function - - mock.side_effect = side_effect - yield mock - -############################################################################ -# Endpoints Test -############################################################################ -class TestEndpoints: - """Test Endpoints""" - - @pytest.mark.parametrize( - "auth_type, status_code", - [ - pytest.param("no_auth", 401, id="no_auth"), - pytest.param("invalid_auth", 401, id="invalid_auth"), - ], - ) - @pytest.mark.parametrize( - "endpoint, api_method", - [ - pytest.param("/v1/oci", "get", id="oci_list"), - pytest.param("/v1/oci/DEFAULT", "get", id="oci_get"), - pytest.param("/v1/oci/compartments/DEFAULT", "get", id="oci_list_compartments"), - pytest.param("/v1/oci/buckets/ocid/DEFAULT", "get", id="oci_list_buckets"), - pytest.param("/v1/oci/objects/bucket/DEFAULT", "get", id="oci_list_bucket_objects"), - pytest.param("/v1/oci/DEFAULT", "patch", id="oci_profile_update"), - pytest.param("/v1/oci/objects/download/bucket/DEFAULT", "post", id="oci_download_objects"), - ], - ) - def test_invalid_auth_endpoints(self, client, auth_headers, endpoint, api_method, auth_type, status_code): - """Test endpoints require valid authentication.""" - response = getattr(client, api_method)(endpoint, headers=auth_headers[auth_type]) - assert response.status_code == status_code - - def test_oci_list(self, client, auth_headers): - """List OCI Configuration""" - response = client.get("/v1/oci", headers=auth_headers["valid_auth"]) - assert response.status_code == 200 - # The endpoint returns a list of OracleCloudSettings - assert isinstance(response.json(), list) - # Each item in the list should be a valid OracleCloudSettings object - for item in response.json(): - assert "auth_profile" in item - assert item["auth_profile"] in ["DEFAULT"] # At minimum, DEFAULT profile should exist - - def test_oci_get(self, client, auth_headers): - """List OCI Configuration""" - response = client.get("/v1/oci/DEFAULT", headers=auth_headers["valid_auth"]) - assert response.status_code == 200 - data = response.json() - assert data["auth_profile"] == "DEFAULT" - response = client.get("/v1/oci/TEST", headers=auth_headers["valid_auth"]) - assert response.status_code == 404 - assert response.json() == {"detail": "OCI: profile 'TEST' not found."} - - def test_oci_list_compartments(self, client, auth_headers, mock_get_compartments): - """List OCI Compartments""" - with mock_client_response(client, "get", 200, mock_get_compartments.return_value) as mock_get: - # Test DEFAULT profile - response = client.get("/v1/oci/compartments/DEFAULT", headers=auth_headers["valid_auth"]) - assert response.status_code == 200 - assert response.json() == mock_get_compartments.return_value - - # Test TEST profile - mock_get.return_value.status_code = 404 - mock_get.return_value.json.return_value = {"detail": "OCI: profile 'TEST' not found"} - response = client.get("/v1/oci/compartments/TEST", headers=auth_headers["valid_auth"]) - assert response.status_code == 404 - assert response.json() == {"detail": "OCI: profile 'TEST' not found"} - - def test_oci_list_buckets(self, client, auth_headers, mock_get_buckets): - """List OCI Buckets""" - with mock_client_response(client, "get", 200, mock_get_buckets.return_value) as mock_get: - response = client.get( - "/v1/oci/buckets/ocid1.compartment.oc1..aaaaaaaa/DEFAULT", headers=auth_headers["valid_auth"] - ) - assert response.status_code == 200 - assert response.json() == mock_get_buckets.return_value - - # Test TEST profile - mock_get.return_value.status_code = 404 - mock_get.return_value.json.return_value = {"detail": "OCI: profile 'TEST' not found"} - response = client.get( - "/v1/oci/buckets/ocid1.compartment.oc1..aaaaaaaa/TEST", headers=auth_headers["valid_auth"] - ) - assert response.status_code == 404 - assert response.json() == {"detail": "OCI: profile 'TEST' not found"} - - def test_oci_list_bucket_objects(self, client, auth_headers, mock_get_bucket_objects): - """List OCI Bucket Objects""" - with mock_client_response(client, "get", 200, mock_get_bucket_objects.return_value) as mock_get: - response = client.get("/v1/oci/objects/bucket1/DEFAULT", headers=auth_headers["valid_auth"]) - assert response.status_code == 200 - assert response.json() == mock_get_bucket_objects.return_value - - # Test TEST profile - mock_get.return_value.status_code = 404 - mock_get.return_value.json.return_value = {"detail": "OCI: profile 'TEST' not found"} - response = client.get("/v1/oci/objects/bucket1/TEST", headers=auth_headers["valid_auth"]) - assert response.status_code == 404 - assert response.json() == {"detail": "OCI: profile 'TEST' not found"} - - test_cases = [ - pytest.param("DEFAULT", "", 422, id="empty_payload"), - pytest.param("DEFAULT", {}, 400, id="invalid_payload"), - pytest.param( - "DEFAULT", - { - "tenancy": "ocid1.tenancy.oc1..aaaaaaaa", - "user": "ocid1.user.oc1..aaaaaaaa", - "region": "us-ashburn-1", - "fingerprint": "e8:65:45:4a:85:4b:6c:51:63:b8:84:64:ef:36:16:7b", - "key_file": "/dev/null", - }, - 200, - id="valid_default_profile", - ), - pytest.param( - "TEST", - { - "tenancy": "ocid1.tenancy.oc1..aaaaaaaa", - "user": "ocid1.user.oc1..aaaaaaaa", - "region": "us-ashburn-1", - "fingerprint": "e8:65:45:4a:85:4b:6c", - "key_file": "/tmp/key.pem", - }, - 404, - id="valid_test_profile", - ), - ] - - @pytest.mark.parametrize("profile, payload, status_code", test_cases) - def test_oci_profile_update(self, client, auth_headers, profile, payload, status_code, mock_get_namespace): - """Update Profile""" - json_data = {"namespace": mock_get_namespace.return_value} if status_code == 200 else None - with mock_client_response(client, "patch", status_code, json_data): - response = client.patch(f"/v1/oci/{profile}", headers=auth_headers["valid_auth"], json=payload) - assert response.status_code == status_code - if status_code == 200: - data = response.json() - assert data["namespace"] == mock_get_namespace.return_value - - def test_oci_download_objects( - self, client, auth_headers, mock_get_compartments, mock_get_buckets, mock_get_bucket_objects, mock_get_object - ): - """OCI Object Download""" - # Get Compartments - with mock_client_response(client, "get", 200, mock_get_compartments.return_value): - response = client.get("/v1/oci/compartments/DEFAULT", headers=auth_headers["valid_auth"]) - assert response.status_code == 200 - assert response.json() == mock_get_compartments.return_value - compartment = response.json()[next(iter(response.json()))] - - # Get Buckets - with mock_client_response(client, "get", 200, mock_get_buckets.return_value): - response = client.get(f"/v1/oci/buckets/{compartment}/DEFAULT", headers=auth_headers["valid_auth"]) - assert response.status_code == 200 - assert response.json() == mock_get_buckets.return_value - bucket = response.json()[0] - - # Get Bucket Objects - with mock_client_response(client, "get", 200, mock_get_bucket_objects.return_value): - response = client.get(f"/v1/oci/objects/{bucket}/DEFAULT", headers=auth_headers["valid_auth"]) - assert response.status_code == 200 - assert response.json() == mock_get_bucket_objects.return_value - payload = response.json() - - # Download - assert mock_get_object is not None - with mock_client_response(client, "post", 200, mock_get_bucket_objects.return_value): - response = client.post( - f"/v1/oci/objects/download/{bucket}/DEFAULT", headers=auth_headers["valid_auth"], json=payload - ) - assert response.status_code == 200 - assert set(response.json()) == set(mock_get_bucket_objects.return_value) diff --git a/tests/server/integration/test_endpoints_settings.py b/tests/server/integration/test_endpoints_settings.py deleted file mode 100644 index 2eba8f74..00000000 --- a/tests/server/integration/test_endpoints_settings.py +++ /dev/null @@ -1,295 +0,0 @@ -""" -Copyright (c) 2024, 2025, Oracle and/or its affiliates. -Licensed under the Universal Permissive License v1.0 as shown at http://oss.oracle.com/licenses/upl. -""" -# spell-checker: disable -# pylint: disable=protected-access import-error import-outside-toplevel - -import pytest -from common.schema import ( - Settings, - LargeLanguageSettings, - VectorSearchSettings, - OciSettings, -) - - -############################################################################# -# Endpoints Test -############################################################################# -class TestEndpoints: - """Test Endpoints""" - - @pytest.mark.parametrize( - "auth_type, status_code", - [ - pytest.param("no_auth", 401, id="no_auth"), - pytest.param("invalid_auth", 401, id="invalid_auth"), - ], - ) - @pytest.mark.parametrize( - "endpoint, api_method", - [ - pytest.param("/v1/settings", "get", id="settings_get"), - pytest.param("/v1/settings", "patch", id="settings_update"), - pytest.param("/v1/settings", "post", id="settings_create"), - pytest.param("/v1/settings/load/file", "post", id="load_settings_from_file"), - pytest.param("/v1/settings/load/json", "post", id="load_settings_from_json"), - ], - ) - def test_invalid_auth_endpoints(self, client, auth_headers, endpoint, api_method, auth_type, status_code): - """Test endpoints require valid authentication.""" - response = getattr(client, api_method)(endpoint, headers=auth_headers[auth_type]) - assert response.status_code == status_code - - def test_settings_get(self, client, auth_headers): - """Test getting settings for a client""" - # Test getting settings for the test client - response = client.get("/v1/settings", headers=auth_headers["valid_auth"], params={"client": "default"}) - assert response.status_code == 200 - settings = response.json() - - # Verify the response contains the expected structure - assert settings["client"] == "default" - assert "ll_model" in settings - assert "vector_search" in settings - assert "oci" in settings - assert "database" in settings - assert "testbed" in settings - - def test_settings_get_nonexistent_client(self, client, auth_headers): - """Test getting settings for a non-existent client""" - response = client.get( - "/v1/settings", headers=auth_headers["valid_auth"], params={"client": "non_existant_client"} - ) - assert response.status_code == 404 - assert "not found" in response.json()["detail"] - - def test_settings_create(self, client, auth_headers): - """Test creating settings for a new client""" - new_client = "new_test_client" - - response = client.get("/v1/settings", headers=auth_headers["valid_auth"], params={"client": "default"}) - assert response.status_code == 200 - default_settings = response.json() - - # Create new client settings - response = client.post("/v1/settings", headers=auth_headers["valid_auth"], params={"client": new_client}) - assert response.status_code == 200 - - # Verify we can retrieve the settings - response = client.get("/v1/settings", headers=auth_headers["valid_auth"], params={"client": new_client}) - assert response.status_code == 200 - new_client_settings = response.json() - assert new_client_settings["client"] == new_client - - # Remove the client key to compare the rest - del default_settings["client"] - del new_client_settings["client"] - assert default_settings == new_client_settings - - def test_settings_create_existing_client(self, client, auth_headers) -> None: - """Test creating settings for an existing client""" - response = client.post("/v1/settings", headers=auth_headers["valid_auth"], params={"client": "default"}) - assert response.status_code == 409 - assert response.json() == {"detail": "Settings: client default already exists."} - - def test_settings_update(self, client, auth_headers): - """Test updating settings for a client""" - # First get the current settings - response = client.get("/v1/settings", headers=auth_headers["valid_auth"], params={"client": "default"}) - assert response.status_code == 200 - old_settings = response.json() - - # Modify some settings - updated_settings = Settings( - client="default", - ll_model=LargeLanguageSettings(model="updated-model", chat_history=False), - tools_enabled=["Vector Search"], - vector_search=VectorSearchSettings(grade=False, search_type="Similarity", top_k=5), - oci=OciSettings(auth_profile="UPDATED"), - ) - - # Update the settings - response = client.patch( - "/v1/settings", - headers=auth_headers["valid_auth"], - json=updated_settings.model_dump(), - params={"client": "default"}, - ) - assert response.status_code == 200 - new_settings = response.json() - - # Check old do not match update - assert old_settings != new_settings - - # Check that the values were updated - assert new_settings["ll_model"]["model"] == "updated-model" - assert new_settings["ll_model"]["chat_history"] is False - assert new_settings["tools_enabled"] == ["Vector Search"] - assert new_settings["vector_search"]["grade"] is False - assert new_settings["vector_search"]["top_k"] == 5 - assert new_settings["oci"]["auth_profile"] == "UPDATED" - - def test_settings_copy(self, client, auth_headers): - """Test copying settings for a client""" - # First get the current settings for the client - response = client.get("/v1/settings", headers=auth_headers["valid_auth"], params={"client": "default"}) - assert response.status_code == 200 - default_settings = response.json() - - response = client.get("/v1/settings", headers=auth_headers["valid_auth"], params={"client": "server"}) - assert response.status_code == 200 - old_server_settings = response.json() - - # Copy the client settings to the server settings - response = client.patch( - "/v1/settings", - headers=auth_headers["valid_auth"], - json=default_settings, - params={"client": "server"}, - ) - assert response.status_code == 200 - response = client.get("/v1/settings", headers=auth_headers["valid_auth"], params={"client": "server"}) - new_server_settings = response.json() - assert old_server_settings != new_server_settings - - del new_server_settings["client"] - del default_settings["client"] - assert new_server_settings == default_settings - - def test_settings_update_nonexistent_client(self, client, auth_headers): - """Test updating settings for a non-existent client""" - updated_settings = Settings(client="nonexistent_client", ll_model=LargeLanguageSettings(model="test-model")) - - response = client.patch( - "/v1/settings", - headers=auth_headers["valid_auth"], - json=updated_settings.model_dump(), - params={"client": "nonexistent_client"}, - ) - assert response.status_code == 404 - assert response.json() == {"detail": "Settings: client nonexistent_client not found."} - - def test_load_json_with_prompt_matching_default(self, client, auth_headers): - """Test uploading settings with prompt text that matches default""" - # Get current settings with prompts - response = client.get( - "/v1/settings", - headers=auth_headers["valid_auth"], - params={"client": "server", "full_config": True, "incl_sensitive": True}, - ) - assert response.status_code == 200 - original_config = response.json() - - if not original_config.get("prompt_configs"): - pytest.skip("No prompts available for testing") - - # Modify a prompt to custom text - test_prompt = original_config["prompt_configs"][0] - original_text = test_prompt["text"] - custom_text = "Custom test instruction - pirate" - test_prompt["text"] = custom_text - - # Upload with custom text (payload is Configuration schema directly) - response = client.post( - "/v1/settings/load/json", - headers=auth_headers["valid_auth"], - params={"client": "server"}, - json=original_config, - ) - assert response.status_code == 200 - - # Verify custom text is active - response = client.get("/v1/mcp/prompts", headers=auth_headers["valid_auth"], params={"full": True}) - prompts = response.json() - updated_prompt = next((p for p in prompts if p["name"] == test_prompt["name"]), None) - assert updated_prompt is not None - assert updated_prompt["text"] == custom_text - - # Now upload again with text matching the original - test_prompt["text"] = original_text - response = client.post( - "/v1/settings/load/json", - headers=auth_headers["valid_auth"], - params={"client": "server"}, - json=original_config, - ) - assert response.status_code == 200 - - # Verify the original text is now active (override was replaced) - response = client.get("/v1/mcp/prompts", headers=auth_headers["valid_auth"], params={"full": True}) - prompts = response.json() - reverted_prompt = next((p for p in prompts if p["name"] == test_prompt["name"]), None) - assert reverted_prompt is not None - assert reverted_prompt["text"] == original_text - - def test_load_json_with_alternating_prompt_text(self, client, auth_headers): - """Test uploading settings with alternating prompt text""" - # Get current settings - response = client.get( - "/v1/settings", - headers=auth_headers["valid_auth"], - params={"client": "server", "full_config": True, "incl_sensitive": True}, - ) - assert response.status_code == 200 - config = response.json() - - if not config.get("prompt_configs"): - pytest.skip("No prompts available for testing") - - test_prompt = config["prompt_configs"][0] - text_a = "Talk like a pirate" - text_b = "Talk like a pirate lady" - - # Upload with text A (payload is Configuration schema directly) - test_prompt["text"] = text_a - response = client.post( - "/v1/settings/load/json", - headers=auth_headers["valid_auth"], - params={"client": "server"}, - json=config, - ) - assert response.status_code == 200 - - response = client.get("/v1/mcp/prompts", headers=auth_headers["valid_auth"], params={"full": True}) - prompts = response.json() - prompt = next((p for p in prompts if p["name"] == test_prompt["name"]), None) - assert prompt["text"] == text_a - - # Upload with text B - test_prompt["text"] = text_b - response = client.post( - "/v1/settings/load/json", - headers=auth_headers["valid_auth"], - params={"client": "server"}, - json=config, - ) - assert response.status_code == 200 - - response = client.get("/v1/mcp/prompts", headers=auth_headers["valid_auth"], params={"full": True}) - prompts = response.json() - prompt = next((p for p in prompts if p["name"] == test_prompt["name"]), None) - assert prompt["text"] == text_b - - # Upload with text A again - test_prompt["text"] = text_a - response = client.post( - "/v1/settings/load/json", - headers=auth_headers["valid_auth"], - params={"client": "server"}, - json=config, - ) - assert response.status_code == 200 - - response = client.get("/v1/mcp/prompts", headers=auth_headers["valid_auth"], params={"full": True}) - prompts = response.json() - prompt = next((p for p in prompts if p["name"] == test_prompt["name"]), None) - assert prompt["text"] == text_a - - @pytest.mark.parametrize("app_server", ["/tmp/settings.json"], indirect=True) - def test_user_supplied_settings(self, app_server): - """Test the copy_user_settings function with a successful API call""" - assert app_server is not None - - # Test Logic diff --git a/tests/server/unit/api/v1/test_v1_embed.py b/tests/server/unit/api/v1/test_v1_embed.py deleted file mode 100644 index a4bf3006..00000000 --- a/tests/server/unit/api/v1/test_v1_embed.py +++ /dev/null @@ -1,64 +0,0 @@ -""" -Copyright (c) 2024, 2025, Oracle and/or its affiliates. -Licensed under the Universal Permissive License v1.0 as shown at http://oss.oracle.com/licenses/upl. -""" -# pylint: disable=protected-access - -import pytest -from server.api.v1.embed import _extract_provider_error_message - - -class TestExtractProviderErrorMessage: - """Test _extract_provider_error_message function""" - - def test_exception_with_message(self): - """Test extraction of exception with message""" - error = Exception("Something went wrong") - result = _extract_provider_error_message(error) - assert result == "Something went wrong" - - def test_exception_without_message(self): - """Test extraction of exception without message""" - error = ValueError() - result = _extract_provider_error_message(error) - assert result == "Error: ValueError" - - def test_openai_quota_exceeded(self): - """Test extraction of OpenAI quota exceeded error message""" - error_msg = ( - "Error code: 429 - {'error': {'message': 'You exceeded your current quota, " - "please check your plan and billing details.', 'type': 'insufficient_quota'}}" - ) - error = Exception(error_msg) - result = _extract_provider_error_message(error) - assert result == error_msg - - def test_openai_rate_limit(self): - """Test extraction of OpenAI rate limit error message""" - error_msg = "Rate limit exceeded. Please try again later." - error = Exception(error_msg) - result = _extract_provider_error_message(error) - assert result == error_msg - - def test_complex_error_message(self): - """Test extraction of complex multi-line error message""" - error_msg = "Connection failed\nTimeout: 30s\nHost: api.example.com" - error = Exception(error_msg) - result = _extract_provider_error_message(error) - assert result == error_msg - - @pytest.mark.parametrize( - "error_message", - [ - "OpenAI API key is invalid", - "Cohere API error occurred", - "OCI service error", - "Database connection failed", - "Rate limit exceeded for model xyz", - ], - ) - def test_various_error_messages(self, error_message): - """Test that various error messages are passed through correctly""" - error = Exception(error_message) - result = _extract_provider_error_message(error) - assert result == error_message diff --git a/tests/server/unit/bootstrap/test_bootstrap.py b/tests/server/unit/bootstrap/test_bootstrap.py deleted file mode 100644 index 9f28e5ed..00000000 --- a/tests/server/unit/bootstrap/test_bootstrap.py +++ /dev/null @@ -1,22 +0,0 @@ -""" -Copyright (c) 2024, 2025, Oracle and/or its affiliates. -Licensed under the Universal Permissive License v1.0 as shown at http://oss.oracle.com/licenses/upl. -""" -# spell-checker: disable - -# ============================================================================= -# DEPRECATED: Tests in this file have been replaced by more comprehensive tests -# in test/unit/server/bootstrap/test_bootstrap_bootstrap.py -# ============================================================================= -# -# test_module_imports_and_initialization -> Replaced by: -# - test/unit/server/bootstrap/test_bootstrap_bootstrap.py::TestBootstrapModule::test_calls_all_bootstrap_functions -# - test/unit/server/bootstrap/test_bootstrap_bootstrap.py::TestBootstrapModule::test_database_objects_is_list -# - test/unit/server/bootstrap/test_bootstrap_bootstrap.py::TestBootstrapModule::test_model_objects_is_list -# - test/unit/server/bootstrap/test_bootstrap_bootstrap.py::TestBootstrapModule::test_oci_objects_is_list -# - test/unit/server/bootstrap/test_bootstrap_bootstrap.py::TestBootstrapModule::test_settings_objects_is_list -# -# test_logger_exists -> Replaced by: -# - test/unit/server/bootstrap/test_bootstrap_bootstrap.py::TestLoggerConfiguration::test_logger_exists -# - test/unit/server/bootstrap/test_bootstrap_bootstrap.py::TestLoggerConfiguration::test_logger_name -# From eeb21351a7f7f20e23e6816158996ab34c9f9eca Mon Sep 17 00:00:00 2001 From: gotsysdba Date: Sat, 29 Nov 2025 09:35:46 +0000 Subject: [PATCH 08/20] common code tests --- test/integration/common/__init__.py | 6 + test/integration/common/test_functions.py | 194 +++++++ test/unit/common/__init__.py | 6 + test/unit/common/test_functions.py | 467 ++++++++++++++++ test/unit/common/test_help_text.py | 189 +++++++ test/unit/common/test_logging_config.py | 273 ++++++++++ test/unit/common/test_schema.py | 621 ++++++++++++++++++++++ test/unit/common/test_version.py | 36 ++ test/unit/server/api/v1/test_v1_embed.py | 50 +- tests/common/test_functions_sql.py | 140 +---- 10 files changed, 1823 insertions(+), 159 deletions(-) create mode 100644 test/integration/common/__init__.py create mode 100644 test/integration/common/test_functions.py create mode 100644 test/unit/common/__init__.py create mode 100644 test/unit/common/test_functions.py create mode 100644 test/unit/common/test_help_text.py create mode 100644 test/unit/common/test_logging_config.py create mode 100644 test/unit/common/test_schema.py create mode 100644 test/unit/common/test_version.py diff --git a/test/integration/common/__init__.py b/test/integration/common/__init__.py new file mode 100644 index 00000000..d63a0614 --- /dev/null +++ b/test/integration/common/__init__.py @@ -0,0 +1,6 @@ +""" +Copyright (c) 2024, 2025, Oracle and/or its affiliates. +Licensed under the Universal Permissive License v1.0 as shown at http://oss.oracle.com/licenses/upl. + +Integration tests for common module. +""" diff --git a/test/integration/common/test_functions.py b/test/integration/common/test_functions.py new file mode 100644 index 00000000..04f30c67 --- /dev/null +++ b/test/integration/common/test_functions.py @@ -0,0 +1,194 @@ +""" +Copyright (c) 2024, 2025, Oracle and/or its affiliates. +Licensed under the Universal Permissive License v1.0 as shown at http://oss.oracle.com/licenses/upl. + +Integration tests for common/functions.py + +Tests functions that interact with external systems (URLs, databases). +These tests may require network access or database connectivity. +""" + +import os +import tempfile + +from test.conftest import TEST_CONFIG + +import pytest + +from common import functions + + +class TestIsUrlAccessibleIntegration: + """Integration tests for is_url_accessible function.""" + + @pytest.mark.integration + def test_real_accessible_url(self): + """is_url_accessible should return True for known accessible URLs.""" + # Using httpbin.org which is a testing service + result, msg = functions.is_url_accessible("https://httpbin.org/status/200") + + assert result is True + assert msg is None + + @pytest.mark.integration + def test_real_inaccessible_url(self): + """is_url_accessible should return False for non-existent URLs.""" + result, msg = functions.is_url_accessible("https://this-domain-does-not-exist-xyz123.com") + + assert result is False + assert msg is not None + + +class TestGetVsTableIntegration: + """Integration tests for get_vs_table function.""" + + def test_roundtrip_table_comment(self): + """get_vs_table output should be parseable by parse_vs_comment.""" + _, comment = functions.get_vs_table( + model="cohere-embed-english-v3", + chunk_size=2048, + chunk_overlap=256, + distance_metric="DOT_PRODUCT", + index_type="IVF", + alias="integration_alias", + description="Integration test description", + ) + + # Parse the generated comment + parsed = functions.parse_vs_comment(comment) + + assert parsed["parse_status"] == "success" + assert parsed["alias"] == "integration_alias" + assert parsed["description"] == "Integration test description" + assert parsed["model"] == "cohere-embed-english-v3" + assert parsed["chunk_size"] == 2048 + assert parsed["chunk_overlap"] == 256 + assert parsed["distance_metric"] == "DOT_PRODUCT" + assert parsed["index_type"] == "IVF" + + def test_roundtrip_with_genai_prefix(self): + """parse_vs_comment should handle GENAI prefix correctly.""" + _, comment = functions.get_vs_table( + model="test-model", + chunk_size=500, + chunk_overlap=50, + distance_metric="DOT_PRODUCT", + index_type="IVF", + alias="test", + ) + + # Add GENAI prefix as it would be stored in database + prefixed_comment = f"GENAI: {comment}" + + parsed = functions.parse_vs_comment(prefixed_comment) + + assert parsed["parse_status"] == "success" + assert parsed["alias"] == "test" + assert parsed["model"] == "test-model" + + def test_table_name_uniqueness(self): + """Different parameters should generate different table names.""" + table1, _ = functions.get_vs_table( + model="model-a", + chunk_size=1000, + chunk_overlap=100, + distance_metric="COSINE", + ) + + table2, _ = functions.get_vs_table( + model="model-b", + chunk_size=1000, + chunk_overlap=100, + distance_metric="COSINE", + ) + + table3, _ = functions.get_vs_table( + model="model-a", + chunk_size=500, + chunk_overlap=100, + distance_metric="COSINE", + ) + + assert table1 != table2 + assert table1 != table3 + assert table2 != table3 + + +class TestDatabaseFunctionsIntegration: + """Integration tests for database functions. + + These tests are marked with db_container to indicate they require + a real database connection. + """ + + @pytest.mark.db_container + def test_is_sql_accessible_with_real_database(self, db_container): + """is_sql_accessible should return True for valid database and query.""" + # pylint: disable=unused-argument + # Connection string format: username/password@dsn + db_conn = f"{TEST_CONFIG['db_username']}/{TEST_CONFIG['db_password']}@{TEST_CONFIG['db_dsn']}" + # Must use VARCHAR2 - the function checks column type is VARCHAR, not CHAR + query = "SELECT CAST('test' AS VARCHAR2(10)) FROM dual" + + result, msg = functions.is_sql_accessible(db_conn, query) + + assert result is True + assert msg == "" + + @pytest.mark.db_container + def test_is_sql_accessible_invalid_credentials(self, db_container): + """is_sql_accessible should return False for invalid credentials.""" + # pylint: disable=unused-argument + db_conn = f"INVALID_USER/INVALID_PASSWORD@{TEST_CONFIG['db_dsn']}" + query = "SELECT 'test' FROM dual" + + result, msg = functions.is_sql_accessible(db_conn, query) + + assert result is False + assert "error" in msg.lower() + + @pytest.mark.db_container + def test_is_sql_accessible_wrong_column_count(self, db_container): + """is_sql_accessible should return False when query returns multiple columns.""" + # pylint: disable=unused-argument + db_conn = f"{TEST_CONFIG['db_username']}/{TEST_CONFIG['db_password']}@{TEST_CONFIG['db_dsn']}" + query = "SELECT 'a', 'b' FROM dual" # Two columns - should fail + + result, msg = functions.is_sql_accessible(db_conn, query) + + assert result is False + assert "columns" in msg.lower() + + @pytest.mark.db_container + def test_run_sql_query_with_real_database(self, db_container): + """run_sql_query should execute SQL and save results to CSV.""" + # pylint: disable=unused-argument + db_conn = f"{TEST_CONFIG['db_username']}/{TEST_CONFIG['db_password']}@{TEST_CONFIG['db_dsn']}" + query = "SELECT 'value1' AS col1, 'value2' AS col2 FROM dual" + + with tempfile.TemporaryDirectory() as tmpdir: + result = functions.run_sql_query(db_conn, query, tmpdir) + + # Should return the file path + assert result is not False + assert result.endswith(".csv") + + # File should exist and contain data + assert os.path.exists(result) + with open(result, encoding="utf-8") as f: + content = f.read() + assert "COL1" in content or "col1" in content.lower() + assert "value1" in content + + @pytest.mark.db_container + def test_run_sql_query_invalid_connection(self, db_container): + """run_sql_query should return falsy value for invalid connection.""" + # pylint: disable=unused-argument + db_conn = f"INVALID_USER/INVALID_PASSWORD@{TEST_CONFIG['db_dsn']}" + query = "SELECT 'test' FROM dual" + + with tempfile.TemporaryDirectory() as tmpdir: + result = functions.run_sql_query(db_conn, query, tmpdir) + + # Function returns '' or False on error + assert not result diff --git a/test/unit/common/__init__.py b/test/unit/common/__init__.py new file mode 100644 index 00000000..e4188e30 --- /dev/null +++ b/test/unit/common/__init__.py @@ -0,0 +1,6 @@ +""" +Copyright (c) 2024, 2025, Oracle and/or its affiliates. +Licensed under the Universal Permissive License v1.0 as shown at http://oss.oracle.com/licenses/upl. + +Unit tests for common module. +""" diff --git a/test/unit/common/test_functions.py b/test/unit/common/test_functions.py new file mode 100644 index 00000000..e565a20f --- /dev/null +++ b/test/unit/common/test_functions.py @@ -0,0 +1,467 @@ +""" +Copyright (c) 2024, 2025, Oracle and/or its affiliates. +Licensed under the Universal Permissive License v1.0 as shown at http://oss.oracle.com/licenses/upl. + +Unit tests for common/functions.py + +Tests utility functions for URL checking, vector store operations, and SQL operations. +""" + +import json +import os +import tempfile +from unittest.mock import patch, MagicMock +import requests +import oracledb + +from common import functions + + +class TestIsUrlAccessible: + """Tests for is_url_accessible function.""" + + def test_empty_url_returns_false(self): + """is_url_accessible should return False for empty URL.""" + result, msg = functions.is_url_accessible("") + assert result is False + assert msg == "No URL Provided" + + def test_none_url_returns_false(self): + """is_url_accessible should return False for None URL.""" + result, msg = functions.is_url_accessible(None) + assert result is False + assert msg == "No URL Provided" + + @patch("common.functions.requests.get") + def test_successful_200_response(self, mock_get): + """is_url_accessible should return True for 200 response.""" + mock_get.return_value = MagicMock(status_code=200) + + result, msg = functions.is_url_accessible("http://example.com") + + assert result is True + assert msg is None + mock_get.assert_called_once_with("http://example.com", timeout=2) + + @patch("common.functions.requests.get") + def test_successful_403_response(self, mock_get): + """is_url_accessible should return True for 403 response (accessible but forbidden).""" + mock_get.return_value = MagicMock(status_code=403) + + result, msg = functions.is_url_accessible("http://example.com") + + assert result is True + assert msg is None + + @patch("common.functions.requests.get") + def test_successful_404_response(self, mock_get): + """is_url_accessible should return True for 404 response (server accessible).""" + mock_get.return_value = MagicMock(status_code=404) + + result, msg = functions.is_url_accessible("http://example.com") + + assert result is True + assert msg is None + + @patch("common.functions.requests.get") + def test_successful_421_response(self, mock_get): + """is_url_accessible should return True for 421 response.""" + mock_get.return_value = MagicMock(status_code=421) + + result, msg = functions.is_url_accessible("http://example.com") + + assert result is True + assert msg is None + + @patch("common.functions.requests.get") + def test_unsuccessful_500_response(self, mock_get): + """is_url_accessible should return False for 500 response.""" + mock_get.return_value = MagicMock(status_code=500) + + result, msg = functions.is_url_accessible("http://example.com") + + assert result is False + assert "not accessible" in msg + assert "500" in msg + + @patch("common.functions.requests.get") + def test_connection_error(self, mock_get): + """is_url_accessible should return False for connection errors.""" + mock_get.side_effect = requests.exceptions.ConnectionError("Connection failed") + + result, msg = functions.is_url_accessible("http://example.com") + + assert result is False + assert "not accessible" in msg + assert "ConnectionError" in msg + + @patch("common.functions.requests.get") + def test_timeout_error(self, mock_get): + """is_url_accessible should return False for timeout errors.""" + mock_get.side_effect = requests.exceptions.Timeout("Request timed out") + + result, msg = functions.is_url_accessible("http://example.com") + + assert result is False + assert "not accessible" in msg + assert "Timeout" in msg + + +class TestGetVsTable: + """Tests for get_vs_table function.""" + + def test_basic_table_name_generation(self): + """get_vs_table should generate correct table name.""" + table, comment = functions.get_vs_table( + model="text-embedding-3-small", + chunk_size=512, + chunk_overlap=50, + distance_metric="COSINE", + index_type="HNSW", + ) + + assert table == "TEXT_EMBEDDING_3_SMALL_512_50_COSINE_HNSW" + assert comment is not None + + def test_table_name_with_alias(self): + """get_vs_table should include alias in table name.""" + table, _ = functions.get_vs_table( + model="test-model", + chunk_size=500, + chunk_overlap=50, + distance_metric="EUCLIDEAN_DISTANCE", + alias="myalias", + ) + + assert table.startswith("MYALIAS_") + assert "TEST_MODEL" in table + + def test_special_characters_replaced(self): + """get_vs_table should replace special characters with underscores.""" + table, _ = functions.get_vs_table( + model="openai/gpt-4", + chunk_size=1000, + chunk_overlap=100, + distance_metric="COSINE", + ) + + assert "/" not in table + assert "-" not in table + assert "_" in table + + def test_chunk_overlap_ceiling(self): + """get_vs_table should use ceiling for chunk_overlap.""" + table, comment = functions.get_vs_table( + model="test", + chunk_size=1000, + chunk_overlap=99.5, + distance_metric="COSINE", + ) + + assert "100" in table + parsed_comment = json.loads(comment) + assert parsed_comment["chunk_overlap"] == 100 + + def test_comment_json_structure(self): + """get_vs_table should generate valid JSON comment.""" + _, comment = functions.get_vs_table( + model="test-model", + chunk_size=1000, + chunk_overlap=100, + distance_metric="COSINE", + index_type="HNSW", + alias="test_alias", + description="Test description", + ) + + parsed = json.loads(comment) + assert parsed["alias"] == "test_alias" + assert parsed["description"] == "Test description" + assert parsed["model"] == "test-model" + assert parsed["chunk_size"] == 1000 + assert parsed["chunk_overlap"] == 100 + assert parsed["distance_metric"] == "COSINE" + assert parsed["index_type"] == "HNSW" + + def test_comment_null_description(self): + """get_vs_table should include null description when not provided.""" + _, comment = functions.get_vs_table( + model="test", + chunk_size=1000, + chunk_overlap=100, + distance_metric="COSINE", + ) + + parsed = json.loads(comment) + assert parsed["description"] is None + + def test_default_index_type(self): + """get_vs_table should default to HNSW index type.""" + table, _ = functions.get_vs_table( + model="test", + chunk_size=1000, + chunk_overlap=100, + distance_metric="COSINE", + ) + + assert "HNSW" in table + + def test_missing_required_values_returns_none(self): + """get_vs_table should return None for missing required values.""" + table, comment = functions.get_vs_table( + model=None, + chunk_size=None, + chunk_overlap=None, + distance_metric=None, + ) + + assert table is None + assert comment is None + + +class TestParseVsComment: + """Tests for parse_vs_comment function.""" + + def test_empty_comment_returns_defaults(self): + """parse_vs_comment should return defaults for empty comment.""" + result = functions.parse_vs_comment("") + + assert result["alias"] is None + assert result["description"] is None + assert result["model"] is None + assert result["parse_status"] == "no_comment" + + def test_none_comment_returns_defaults(self): + """parse_vs_comment should return defaults for None comment.""" + result = functions.parse_vs_comment(None) + + assert result["parse_status"] == "no_comment" + + def test_valid_json_comment(self): + """parse_vs_comment should parse valid JSON comment.""" + comment = json.dumps({ + "alias": "test_alias", + "description": "Test description", + "model": "test-model", + "chunk_size": 1000, + "chunk_overlap": 100, + "distance_metric": "COSINE", + "index_type": "HNSW", + }) + + result = functions.parse_vs_comment(comment) + + assert result["alias"] == "test_alias" + assert result["description"] == "Test description" + assert result["model"] == "test-model" + assert result["chunk_size"] == 1000 + assert result["chunk_overlap"] == 100 + assert result["distance_metric"] == "COSINE" + assert result["index_type"] == "HNSW" + assert result["parse_status"] == "success" + + def test_genai_prefix_stripped(self): + """parse_vs_comment should strip 'GENAI: ' prefix.""" + comment = 'GENAI: {"alias": "test", "model": "test-model"}' + + result = functions.parse_vs_comment(comment) + + assert result["alias"] == "test" + assert result["model"] == "test-model" + assert result["parse_status"] == "success" + + def test_missing_description_backward_compat(self): + """parse_vs_comment should handle missing description for backward compatibility.""" + comment = json.dumps({ + "alias": "test", + "model": "test-model", + }) + + result = functions.parse_vs_comment(comment) + + assert result["description"] is None + assert result["parse_status"] == "success" + + def test_invalid_json_returns_error(self): + """parse_vs_comment should return error for invalid JSON.""" + result = functions.parse_vs_comment("not valid json") + + assert "parse_error" in result["parse_status"] + + +class TestIsSqlAccessible: + """Tests for is_sql_accessible function.""" + + def test_empty_connection_returns_false(self): + """is_sql_accessible should return False for empty connection.""" + result, _ = functions.is_sql_accessible("", "SELECT 1") + assert result is False + + def test_empty_query_returns_false(self): + """is_sql_accessible should return False for empty query.""" + result, _ = functions.is_sql_accessible("user/pass@dsn", "") + assert result is False + + def test_invalid_connection_string_format(self): + """is_sql_accessible should handle invalid connection string format.""" + result, msg = functions.is_sql_accessible("invalid_format", "SELECT 1") + + assert result is False + # The function may fail at connection string parsing or at actual connection + assert msg is not None + + @patch("common.functions.oracledb.connect") + def test_database_error(self, mock_connect): + """is_sql_accessible should return False for database errors.""" + mock_connect.side_effect = oracledb.Error("Connection failed") + + result, msg = functions.is_sql_accessible("user/pass@localhost/db", "SELECT 1") + + assert result is False + assert "connection error" in msg + + @patch("common.functions.oracledb.connect") + def test_empty_result_returns_false(self, mock_connect): + """is_sql_accessible should return False when query returns no rows.""" + mock_cursor = MagicMock() + mock_cursor.fetchmany.return_value = [] + mock_cursor.description = [("COL1", oracledb.DB_TYPE_VARCHAR, None, None, None, None, None)] + + mock_conn = MagicMock() + mock_conn.cursor.return_value.__enter__ = MagicMock(return_value=mock_cursor) + mock_conn.cursor.return_value.__exit__ = MagicMock(return_value=False) + mock_connect.return_value.__enter__ = MagicMock(return_value=mock_conn) + mock_connect.return_value.__exit__ = MagicMock(return_value=False) + + result, msg = functions.is_sql_accessible("user/pass@localhost/db", "SELECT col FROM table") + + assert result is False + assert "empty table" in msg + + @patch("common.functions.oracledb.connect") + def test_multiple_columns_returns_false(self, mock_connect): + """is_sql_accessible should return False when query returns multiple columns.""" + mock_cursor = MagicMock() + mock_cursor.fetchmany.return_value = [("value1", "value2")] + mock_cursor.description = [ + ("COL1", oracledb.DB_TYPE_VARCHAR, None, None, None, None, None), + ("COL2", oracledb.DB_TYPE_VARCHAR, None, None, None, None, None), + ] + + mock_conn = MagicMock() + mock_conn.cursor.return_value.__enter__ = MagicMock(return_value=mock_cursor) + mock_conn.cursor.return_value.__exit__ = MagicMock(return_value=False) + mock_connect.return_value.__enter__ = MagicMock(return_value=mock_conn) + mock_connect.return_value.__exit__ = MagicMock(return_value=False) + + result, msg = functions.is_sql_accessible("user/pass@localhost/db", "SELECT col1, col2 FROM table") + + assert result is False + assert "returns 2 columns" in msg + + @patch("common.functions.oracledb.connect") + def test_valid_sql_connection_and_query(self, mock_connect): + """is_sql_accessible should return True for valid connection and query.""" + mock_cursor = MagicMock() + mock_cursor.description = [MagicMock(type=oracledb.DB_TYPE_VARCHAR)] + mock_cursor.fetchmany.return_value = [("row1",), ("row2",), ("row3",)] + + mock_conn = MagicMock() + mock_conn.cursor.return_value.__enter__ = MagicMock(return_value=mock_cursor) + mock_conn.cursor.return_value.__exit__ = MagicMock(return_value=False) + mock_connect.return_value.__enter__ = MagicMock(return_value=mock_conn) + mock_connect.return_value.__exit__ = MagicMock(return_value=False) + + result, msg = functions.is_sql_accessible("user/pass@localhost/db", "SELECT text FROM documents") + + assert result is True + assert msg == "" + + @patch("common.functions.oracledb.connect") + def test_invalid_column_type_returns_false(self, mock_connect): + """is_sql_accessible should return False for non-VARCHAR column type.""" + mock_cursor = MagicMock() + mock_cursor.description = [MagicMock(type=oracledb.DB_TYPE_NUMBER)] + mock_cursor.fetchmany.return_value = [(123,)] + + mock_conn = MagicMock() + mock_conn.cursor.return_value.__enter__ = MagicMock(return_value=mock_cursor) + mock_conn.cursor.return_value.__exit__ = MagicMock(return_value=False) + mock_connect.return_value.__enter__ = MagicMock(return_value=mock_conn) + mock_connect.return_value.__exit__ = MagicMock(return_value=False) + + result, msg = functions.is_sql_accessible("user/pass@localhost/db", "SELECT id FROM table") + + assert result is False + assert "VARCHAR" in msg + + @patch("common.functions.oracledb.connect") + def test_nvarchar_column_type_accepted(self, mock_connect): + """is_sql_accessible should accept NVARCHAR column type as valid.""" + mock_cursor = MagicMock() + mock_cursor.description = [MagicMock(type=oracledb.DB_TYPE_NVARCHAR)] + mock_cursor.fetchmany.return_value = [("text1",), ("text2",)] + + mock_conn = MagicMock() + mock_conn.cursor.return_value.__enter__ = MagicMock(return_value=mock_cursor) + mock_conn.cursor.return_value.__exit__ = MagicMock(return_value=False) + mock_connect.return_value.__enter__ = MagicMock(return_value=mock_conn) + mock_connect.return_value.__exit__ = MagicMock(return_value=False) + + result, msg = functions.is_sql_accessible("user/pass@localhost/db", "SELECT ntext FROM table") + + assert result is True + assert msg == "" + + +class TestRunSqlQuery: + """Tests for run_sql_query function.""" + + def test_empty_connection_returns_false(self): + """run_sql_query should return False for empty connection.""" + result = functions.run_sql_query("", "SELECT 1", "/tmp") + assert result is False + + def test_invalid_connection_string_format(self): + """run_sql_query should return False for invalid connection string.""" + result = functions.run_sql_query("invalid_format", "SELECT 1", "/tmp") + assert result is False + + @patch("common.functions.oracledb.connect") + def test_database_error_returns_empty(self, mock_connect): + """run_sql_query should return empty string for database errors.""" + mock_connect.side_effect = oracledb.Error("Connection failed") + + result = functions.run_sql_query("user/pass@localhost/db", "SELECT 1", "/tmp") + + assert result == "" + + @patch("common.functions.oracledb.connect") + def test_successful_query_creates_csv(self, mock_connect): + """run_sql_query should create CSV file with query results.""" + mock_cursor = MagicMock() + mock_cursor.description = [ + ("COL1", None, None, None, None, None, None), + ("COL2", None, None, None, None, None, None), + ] + mock_cursor.fetchmany.side_effect = [ + [("val1", "val2"), ("val3", "val4")], + [], # Second call returns empty to end loop + ] + + mock_conn = MagicMock() + mock_conn.cursor.return_value.__enter__ = MagicMock(return_value=mock_cursor) + mock_conn.cursor.return_value.__exit__ = MagicMock(return_value=False) + mock_connect.return_value.__enter__ = MagicMock(return_value=mock_conn) + mock_connect.return_value.__exit__ = MagicMock(return_value=False) + + with tempfile.TemporaryDirectory() as tmpdir: + result = functions.run_sql_query("user/pass@localhost/db", "SELECT * FROM table", tmpdir) + + assert result.endswith(".csv") + assert os.path.exists(result) + + with open(result, "r", encoding="utf-8") as f: + content = f.read() + assert "COL1,COL2" in content + assert "val1,val2" in content diff --git a/test/unit/common/test_help_text.py b/test/unit/common/test_help_text.py new file mode 100644 index 00000000..f35d298d --- /dev/null +++ b/test/unit/common/test_help_text.py @@ -0,0 +1,189 @@ +""" +Copyright (c) 2024, 2025, Oracle and/or its affiliates. +Licensed under the Universal Permissive License v1.0 as shown at http://oss.oracle.com/licenses/upl. + +Unit tests for common/help_text.py + +Tests help text dictionary contents and structure. +""" + +from common import help_text + + +class TestHelpDict: + """Tests for help_dict dictionary.""" + + def test_help_dict_is_dictionary(self): + """help_dict should be a dictionary.""" + assert isinstance(help_text.help_dict, dict) + + def test_help_dict_has_expected_keys(self): + """help_dict should contain all expected keys.""" + expected_keys = [ + "max_input_tokens", + "temperature", + "max_tokens", + "top_p", + "frequency_penalty", + "presence_penalty", + "vector_search", + "rerank", + "top_k", + "score_threshold", + "fetch_k", + "lambda_mult", + "embed_alias", + "chunk_overlap", + "chunk_size", + "index_type", + "distance_metric", + "model_id", + "model_provider", + "model_url", + "model_api_key", + ] + + for key in expected_keys: + assert key in help_text.help_dict, f"Missing expected key: {key}" + + def test_all_values_are_strings(self): + """All values in help_dict should be strings.""" + for key, value in help_text.help_dict.items(): + assert isinstance(value, str), f"Value for {key} is not a string" + + def test_all_values_are_non_empty(self): + """All values in help_dict should be non-empty.""" + for key, value in help_text.help_dict.items(): + assert len(value.strip()) > 0, f"Value for {key} is empty" + + +class TestModelParameters: + """Tests for model parameter help texts.""" + + def test_max_input_tokens_help(self): + """max_input_tokens help should explain context window.""" + help_text_value = help_text.help_dict["max_input_tokens"] + assert "token" in help_text_value.lower() + assert "model" in help_text_value.lower() + + def test_temperature_help(self): + """temperature help should explain creativity control.""" + help_text_value = help_text.help_dict["temperature"] + assert "creative" in help_text_value.lower() + assert "top p" in help_text_value.lower() + + def test_max_tokens_help(self): + """max_tokens help should explain response length.""" + help_text_value = help_text.help_dict["max_tokens"] + assert "length" in help_text_value.lower() or "response" in help_text_value.lower() + + def test_top_p_help(self): + """top_p help should explain probability threshold.""" + help_text_value = help_text.help_dict["top_p"] + assert "word" in help_text_value.lower() + assert "temperature" in help_text_value.lower() + + def test_frequency_penalty_help(self): + """frequency_penalty help should explain repetition control.""" + help_text_value = help_text.help_dict["frequency_penalty"] + assert "repeat" in help_text_value.lower() + + def test_presence_penalty_help(self): + """presence_penalty help should explain topic diversity.""" + help_text_value = help_text.help_dict["presence_penalty"] + assert "topic" in help_text_value.lower() or "new" in help_text_value.lower() + + +class TestVectorSearchParameters: + """Tests for vector search parameter help texts.""" + + def test_vector_search_help(self): + """vector_search help should explain the feature.""" + help_text_value = help_text.help_dict["vector_search"] + assert "vector" in help_text_value.lower() + + def test_rerank_help(self): + """rerank help should explain document reranking.""" + help_text_value = help_text.help_dict["rerank"] + assert "document" in help_text_value.lower() + assert "relevan" in help_text_value.lower() + + def test_top_k_help(self): + """top_k help should explain document retrieval count.""" + help_text_value = help_text.help_dict["top_k"] + assert "document" in help_text_value.lower() or "retrieved" in help_text_value.lower() + + def test_score_threshold_help(self): + """score_threshold help should explain minimum similarity.""" + help_text_value = help_text.help_dict["score_threshold"] + assert "similarity" in help_text_value.lower() or "threshold" in help_text_value.lower() + + def test_fetch_k_help(self): + """fetch_k help should explain initial fetch count.""" + help_text_value = help_text.help_dict["fetch_k"] + assert "document" in help_text_value.lower() + assert "fetch" in help_text_value.lower() + + def test_lambda_mult_help(self): + """lambda_mult help should explain diversity.""" + help_text_value = help_text.help_dict["lambda_mult"] + assert "diversity" in help_text_value.lower() + + +class TestEmbeddingParameters: + """Tests for embedding parameter help texts.""" + + def test_embed_alias_help(self): + """embed_alias help should explain aliasing.""" + help_text_value = help_text.help_dict["embed_alias"] + assert "alias" in help_text_value.lower() + assert "vector" in help_text_value.lower() or "embed" in help_text_value.lower() + + def test_chunk_overlap_help(self): + """chunk_overlap help should explain overlap percentage.""" + help_text_value = help_text.help_dict["chunk_overlap"] + assert "overlap" in help_text_value.lower() + assert "chunk" in help_text_value.lower() + + def test_chunk_size_help(self): + """chunk_size help should explain chunk length.""" + help_text_value = help_text.help_dict["chunk_size"] + assert "chunk" in help_text_value.lower() + assert "length" in help_text_value.lower() + + def test_index_type_help(self): + """index_type help should explain HNSW and IVF.""" + help_text_value = help_text.help_dict["index_type"] + assert "hnsw" in help_text_value.lower() + assert "ivf" in help_text_value.lower() + + def test_distance_metric_help(self): + """distance_metric help should explain distance calculation.""" + help_text_value = help_text.help_dict["distance_metric"] + assert "distance" in help_text_value.lower() or "similar" in help_text_value.lower() + + +class TestModelConfiguration: + """Tests for model configuration help texts.""" + + def test_model_id_help(self): + """model_id help should explain model naming.""" + help_text_value = help_text.help_dict["model_id"] + assert "model" in help_text_value.lower() + assert "name" in help_text_value.lower() + + def test_model_provider_help(self): + """model_provider help should explain provider selection.""" + help_text_value = help_text.help_dict["model_provider"] + assert "provider" in help_text_value.lower() + + def test_model_url_help(self): + """model_url help should explain API URL.""" + help_text_value = help_text.help_dict["model_url"] + assert "api" in help_text_value.lower() or "url" in help_text_value.lower() + + def test_model_api_key_help(self): + """model_api_key help should explain API key.""" + help_text_value = help_text.help_dict["model_api_key"] + assert "api" in help_text_value.lower() + assert "key" in help_text_value.lower() diff --git a/test/unit/common/test_logging_config.py b/test/unit/common/test_logging_config.py new file mode 100644 index 00000000..0c8fa644 --- /dev/null +++ b/test/unit/common/test_logging_config.py @@ -0,0 +1,273 @@ +""" +Copyright (c) 2024, 2025, Oracle and/or its affiliates. +Licensed under the Universal Permissive License v1.0 as shown at http://oss.oracle.com/licenses/upl. + +Unit tests for common/logging_config.py + +Tests logging configuration, filters, and formatters. +""" +# pylint: disable=too-few-public-methods, protected-access + +import logging +import asyncio +import sys + +from common import logging_config +from common._version import __version__ + + +class TestVersionFilter: + """Tests for VersionFilter logging filter.""" + + def test_version_filter_injects_version(self): + """VersionFilter should inject __version__ into log records.""" + filter_instance = logging_config.VersionFilter() + record = logging.LogRecord( + name="test", + level=logging.INFO, + pathname="", + lineno=0, + msg="Test message", + args=(), + exc_info=None, + ) + + result = filter_instance.filter(record) + + assert result is True + assert hasattr(record, "__version__") + assert getattr(record, "__version__") == __version__ + + +class TestPrettifyCancelledError: + """Tests for PrettifyCancelledError logging filter.""" + + def test_filter_returns_true_for_normal_records(self): + """PrettifyCancelledError should pass through normal records.""" + filter_instance = logging_config.PrettifyCancelledError() + record = logging.LogRecord( + name="test", + level=logging.INFO, + pathname="", + lineno=0, + msg="Normal message", + args=(), + exc_info=None, + ) + + result = filter_instance.filter(record) + + assert result is True + assert record.msg == "Normal message" + + def test_filter_modifies_cancelled_error_record(self): + """PrettifyCancelledError should modify CancelledError records.""" + filter_instance = logging_config.PrettifyCancelledError() + + exc_info = None + try: + raise asyncio.CancelledError() + except asyncio.CancelledError: + exc_info = sys.exc_info() + + record = logging.LogRecord( + name="test", + level=logging.ERROR, + pathname="", + lineno=0, + msg="Original message", + args=(), + exc_info=exc_info, + ) + + result = filter_instance.filter(record) + + assert result is True + assert record.exc_info is None + assert "graceful timeout" in record.msg.lower() + assert record.levelno == logging.WARNING + assert record.levelname == "WARNING" + + def test_filter_handles_exception_group_with_cancelled(self): + """PrettifyCancelledError should handle ExceptionGroup with CancelledError.""" + filter_instance = logging_config.PrettifyCancelledError() + + # Create an ExceptionGroup containing a regular Exception wrapping CancelledError + # Note: CancelledError is a BaseException, so we need to wrap it properly + # Using a regular exception that contains a nested CancelledError simulation + exc_info = None + try: + # Create an exception group with a regular exception + exc_group = ExceptionGroup("test group", [ValueError("test")]) + raise exc_group + except ExceptionGroup: + exc_info = sys.exc_info() + + record = logging.LogRecord( + name="test", + level=logging.ERROR, + pathname="", + lineno=0, + msg="Original message", + args=(), + exc_info=exc_info, + ) + + # This should pass through since ValueError is not CancelledError + result = filter_instance.filter(record) + + assert result is True + # Regular exceptions are not modified + assert record.msg == "Original message" + + def test_contains_cancelled_direct(self): + """_contains_cancelled should return True for direct CancelledError.""" + filter_instance = logging_config.PrettifyCancelledError() + + cancelled = asyncio.CancelledError() + assert filter_instance._contains_cancelled(cancelled) is True + + def test_contains_cancelled_other_exception(self): + """_contains_cancelled should return False for other exceptions.""" + filter_instance = logging_config.PrettifyCancelledError() + + other_exc = ValueError("test") + assert filter_instance._contains_cancelled(other_exc) is False + + +class TestLoggingConfig: + """Tests for LOGGING_CONFIG dictionary.""" + + def test_logging_config_has_required_keys(self): + """LOGGING_CONFIG should have all required keys.""" + assert "version" in logging_config.LOGGING_CONFIG + assert "disable_existing_loggers" in logging_config.LOGGING_CONFIG + assert "formatters" in logging_config.LOGGING_CONFIG + assert "filters" in logging_config.LOGGING_CONFIG + assert "handlers" in logging_config.LOGGING_CONFIG + assert "loggers" in logging_config.LOGGING_CONFIG + + def test_logging_config_version(self): + """LOGGING_CONFIG version should be 1.""" + assert logging_config.LOGGING_CONFIG["version"] == 1 + + def test_logging_config_does_not_disable_existing_loggers(self): + """LOGGING_CONFIG should not disable existing loggers.""" + assert logging_config.LOGGING_CONFIG["disable_existing_loggers"] is False + + def test_standard_formatter_defined(self): + """LOGGING_CONFIG should define standard formatter.""" + formatters = logging_config.LOGGING_CONFIG["formatters"] + assert "standard" in formatters + + def test_version_filter_configured(self): + """LOGGING_CONFIG should configure version_filter.""" + filters = logging_config.LOGGING_CONFIG["filters"] + assert "version_filter" in filters + assert filters["version_filter"]["()"] == logging_config.VersionFilter + + def test_prettify_cancelled_filter_configured(self): + """LOGGING_CONFIG should configure prettify_cancelled filter.""" + filters = logging_config.LOGGING_CONFIG["filters"] + assert "prettify_cancelled" in filters + assert filters["prettify_cancelled"]["()"] == logging_config.PrettifyCancelledError + + def test_default_handler_configured(self): + """LOGGING_CONFIG should configure default handler.""" + handlers = logging_config.LOGGING_CONFIG["handlers"] + assert "default" in handlers + assert handlers["default"]["formatter"] == "standard" + assert handlers["default"]["class"] == "logging.StreamHandler" + assert "version_filter" in handlers["default"]["filters"] + + def test_root_logger_configured(self): + """LOGGING_CONFIG should configure root logger.""" + loggers = logging_config.LOGGING_CONFIG["loggers"] + assert "" in loggers + assert "default" in loggers[""]["handlers"] + assert loggers[""]["propagate"] is False + + def test_uvicorn_loggers_configured(self): + """LOGGING_CONFIG should configure uvicorn loggers.""" + loggers = logging_config.LOGGING_CONFIG["loggers"] + assert "uvicorn.error" in loggers + assert "uvicorn.access" in loggers + assert "prettify_cancelled" in loggers["uvicorn.error"]["filters"] + + def test_asyncio_logger_configured(self): + """LOGGING_CONFIG should configure asyncio logger.""" + loggers = logging_config.LOGGING_CONFIG["loggers"] + assert "asyncio" in loggers + assert "prettify_cancelled" in loggers["asyncio"]["filters"] + + def test_third_party_loggers_configured(self): + """LOGGING_CONFIG should configure third-party loggers.""" + loggers = logging_config.LOGGING_CONFIG["loggers"] + expected_loggers = [ + "watchdog.observers.inotify_buffer", + "PIL", + "fsevents", + "numba", + "oci", + "openai", + "httpcore", + "sagemaker.config", + "LiteLLM", + "LiteLLM Proxy", + "LiteLLM Router", + ] + for logger_name in expected_loggers: + assert logger_name in loggers, f"Logger {logger_name} not configured" + + +class TestFormatterConfig: + """Tests for FORMATTER configuration.""" + + def test_formatter_format_string(self): + """FORMATTER should have correct format string.""" + assert "%(asctime)s" in logging_config.FORMATTER["format"] + assert "%(levelname)" in logging_config.FORMATTER["format"] + assert "%(name)s" in logging_config.FORMATTER["format"] + assert "%(message)s" in logging_config.FORMATTER["format"] + assert "%(__version__)s" in logging_config.FORMATTER["format"] + + def test_formatter_date_format(self): + """FORMATTER should have correct date format.""" + assert logging_config.FORMATTER["datefmt"] == "%Y-%b-%d %H:%M:%S" + + +class TestDebugMode: + """Tests for DEBUG_MODE behavior.""" + + def test_debug_mode_from_environment(self): + """DEBUG_MODE should be set from LOG_LEVEL environment variable.""" + # The actual DEBUG_MODE value depends on the environment at import time + # We just verify it's a boolean + assert isinstance(logging_config.DEBUG_MODE, bool) + + def test_log_level_from_environment(self): + """LOG_LEVEL should be read from environment or default to INFO.""" + # LOG_LEVEL is either the env var value or logging.INFO + assert logging_config.LOG_LEVEL is not None + + +class TestWarningsCaptured: + """Tests for warnings capture configuration.""" + + def test_warnings_logger_configured(self): + """py.warnings logger should be configured.""" + loggers = logging_config.LOGGING_CONFIG["loggers"] + assert "py.warnings" in loggers + assert loggers["py.warnings"]["propagate"] is False + + +class TestLiteLLMLoggersCleaned: + """Tests for LiteLLM logger cleanup.""" + + def test_litellm_loggers_propagate_disabled(self): + """LiteLLM loggers should have propagate disabled.""" + # Note: The handlers may be re-added by other test imports, + # but propagate should remain disabled + for name in ["LiteLLM", "LiteLLM Proxy", "LiteLLM Router"]: + logger = logging.getLogger(name) + assert logger.propagate is False diff --git a/test/unit/common/test_schema.py b/test/unit/common/test_schema.py new file mode 100644 index 00000000..0dc0616e --- /dev/null +++ b/test/unit/common/test_schema.py @@ -0,0 +1,621 @@ +""" +Copyright (c) 2024, 2025, Oracle and/or its affiliates. +Licensed under the Universal Permissive License v1.0 as shown at http://oss.oracle.com/licenses/upl. + +Unit tests for common/schema.py + +Tests Pydantic models, field validation, and utility methods. +""" +# pylint: disable=too-few-public-methods + +import time +from unittest.mock import MagicMock +import pytest +from pydantic import ValidationError + +from langchain_core.messages import ChatMessage + +from common.schema import ( + # Database models + DatabaseVectorStorage, + VectorStoreRefreshRequest, + VectorStoreRefreshStatus, + DatabaseAuth, + Database, + # Model models + LanguageModelParameters, + EmbeddingModelParameters, + ModelAccess, + Model, + # OCI models + OracleResource, + OracleCloudSettings, + # Prompt models + MCPPrompt, + # Settings models + VectorSearchSettings, + Settings, + # Configuration + Configuration, + # Completions + ChatRequest, + # Testbed + TestSets, + TestSetQA, + Evaluation, + EvaluationReport, + # Types + ClientIdType, + DatabaseNameType, + VectorStoreTableType, + ModelIdType, + ModelProviderType, + ModelTypeType, + ModelEnabledType, + OCIProfileType, + OCIResourceOCID, +) + + +class TestDatabaseVectorStorage: + """Tests for DatabaseVectorStorage model.""" + + def test_default_values(self): + """DatabaseVectorStorage should have correct defaults.""" + storage = DatabaseVectorStorage() + + assert storage.vector_store is None + assert storage.alias is None + assert storage.description is None + assert storage.model is None + assert storage.chunk_size == 0 + assert storage.chunk_overlap == 0 + assert storage.distance_metric is None + assert storage.index_type is None + + def test_with_all_values(self): + """DatabaseVectorStorage should accept all valid values.""" + storage = DatabaseVectorStorage( + vector_store="TEST_VS", + alias="test_alias", + description="Test description", + model="text-embedding-ada-002", + chunk_size=1000, + chunk_overlap=100, + distance_metric="COSINE", + index_type="HNSW", + ) + + assert storage.vector_store == "TEST_VS" + assert storage.alias == "test_alias" + assert storage.description == "Test description" + assert storage.model == "text-embedding-ada-002" + assert storage.chunk_size == 1000 + assert storage.chunk_overlap == 100 + assert storage.distance_metric == "COSINE" + assert storage.index_type == "HNSW" + + def test_distance_metric_literals(self): + """DatabaseVectorStorage should only accept valid distance metrics.""" + for metric in ["COSINE", "EUCLIDEAN_DISTANCE", "DOT_PRODUCT"]: + storage = DatabaseVectorStorage(distance_metric=metric) + assert storage.distance_metric == metric + + def test_index_type_literals(self): + """DatabaseVectorStorage should only accept valid index types.""" + for index_type in ["HNSW", "IVF"]: + storage = DatabaseVectorStorage(index_type=index_type) + assert storage.index_type == index_type + + +class TestVectorStoreRefreshRequest: + """Tests for VectorStoreRefreshRequest model.""" + + def test_required_fields(self): + """VectorStoreRefreshRequest should require vector_store_alias and bucket_name.""" + with pytest.raises(ValidationError): + VectorStoreRefreshRequest() + + request = VectorStoreRefreshRequest( + vector_store_alias="test_alias", + bucket_name="test-bucket", + ) + assert request.vector_store_alias == "test_alias" + assert request.bucket_name == "test-bucket" + + def test_default_values(self): + """VectorStoreRefreshRequest should have correct defaults.""" + request = VectorStoreRefreshRequest( + vector_store_alias="test", + bucket_name="bucket", + ) + assert request.auth_profile == "DEFAULT" + assert request.rate_limit == 0 + + +class TestVectorStoreRefreshStatus: + """Tests for VectorStoreRefreshStatus model.""" + + def test_required_fields(self): + """VectorStoreRefreshStatus should require status and message.""" + with pytest.raises(ValidationError): + VectorStoreRefreshStatus() + + status = VectorStoreRefreshStatus( + status="processing", + message="In progress", + ) + assert status.status == "processing" + + def test_status_literals(self): + """VectorStoreRefreshStatus should only accept valid status values.""" + for valid_status in ["processing", "completed", "failed"]: + status = VectorStoreRefreshStatus(status=valid_status, message="test") + assert status.status == valid_status + + def test_default_values(self): + """VectorStoreRefreshStatus should have correct defaults.""" + status = VectorStoreRefreshStatus(status="completed", message="Done") + assert status.processed_files == 0 + assert status.new_files == 0 + assert status.updated_files == 0 + assert status.total_chunks == 0 + assert status.total_chunks_in_store == 0 + assert status.errors == [] + + +class TestDatabaseAuth: + """Tests for DatabaseAuth model.""" + + def test_default_values(self): + """DatabaseAuth should have correct defaults.""" + auth = DatabaseAuth() + + assert auth.user is None + assert auth.password is None + assert auth.dsn is None + assert auth.wallet_password is None + assert auth.wallet_location is None + assert auth.config_dir == "tns_admin" + assert auth.tcp_connect_timeout == 5 + + def test_sensitive_fields_marked(self): + """DatabaseAuth sensitive fields should be marked.""" + password_field = DatabaseAuth.model_fields.get("password") + assert password_field.json_schema_extra.get("sensitive") is True + + wallet_password_field = DatabaseAuth.model_fields.get("wallet_password") + assert wallet_password_field.json_schema_extra.get("sensitive") is True + + +class TestDatabase: + """Tests for Database model.""" + + def test_inherits_from_database_auth(self): + """Database should inherit from DatabaseAuth.""" + assert issubclass(Database, DatabaseAuth) + + def test_default_values(self): + """Database should have correct defaults.""" + db = Database() + + assert db.name == "DEFAULT" + assert db.connected is False + assert db.vector_stores == [] + assert db.user is None # Inherited from DatabaseAuth + + def test_connection_property(self): + """Database connection property should work correctly.""" + db = Database() + assert db.connection is None + + mock_conn = MagicMock() + db.set_connection(mock_conn) + assert db.connection == mock_conn + + def test_readonly_fields_marked(self): + """Database readonly fields should be marked.""" + connected_field = Database.model_fields["connected"] + assert connected_field.json_schema_extra.get("readOnly") is True + + vector_stores_field = Database.model_fields["vector_stores"] + assert vector_stores_field.json_schema_extra.get("readOnly") is True + + +class TestLanguageModelParameters: + """Tests for LanguageModelParameters model.""" + + def test_default_values(self): + """LanguageModelParameters should have correct defaults.""" + params = LanguageModelParameters() + + assert params.max_input_tokens is None + assert params.frequency_penalty == 0.00 + assert params.max_tokens == 4096 + assert params.presence_penalty == 0.00 + assert params.temperature == 0.50 + assert params.top_p == 1.00 + + +class TestEmbeddingModelParameters: + """Tests for EmbeddingModelParameters model.""" + + def test_default_values(self): + """EmbeddingModelParameters should have correct defaults.""" + params = EmbeddingModelParameters() + + assert params.max_chunk_size == 8192 + + +class TestModelAccess: + """Tests for ModelAccess model.""" + + def test_default_values(self): + """ModelAccess should have correct defaults.""" + access = ModelAccess() + + assert access.enabled is False + assert access.api_base is None + assert access.api_key is None + + def test_sensitive_field_marked(self): + """ModelAccess api_key should be marked sensitive.""" + api_key_field = ModelAccess.model_fields.get("api_key") + assert api_key_field.json_schema_extra.get("sensitive") is True + + +class TestModel: + """Tests for Model model.""" + + def test_required_fields(self): + """Model should require id, type, and provider.""" + with pytest.raises(ValidationError): + Model() + + model = Model(id="gpt-4", type="ll", provider="openai") + assert model.id == "gpt-4" + assert model.type == "ll" + assert model.provider == "openai" + + def test_default_values(self): + """Model should have correct defaults.""" + model = Model(id="test-model", type="embed", provider="test") + + assert model.object == "model" + assert model.owned_by == "aioptimizer" + assert model.enabled is False + + def test_created_timestamp(self): + """Model created should be a Unix timestamp.""" + before = int(time.time()) + model = Model(id="test", type="ll", provider="test") + after = int(time.time()) + + assert before <= model.created <= after + + def test_type_literals(self): + """Model type should only accept valid values.""" + for model_type in ["ll", "embed", "rerank"]: + model = Model(id="test", type=model_type, provider="test") + assert model.type == model_type + + def test_id_min_length(self): + """Model id should have minimum length of 1.""" + with pytest.raises(ValidationError): + Model(id="", type="ll", provider="openai") + + +class TestOracleResource: + """Tests for OracleResource model.""" + + def test_valid_ocid(self): + """OracleResource should accept valid OCIDs.""" + valid_ocid = "ocid1.compartment.oc1..aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa" + resource = OracleResource(ocid=valid_ocid) + assert resource.ocid == valid_ocid + + def test_invalid_ocid_rejected(self): + """OracleResource should reject invalid OCIDs.""" + with pytest.raises(ValidationError): + OracleResource(ocid="invalid-ocid") + + +class TestOracleCloudSettings: + """Tests for OracleCloudSettings model.""" + + def test_default_values(self): + """OracleCloudSettings should have correct defaults.""" + settings = OracleCloudSettings() + + assert settings.auth_profile == "DEFAULT" + assert settings.namespace is None + assert settings.user is None + assert settings.security_token_file is None + assert settings.authentication == "api_key" + assert settings.genai_compartment_id is None + assert settings.genai_region is None + + def test_authentication_literals(self): + """OracleCloudSettings authentication should only accept valid values.""" + valid_auths = ["api_key", "instance_principal", "oke_workload_identity", "security_token"] + for auth in valid_auths: + settings = OracleCloudSettings(authentication=auth) + assert settings.authentication == auth + + def test_allows_extra_fields(self): + """OracleCloudSettings should allow extra fields.""" + settings = OracleCloudSettings(extra_field="extra_value") + assert settings.extra_field == "extra_value" + + +class TestMCPPrompt: + """Tests for MCPPrompt model.""" + + def test_required_fields(self): + """MCPPrompt should require name, title, and text.""" + with pytest.raises(ValidationError): + MCPPrompt() + + prompt = MCPPrompt(name="test_prompt", title="Test", text="Hello") + assert prompt.name == "test_prompt" + + def test_default_values(self): + """MCPPrompt should have correct defaults.""" + prompt = MCPPrompt(name="test", title="Test", text="Content") + + assert prompt.description == "" + assert prompt.tags == [] + + +class TestSettings: + """Tests for Settings model.""" + + def test_required_client(self): + """Settings should require client.""" + with pytest.raises(ValidationError): + Settings() + + settings = Settings(client="test_client") + assert settings.client == "test_client" + + def test_client_min_length(self): + """Settings client should have minimum length of 1.""" + with pytest.raises(ValidationError): + Settings(client="") + + def test_default_values(self): + """Settings should have correct defaults.""" + settings = Settings(client="test") + + assert settings.ll_model is not None + assert settings.oci is not None + assert settings.database is not None + assert settings.tools_enabled == ["LLM Only"] + assert settings.vector_search is not None + assert settings.testbed is not None + + +class TestVectorSearchSettings: + """Tests for VectorSearchSettings model.""" + + def test_default_values(self): + """VectorSearchSettings should have correct defaults.""" + settings = VectorSearchSettings() + + assert settings.discovery is True + assert settings.rephrase is True + assert settings.grade is True + assert settings.search_type == "Similarity" + assert settings.top_k == 4 + assert settings.score_threshold == 0.0 + assert settings.fetch_k == 20 + assert settings.lambda_mult == 0.5 + + def test_search_type_literals(self): + """VectorSearchSettings search_type should only accept valid values.""" + valid_types = ["Similarity", "Similarity Score Threshold", "Maximal Marginal Relevance"] + for search_type in valid_types: + settings = VectorSearchSettings(search_type=search_type) + assert settings.search_type == search_type + + def test_top_k_validation(self): + """VectorSearchSettings top_k should be between 1 and 10000.""" + # Valid + VectorSearchSettings(top_k=1) + VectorSearchSettings(top_k=10000) + + # Invalid + with pytest.raises(ValidationError): + VectorSearchSettings(top_k=0) + with pytest.raises(ValidationError): + VectorSearchSettings(top_k=10001) + + def test_score_threshold_validation(self): + """VectorSearchSettings score_threshold should be between 0.0 and 1.0.""" + VectorSearchSettings(score_threshold=0.0) + VectorSearchSettings(score_threshold=1.0) + + with pytest.raises(ValidationError): + VectorSearchSettings(score_threshold=-0.1) + with pytest.raises(ValidationError): + VectorSearchSettings(score_threshold=1.1) + + +class TestConfiguration: + """Tests for Configuration model.""" + + def test_required_client_settings(self): + """Configuration should require client_settings.""" + with pytest.raises(ValidationError): + Configuration() + + config = Configuration(client_settings=Settings(client="test")) + assert config.client_settings.client == "test" + + def test_optional_config_lists(self): + """Configuration config lists should be optional.""" + config = Configuration(client_settings=Settings(client="test")) + + assert config.database_configs is None + assert config.model_configs is None + assert config.oci_configs is None + assert config.prompt_configs is None + + def test_model_dump_public_excludes_sensitive(self): + """model_dump_public should exclude sensitive fields by default.""" + db = Database(name="TEST", user="user", password="secret123", dsn="localhost") + config = Configuration( + client_settings=Settings(client="test"), + database_configs=[db], + ) + + dumped = config.model_dump_public(incl_sensitive=False) + assert "password" not in dumped["database_configs"][0] + + def test_model_dump_public_includes_sensitive_when_requested(self): + """model_dump_public should include sensitive fields when requested.""" + db = Database(name="TEST", user="user", password="secret123", dsn="localhost") + config = Configuration( + client_settings=Settings(client="test"), + database_configs=[db], + ) + + dumped = config.model_dump_public(incl_sensitive=True) + assert dumped["database_configs"][0]["password"] == "secret123" + + def test_model_dump_public_excludes_readonly(self): + """model_dump_public should exclude readonly fields by default.""" + db = Database(name="TEST", connected=True) + config = Configuration( + client_settings=Settings(client="test"), + database_configs=[db], + ) + + dumped = config.model_dump_public(incl_readonly=False) + assert "connected" not in dumped["database_configs"][0] + assert "vector_stores" not in dumped["database_configs"][0] + + def test_model_dump_public_includes_readonly_when_requested(self): + """model_dump_public should include readonly fields when requested.""" + db = Database(name="TEST", connected=True) + config = Configuration( + client_settings=Settings(client="test"), + database_configs=[db], + ) + + dumped = config.model_dump_public(incl_readonly=True) + assert dumped["database_configs"][0]["connected"] is True + + def test_recursive_dump_handles_nested_lists(self): + """recursive_dump should handle nested lists correctly.""" + storage = DatabaseVectorStorage(vector_store="VS1", alias="test") + db = Database(name="TEST", vector_stores=[storage]) + config = Configuration( + client_settings=Settings(client="test"), + database_configs=[db], + ) + + dumped = config.model_dump_public(incl_readonly=True) + assert dumped["database_configs"][0]["vector_stores"][0]["alias"] == "test" + + def test_recursive_dump_handles_dicts(self): + """recursive_dump should handle dicts correctly.""" + # OracleCloudSettings allows extra fields + oci = OracleCloudSettings(auth_profile="TEST", extra_key="extra_value") + config = Configuration( + client_settings=Settings(client="test"), + oci_configs=[oci], + ) + + dumped = config.model_dump_public() + assert dumped["oci_configs"][0]["extra_key"] == "extra_value" + + +class TestChatRequest: + """Tests for ChatRequest model.""" + + def test_required_messages(self): + """ChatRequest should require messages.""" + with pytest.raises(ValidationError): + ChatRequest() + + def test_inherits_language_model_parameters(self): + """ChatRequest should inherit from LanguageModelParameters.""" + assert issubclass(ChatRequest, LanguageModelParameters) + + def test_default_model_is_none(self): + """ChatRequest model should default to None.""" + request = ChatRequest(messages=[ChatMessage(role="user", content="Hello")]) + assert request.model is None + + +class TestTestbedModels: + """Tests for Testbed-related models.""" + + def test_test_sets_required_fields(self): + """TestSets should require tid, name, and created.""" + with pytest.raises(ValidationError): + TestSets() + + test_set = TestSets(tid="123", name="Test Set", created="2024-01-01") + assert test_set.tid == "123" + + def test_test_set_qa_required_fields(self): + """TestSetQA should require qa_data.""" + with pytest.raises(ValidationError): + TestSetQA() + + qa = TestSetQA(qa_data=[{"q": "question", "a": "answer"}]) + assert len(qa.qa_data) == 1 + + def test_evaluation_required_fields(self): + """Evaluation should require eid, evaluated, and correctness.""" + with pytest.raises(ValidationError): + Evaluation() + + evaluation = Evaluation(eid="eval1", evaluated="2024-01-01", correctness=0.95) + assert evaluation.correctness == 0.95 + + def test_evaluation_report_inherits_evaluation(self): + """EvaluationReport should inherit from Evaluation.""" + assert issubclass(EvaluationReport, Evaluation) + + +class TestTypeAliases: + """Tests for type aliases.""" + + def test_client_id_type(self): + """ClientIdType should be the annotation for Settings.client.""" + assert ClientIdType == Settings.__annotations__["client"] + + def test_database_name_type(self): + """DatabaseNameType should be the annotation for Database.name.""" + assert DatabaseNameType == Database.__annotations__["name"] + + def test_vector_store_table_type(self): + """VectorStoreTableType should be the annotation for DatabaseVectorStorage.vector_store.""" + assert VectorStoreTableType == DatabaseVectorStorage.__annotations__["vector_store"] + + def test_model_id_type(self): + """ModelIdType should be the annotation for Model.id.""" + assert ModelIdType == Model.__annotations__["id"] + + def test_model_provider_type(self): + """ModelProviderType should be the annotation for Model.provider.""" + assert ModelProviderType == Model.__annotations__["provider"] + + def test_model_type_type(self): + """ModelTypeType should be the annotation for Model.type.""" + assert ModelTypeType == Model.__annotations__["type"] + + def test_model_enabled_type(self): + """ModelEnabledType should be the annotation for ModelAccess.enabled.""" + assert ModelEnabledType == ModelAccess.__annotations__["enabled"] + + def test_oci_profile_type(self): + """OCIProfileType should be the annotation for OracleCloudSettings.auth_profile.""" + assert OCIProfileType == OracleCloudSettings.__annotations__["auth_profile"] + + def test_oci_resource_ocid(self): + """OCIResourceOCID should be the annotation for OracleResource.ocid.""" + assert OCIResourceOCID == OracleResource.__annotations__["ocid"] diff --git a/test/unit/common/test_version.py b/test/unit/common/test_version.py new file mode 100644 index 00000000..6dcadb82 --- /dev/null +++ b/test/unit/common/test_version.py @@ -0,0 +1,36 @@ +""" +Copyright (c) 2024, 2025, Oracle and/or its affiliates. +Licensed under the Universal Permissive License v1.0 as shown at http://oss.oracle.com/licenses/upl. + +Unit tests for common/_version.py + +Tests version string retrieval. +""" + +from common._version import __version__ + + +class TestVersion: + """Tests for __version__ variable.""" + + def test_version_is_string(self): + """__version__ should be a string.""" + assert isinstance(__version__, str) + + def test_version_is_non_empty(self): + """__version__ should be non-empty.""" + assert len(__version__) > 0 + + def test_version_format(self): + """__version__ should be a valid version string or fallback.""" + # Version should either be a proper version number or the fallback "0.0.0" + # Valid versions can be like "1.0.0", "1.0.0.dev1", "1.3.1.dev128+g867d96f69.d20251126" + assert __version__ == "0.0.0" or "." in __version__ + + def test_version_no_leading_whitespace(self): + """__version__ should not have leading whitespace.""" + assert __version__ == __version__.lstrip() + + def test_version_no_trailing_whitespace(self): + """__version__ should not have trailing whitespace.""" + assert __version__ == __version__.rstrip() diff --git a/test/unit/server/api/v1/test_v1_embed.py b/test/unit/server/api/v1/test_v1_embed.py index 2a2dcdc8..aef37044 100644 --- a/test/unit/server/api/v1/test_v1_embed.py +++ b/test/unit/server/api/v1/test_v1_embed.py @@ -5,8 +5,7 @@ Unit tests for server/api/v1/embed.py Tests for document embedding and vector store endpoints. """ -# pylint: disable=protected-access -# pylint: disable=redefined-outer-name +# pylint: disable=protected-access redefined-outer-name # Pytest fixtures use parameter injection where fixture names match parameters from io import BytesIO @@ -27,14 +26,16 @@ @pytest.fixture def split_embed_mocks(): """Fixture providing bundled mocks for split_embed tests.""" - with patch("server.api.v1.embed.utils_oci.get") as mock_oci_get, \ - patch("server.api.v1.embed.utils_embed.get_temp_directory") as mock_get_temp, \ - patch("server.api.v1.embed.utils_embed.load_and_split_documents") as mock_load_split, \ - patch("server.api.v1.embed.utils_models.get_client_embed") as mock_get_embed, \ - patch("server.api.v1.embed.functions.get_vs_table") as mock_get_vs_table, \ - patch("server.api.v1.embed.utils_embed.populate_vs") as mock_populate, \ - patch("server.api.v1.embed.utils_databases.get_client_database") as mock_get_db, \ - patch("shutil.rmtree") as mock_rmtree: + with ( + patch("server.api.v1.embed.utils_oci.get") as mock_oci_get, + patch("server.api.v1.embed.utils_embed.get_temp_directory") as mock_get_temp, + patch("server.api.v1.embed.utils_embed.load_and_split_documents") as mock_load_split, + patch("server.api.v1.embed.utils_models.get_client_embed") as mock_get_embed, + patch("server.api.v1.embed.functions.get_vs_table") as mock_get_vs_table, + patch("server.api.v1.embed.utils_embed.populate_vs") as mock_populate, + patch("server.api.v1.embed.utils_databases.get_client_database") as mock_get_db, + patch("shutil.rmtree") as mock_rmtree, + ): yield { "oci_get": mock_oci_get, "get_temp": mock_get_temp, @@ -50,15 +51,17 @@ def split_embed_mocks(): @pytest.fixture def refresh_vector_store_mocks(): """Fixture providing bundled mocks for refresh_vector_store tests.""" - with patch("server.api.v1.embed.utils_oci.get") as mock_oci_get, \ - patch("server.api.v1.embed.utils_databases.get_client_database") as mock_get_db, \ - patch("server.api.v1.embed.utils_embed.get_vector_store_by_alias") as mock_get_vs, \ - patch("server.api.v1.embed.utils_oci.get_bucket_objects_with_metadata") as mock_get_objects, \ - patch("server.api.v1.embed.utils_embed.get_processed_objects_metadata") as mock_get_processed, \ - patch("server.api.v1.embed.utils_oci.detect_changed_objects") as mock_detect_changed, \ - patch("server.api.v1.embed.utils_embed.get_total_chunks_count") as mock_get_chunks, \ - patch("server.api.v1.embed.utils_models.get_client_embed") as mock_get_embed, \ - patch("server.api.v1.embed.utils_embed.refresh_vector_store_from_bucket") as mock_refresh: + with ( + patch("server.api.v1.embed.utils_oci.get") as mock_oci_get, + patch("server.api.v1.embed.utils_databases.get_client_database") as mock_get_db, + patch("server.api.v1.embed.utils_embed.get_vector_store_by_alias") as mock_get_vs, + patch("server.api.v1.embed.utils_oci.get_bucket_objects_with_metadata") as mock_get_objects, + patch("server.api.v1.embed.utils_embed.get_processed_objects_metadata") as mock_get_processed, + patch("server.api.v1.embed.utils_oci.detect_changed_objects") as mock_detect_changed, + patch("server.api.v1.embed.utils_embed.get_total_chunks_count") as mock_get_chunks, + patch("server.api.v1.embed.utils_models.get_client_embed") as mock_get_embed, + patch("server.api.v1.embed.utils_embed.refresh_vector_store_from_bucket") as mock_refresh, + ): yield { "oci_get": mock_oci_get, "get_db": mock_get_db, @@ -340,6 +343,7 @@ async def test_store_web_file_pdf_success(self, mock_session_class, mock_get_tem assert result.status_code == 200 + class TestStoreLocalFile: """Tests for the store_local_file endpoint.""" @@ -434,9 +438,7 @@ async def test_split_embed_raises_404_when_folder_not_found(self, mock_get_temp, assert exc_info.value.status_code == 404 @pytest.mark.asyncio - async def test_split_embed_success( - self, split_embed_mocks, tmp_path, make_oci_config, make_database - ): + async def test_split_embed_success(self, split_embed_mocks, tmp_path, make_oci_config, make_database): """split_embed should process files and populate vector store.""" mocks = split_embed_mocks mocks["oci_get"].return_value = make_oci_config() @@ -529,9 +531,7 @@ async def test_split_embed_raises_500_on_generic_exception( assert "Unexpected error occurred" in exc_info.value.detail @pytest.mark.asyncio - async def test_split_embed_loads_file_metadata( - self, split_embed_mocks, tmp_path, make_oci_config, make_database - ): + async def test_split_embed_loads_file_metadata(self, split_embed_mocks, tmp_path, make_oci_config, make_database): """split_embed should load file metadata when available.""" mocks = split_embed_mocks mocks["oci_get"].return_value = make_oci_config() diff --git a/tests/common/test_functions_sql.py b/tests/common/test_functions_sql.py index 8a1be015..44ad9d3b 100644 --- a/tests/common/test_functions_sql.py +++ b/tests/common/test_functions_sql.py @@ -3,148 +3,20 @@ Copyright (c) 2024, 2025, Oracle and/or its affiliates. Licensed under the Universal Permissive License v1.0 as shown at http://oss.oracle.com/licenses/upl. -Unit tests for SQL validation functions in common.functions +Unit tests for client-side SQL validation integration + +Note: Tests for common.functions.is_sql_accessible have been migrated to +test/unit/common/test_functions.py. This file contains only client-side tests +for FileSourceData and UI error display logic. """ # spell-checker: disable -from unittest.mock import Mock, patch +from unittest.mock import patch import pytest -import oracledb from common import functions -class TestIsSQLAccessible: - """Tests for the is_sql_accessible function""" - - def test_valid_sql_connection_and_query(self): - """Test that a valid SQL connection and query returns (True, '')""" - # Mock the oracledb connection and cursor - mock_cursor = Mock() - mock_cursor.description = [Mock(type=oracledb.DB_TYPE_VARCHAR)] - mock_cursor.fetchmany.return_value = [("row1",), ("row2",), ("row3",)] - - mock_connection = Mock() - mock_connection.__enter__ = Mock(return_value=mock_connection) - mock_connection.__exit__ = Mock(return_value=None) - mock_connection.cursor.return_value.__enter__ = Mock(return_value=mock_cursor) - mock_connection.cursor.return_value.__exit__ = Mock(return_value=None) - - with patch("oracledb.connect", return_value=mock_connection): - ok, msg = functions.is_sql_accessible("testuser/testpass@testdsn", "SELECT text FROM documents") - - assert ok is True, "Expected SQL validation to succeed with valid connection and query" - assert msg == "", f"Expected no error message, got: {msg}" - - def test_invalid_connection_string_format(self): - """Test that an invalid connection string format returns (False, error_msg)""" - ok, msg = functions.is_sql_accessible("invalid_connection_string", "SELECT * FROM table") - - assert ok is False, "Expected SQL validation to fail with invalid connection string" - # The function logs "Wrong connection string" but returns the connection error - assert msg != "", "Expected an error message, got empty string" - # Either the ValueError message or the connection error should be present - assert "connection error" in msg.lower() or "Wrong connection string" in msg, \ - f"Expected connection error or 'Wrong connection string' in error, got: {msg}" - - def test_empty_result_set(self): - """Test that a query returning no rows returns (False, error_msg)""" - mock_cursor = Mock() - mock_cursor.description = [Mock(type=oracledb.DB_TYPE_VARCHAR)] - mock_cursor.fetchmany.return_value = [] # Empty result set - - mock_connection = Mock() - mock_connection.__enter__ = Mock(return_value=mock_connection) - mock_connection.__exit__ = Mock(return_value=None) - mock_connection.cursor.return_value.__enter__ = Mock(return_value=mock_cursor) - mock_connection.cursor.return_value.__exit__ = Mock(return_value=None) - - with patch("oracledb.connect", return_value=mock_connection): - ok, msg = functions.is_sql_accessible("testuser/testpass@testdsn", "SELECT text FROM empty_table") - - assert ok is False, "Expected SQL validation to fail with empty result set" - assert "empty table" in msg, f"Expected 'empty table' in error, got: {msg}" - - def test_multiple_columns_returned(self): - """Test that a query returning multiple columns returns (False, error_msg)""" - mock_cursor = Mock() - mock_cursor.description = [ - Mock(type=oracledb.DB_TYPE_VARCHAR), - Mock(type=oracledb.DB_TYPE_VARCHAR), - ] - mock_cursor.fetchmany.return_value = [("col1", "col2")] - - mock_connection = Mock() - mock_connection.__enter__ = Mock(return_value=mock_connection) - mock_connection.__exit__ = Mock(return_value=None) - mock_connection.cursor.return_value.__enter__ = Mock(return_value=mock_cursor) - mock_connection.cursor.return_value.__exit__ = Mock(return_value=None) - - with patch("oracledb.connect", return_value=mock_connection): - ok, msg = functions.is_sql_accessible("testuser/testpass@testdsn", "SELECT col1, col2 FROM table") - - assert ok is False, "Expected SQL validation to fail with multiple columns" - assert "2 columns" in msg, f"Expected '2 columns' in error, got: {msg}" - - def test_invalid_column_type(self): - """Test that a query returning non-VARCHAR column returns (False, error_msg)""" - mock_cursor = Mock() - mock_cursor.description = [Mock(type=oracledb.DB_TYPE_NUMBER)] - mock_cursor.fetchmany.return_value = [(123,)] - - mock_connection = Mock() - mock_connection.__enter__ = Mock(return_value=mock_connection) - mock_connection.__exit__ = Mock(return_value=None) - mock_connection.cursor.return_value.__enter__ = Mock(return_value=mock_cursor) - mock_connection.cursor.return_value.__exit__ = Mock(return_value=None) - - with patch("oracledb.connect", return_value=mock_connection): - ok, msg = functions.is_sql_accessible("testuser/testpass@testdsn", "SELECT id FROM table") - - assert ok is False, "Expected SQL validation to fail with non-VARCHAR column type" - assert "VARCHAR" in msg, f"Expected 'VARCHAR' in error, got: {msg}" - - def test_database_connection_error(self): - """Test that a database connection error returns (False, error_msg)""" - with patch("oracledb.connect", side_effect=oracledb.Error("Connection failed")): - ok, msg = functions.is_sql_accessible("testuser/testpass@testdsn", "SELECT text FROM table") - - assert ok is False, "Expected SQL validation to fail with connection error" - assert "connection error" in msg.lower(), f"Expected 'connection error' in message, got: {msg}" - - def test_empty_connection_string(self): - """Test that empty connection string returns (False, '')""" - ok, msg = functions.is_sql_accessible("", "SELECT * FROM table") - - assert ok is False, "Expected SQL validation to fail with empty connection string" - assert msg == "", f"Expected empty error message, got: {msg}" - - def test_empty_query(self): - """Test that empty query returns (False, '')""" - ok, msg = functions.is_sql_accessible("testuser/testpass@testdsn", "") - - assert ok is False, "Expected SQL validation to fail with empty query" - assert msg == "", f"Expected empty error message, got: {msg}" - - def test_nvarchar_column_type_accepted(self): - """Test that NVARCHAR column type is accepted as valid""" - mock_cursor = Mock() - mock_cursor.description = [Mock(type=oracledb.DB_TYPE_NVARCHAR)] - mock_cursor.fetchmany.return_value = [("text1",), ("text2",)] - - mock_connection = Mock() - mock_connection.__enter__ = Mock(return_value=mock_connection) - mock_connection.__exit__ = Mock(return_value=None) - mock_connection.cursor.return_value.__enter__ = Mock(return_value=mock_cursor) - mock_connection.cursor.return_value.__exit__ = Mock(return_value=None) - - with patch("oracledb.connect", return_value=mock_connection): - ok, msg = functions.is_sql_accessible("testuser/testpass@testdsn", "SELECT ntext FROM table") - - assert ok is True, "Expected SQL validation to succeed with NVARCHAR column type" - assert msg == "", f"Expected no error message, got: {msg}" - - class TestFileSourceDataSQLValidation: """ Tests for FileSourceData.is_valid() method with SQL source From ad1293db44f29d5120342e447103da853c63f2eb Mon Sep 17 00:00:00 2001 From: gotsysdba Date: Sun, 30 Nov 2025 12:09:33 +0000 Subject: [PATCH 09/20] Migrated client tests --- test/integration/client/__init__.py | 0 test/integration/client/conftest.py | 432 ++++++++++++ test/integration/client/content/__init__.py | 0 .../client/content/config/__init__.py | 0 .../client/content/config/tabs/__init__.py | 0 .../content/config/tabs/test_databases.py | 53 +- .../client}/content/config/tabs/test_mcp.py | 0 .../content/config/tabs/test_models.py | 12 +- .../client}/content/config/tabs/test_oci.py | 0 .../content/config/tabs/test_settings.py | 0 .../client}/content/config/test_config.py | 6 +- .../client}/content/test_api_server.py | 20 +- .../client}/content/test_chatbot.py | 13 +- .../client/content/test_testbed.py | 32 + .../client/content/tools/__init__.py | 0 .../client/content/tools/tabs/__init__.py | 0 .../content/tools/tabs/test_prompt_eng.py | 0 .../content/tools/tabs/test_split_embed.py | 8 +- .../client}/content/tools/test_tools.py | 2 +- test/integration/client/utils/__init__.py | 0 .../client}/utils/test_st_footer.py | 0 .../client}/utils/test_vs_options.py | 0 test/shared_fixtures.py | 52 ++ test/unit/client/__init__.py | 0 test/unit/client/conftest.py | 44 ++ test/unit/client/content/__init__.py | 0 test/unit/client/content/config/__init__.py | 0 .../client/content/config/tabs/__init__.py | 0 .../content/config/tabs/test_mcp_unit.py | 0 .../content/config/tabs/test_models_unit.py | 0 .../unit/client}/content/test_chatbot_unit.py | 29 +- .../unit/client}/content/test_testbed_unit.py | 0 test/unit/client/content/tools/__init__.py | 0 .../client/content/tools/tabs/__init__.py | 0 .../tools/tabs/test_split_embed_unit.py | 0 test/unit/client/utils/__init__.py | 0 .../unit/client}/utils/test_client_unit.py | 0 .../unit/client}/utils/test_st_common_unit.py | 0 .../client}/utils/test_vs_options_unit.py | 0 test/unit/server/api/conftest.py | 5 + tests/README.md | 77 -- .../integration/content/test_testbed.py | 665 ------------------ .../integration/utils/test_st_common.py | 9 - tests/common/test_functions_sql.py | 139 ---- tests/conftest.py | 509 -------------- 45 files changed, 637 insertions(+), 1470 deletions(-) create mode 100644 test/integration/client/__init__.py create mode 100644 test/integration/client/conftest.py create mode 100644 test/integration/client/content/__init__.py create mode 100644 test/integration/client/content/config/__init__.py create mode 100644 test/integration/client/content/config/tabs/__init__.py rename {tests/client/integration => test/integration/client}/content/config/tabs/test_databases.py (84%) rename {tests/client/integration => test/integration/client}/content/config/tabs/test_mcp.py (100%) rename {tests/client/integration => test/integration/client}/content/config/tabs/test_models.py (98%) rename {tests/client/integration => test/integration/client}/content/config/tabs/test_oci.py (100%) rename {tests/client/integration => test/integration/client}/content/config/tabs/test_settings.py (100%) rename {tests/client/integration => test/integration/client}/content/config/test_config.py (98%) rename {tests/client/integration => test/integration/client}/content/test_api_server.py (73%) rename {tests/client/integration => test/integration/client}/content/test_chatbot.py (98%) create mode 100644 test/integration/client/content/test_testbed.py create mode 100644 test/integration/client/content/tools/__init__.py create mode 100644 test/integration/client/content/tools/tabs/__init__.py rename {tests/client/integration => test/integration/client}/content/tools/tabs/test_prompt_eng.py (100%) rename {tests/client/integration => test/integration/client}/content/tools/tabs/test_split_embed.py (99%) rename {tests/client/integration => test/integration/client}/content/tools/test_tools.py (98%) create mode 100644 test/integration/client/utils/__init__.py rename {tests/client/integration => test/integration/client}/utils/test_st_footer.py (100%) rename {tests/client/integration => test/integration/client}/utils/test_vs_options.py (100%) create mode 100644 test/unit/client/__init__.py create mode 100644 test/unit/client/conftest.py create mode 100644 test/unit/client/content/__init__.py create mode 100644 test/unit/client/content/config/__init__.py create mode 100644 test/unit/client/content/config/tabs/__init__.py rename {tests/client/unit => test/unit/client}/content/config/tabs/test_mcp_unit.py (100%) rename {tests/client/unit => test/unit/client}/content/config/tabs/test_models_unit.py (100%) rename {tests/client/unit => test/unit/client}/content/test_chatbot_unit.py (96%) rename {tests/client/unit => test/unit/client}/content/test_testbed_unit.py (100%) create mode 100644 test/unit/client/content/tools/__init__.py create mode 100644 test/unit/client/content/tools/tabs/__init__.py rename {tests/client/unit => test/unit/client}/content/tools/tabs/test_split_embed_unit.py (100%) create mode 100644 test/unit/client/utils/__init__.py rename {tests/client/unit => test/unit/client}/utils/test_client_unit.py (100%) rename {tests/client/unit => test/unit/client}/utils/test_st_common_unit.py (100%) rename {tests/client/unit => test/unit/client}/utils/test_vs_options_unit.py (100%) delete mode 100644 tests/README.md delete mode 100644 tests/client/integration/content/test_testbed.py delete mode 100644 tests/client/integration/utils/test_st_common.py delete mode 100644 tests/common/test_functions_sql.py delete mode 100644 tests/conftest.py diff --git a/test/integration/client/__init__.py b/test/integration/client/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/test/integration/client/conftest.py b/test/integration/client/conftest.py new file mode 100644 index 00000000..31daa5ea --- /dev/null +++ b/test/integration/client/conftest.py @@ -0,0 +1,432 @@ +""" +Copyright (c) 2024, 2025, Oracle and/or its affiliates. +Licensed under the Universal Permissive License v1.0 as shown at http://oss.oracle.com/licenses/upl. + +Pytest fixtures for client integration tests. + +These fixtures provide Streamlit AppTest and FastAPI server management +for testing the client UI components. +""" + +# pylint: disable=redefined-outer-name + +import os +import sys +import time +import socket +import subprocess +from contextlib import contextmanager + +# Re-export shared fixtures for pytest discovery +from test.shared_fixtures import ( # noqa: F401 pylint: disable=unused-import + make_database, + make_model, + make_oci_config, + make_ll_settings, + make_settings, + make_configuration, + TEST_DB_USER, + TEST_DB_PASSWORD, + TEST_DB_DSN, + TEST_AUTH_TOKEN, + sample_vector_store_data, + sample_vector_store_data_alt, + sample_vector_stores_list, +) + +# Import TEST_DB_CONFIG for use in helper functions +# Note: db_container, db_connection, db_transaction fixtures are inherited +# from test/conftest.py - do not re-import here to avoid multiple containers +from test.db_fixtures import TEST_DB_CONFIG + +import pytest +import requests + +# Lazy import to avoid circular imports - stored in module-level variable +_app_test_class = None + + +def get_app_test(): + """Lazy import of Streamlit's AppTest.""" + global _app_test_class # pylint: disable=global-statement + if _app_test_class is None: + from streamlit.testing.v1 import AppTest # pylint: disable=import-outside-toplevel + + _app_test_class = AppTest + return _app_test_class + + +################################################# +# Environment Setup for Client Tests +################################################# + +# Clear environment variables that might interfere with tests +API_VARS = ["API_SERVER_KEY", "API_SERVER_URL", "API_SERVER_PORT"] +DB_VARS = ["DB_USERNAME", "DB_PASSWORD", "DB_DSN", "DB_WALLET_PASSWORD", "TNS_ADMIN"] +MODEL_VARS = [ + "ON_PREM_OLLAMA_URL", + "ON_PREM_HF_URL", + "OPENAI_API_KEY", + "PPLX_API_KEY", + "COHERE_API_KEY", +] + +for env_var in [ + *API_VARS, + *DB_VARS, + *MODEL_VARS, + *[var for var in os.environ if var.startswith("OCI_")], +]: + os.environ.pop(env_var, None) + +# Test client configuration +TEST_CLIENT = "client_test" +TEST_SERVER_PORT = 8015 + +# Set up test environment +os.environ["CONFIG_FILE"] = "/non/existent/path/config.json" +os.environ["OCI_CLI_CONFIG_FILE"] = "/non/existent/path" +os.environ["API_SERVER_KEY"] = TEST_AUTH_TOKEN +os.environ["API_SERVER_URL"] = "http://localhost" +os.environ["API_SERVER_PORT"] = str(TEST_SERVER_PORT) + + +################################################# +# Fixtures for Client Tests +################################################# + + +@pytest.fixture(name="auth_headers") +def _auth_headers(): + """Return common header configurations for testing.""" + return { + "no_auth": {}, + "invalid_auth": {"Authorization": "Bearer invalid-token", "client": TEST_CLIENT}, + "valid_auth": {"Authorization": f"Bearer {TEST_AUTH_TOKEN}", "client": TEST_CLIENT}, + } + + +@pytest.fixture(scope="session") +def app_server(request): + """Start the FastAPI server for Streamlit and wait for it to be ready.""" + + def is_port_in_use(port): + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: + return s.connect_ex(("localhost", port)) == 0 + + config_file = getattr(request, "param", None) + + # If config_file is passed, include it in the subprocess command + cmd = ["python", "launch_server.py"] + if config_file: + cmd.extend(["-c", config_file]) + + # Create environment with explicit port setting + # This ensures the server uses the correct port even when other conftest files + # may have modified os.environ during test collection + env = os.environ.copy() + env["API_SERVER_PORT"] = str(TEST_SERVER_PORT) + env["API_SERVER_KEY"] = TEST_AUTH_TOKEN + env["CONFIG_FILE"] = "/non/existent/path/config.json" + env["OCI_CLI_CONFIG_FILE"] = "/non/existent/path" + + server_process = subprocess.Popen(cmd, cwd="src", env=env) # pylint: disable=consider-using-with + + try: + # Wait for server to be ready (up to 30 seconds) + max_wait = 30 + start_time = time.time() + while not is_port_in_use(TEST_SERVER_PORT): + if time.time() - start_time > max_wait: + raise TimeoutError("Server failed to start within 30 seconds") + time.sleep(0.5) + + yield server_process + + finally: + # Terminate the server after tests + server_process.terminate() + server_process.wait() + + +@pytest.fixture +def app_test(auth_headers): + """Establish Streamlit State for Client to Operate. + + This fixture mimics what launch_client.py does in init_configs_state(), + loading the full configuration including all *_configs (database_configs, model_configs, + oci_configs, etc.) into session state, just like the real application does. + """ + app_test_cls = get_app_test() + + def _app_test(page): + # Convert relative paths like "../src/client/..." to absolute paths + # Tests use paths relative to old structure, convert to absolute + if page.startswith("../src/"): + # Get project root (test/integration/client -> project root) + project_root = os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(__file__)))) + page = os.path.join(project_root, page.replace("../src/", "src/")) + at = app_test_cls.from_file(page, default_timeout=30) + # Use constants directly instead of os.environ to avoid issues when + # other conftest files pop these variables during test collection + at.session_state.server = { + "key": TEST_AUTH_TOKEN, + "url": "http://localhost", + "port": TEST_SERVER_PORT, + "control": True, + } + server_url = f"{at.session_state.server['url']}:{at.session_state.server['port']}" + + # First, create the client (POST) - this initializes client settings on the server + # If client already exists (409), that's fine - we just need it to exist + requests.post( + url=f"{server_url}/v1/settings", + headers=auth_headers["valid_auth"], + params={"client": TEST_CLIENT}, + timeout=120, + ) + + # Load full config like launch_client.py does in init_configs_state() + full_config = requests.get( + url=f"{server_url}/v1/settings", + headers=auth_headers["valid_auth"], + params={ + "client": TEST_CLIENT, + "full_config": True, + "incl_sensitive": True, + "incl_readonly": True, + }, + timeout=120, + ).json() + # Load all config items into session state + for key, value in full_config.items(): + at.session_state[key] = value + return at + + return _app_test + + +################################################# +# Helper Functions +################################################# + + +def setup_test_database(app_test_instance): + """Configure and connect to test database for integration tests. + + This helper function: + 1. Updates database config with test credentials + 2. Patches the database on the server + 3. Reloads full config to get updated database status + + Args: + app_test_instance: The AppTest instance from app_test fixture + + Returns: + The updated AppTest instance with database configured + """ + if not app_test_instance.session_state.database_configs: + return app_test_instance + + # Update database config with test credentials + db_config = app_test_instance.session_state.database_configs[0] + db_config["user"] = TEST_DB_CONFIG["db_username"] + db_config["password"] = TEST_DB_CONFIG["db_password"] + db_config["dsn"] = TEST_DB_CONFIG["db_dsn"] + + # Update the database on the server to establish connection + server_url = app_test_instance.session_state.server["url"] + server_port = app_test_instance.session_state.server["port"] + server_key = app_test_instance.session_state.server["key"] + db_name = db_config["name"] + + response = requests.patch( + url=f"{server_url}:{server_port}/v1/databases/{db_name}", + headers={"Authorization": f"Bearer {server_key}", "client": TEST_CLIENT}, + json={ + "user": db_config["user"], + "password": db_config["password"], + "dsn": db_config["dsn"], + }, + timeout=120, + ) + + if response.status_code != 200: + raise RuntimeError(f"Failed to update database: {response.text}") + + # Reload the full config to get the updated database status + full_config = requests.get( + url=f"{server_url}:{server_port}/v1/settings", + headers={"Authorization": f"Bearer {server_key}", "client": TEST_CLIENT}, + params={ + "client": TEST_CLIENT, + "full_config": True, + "incl_sensitive": True, + "incl_readonly": True, + }, + timeout=120, + ).json() + + # Update session state with refreshed config + for key, value in full_config.items(): + app_test_instance.session_state[key] = value + + return app_test_instance + + +def enable_test_models(app_test_instance): + """Enable at least one LL model for testing. + + Args: + app_test_instance: The AppTest instance from app_test fixture + + Returns: + The updated AppTest instance with models enabled + """ + for model in app_test_instance.session_state.model_configs: + if model["type"] == "ll": + model["enabled"] = True + break + + return app_test_instance + + +def enable_test_embed_models(app_test_instance): + """Enable at least one embedding model for testing. + + Args: + app_test_instance: The AppTest instance from app_test fixture + + Returns: + The updated AppTest instance with embed models enabled + """ + for model in app_test_instance.session_state.model_configs: + if model["type"] == "embed": + model["enabled"] = True + break + + return app_test_instance + + +def create_tabs_mock(monkeypatch): + """Create a mock for st.tabs that captures what tabs are created. + + This is a helper function to reduce code duplication in tests that need + to verify which tabs are created by the application. + + Args: + monkeypatch: pytest monkeypatch fixture + + Returns: + A list that will be populated with tab names as they are created + """ + import streamlit as st # pylint: disable=import-outside-toplevel + + tabs_created = [] + original_tabs = st.tabs + + def mock_tabs(tab_list): + tabs_created.extend(tab_list) + return original_tabs(tab_list) + + monkeypatch.setattr(st, "tabs", mock_tabs) + return tabs_created + + +@contextmanager +def temporary_sys_path(path): + """Temporarily add a path to sys.path and remove it when done. + + This context manager is useful for tests that need to temporarily modify + the Python path to import modules from specific locations. + + Args: + path: Path to add to sys.path + + Yields: + None + """ + sys.path.insert(0, path) + try: + yield + finally: + if path in sys.path: + sys.path.remove(path) + + +def run_streamlit_test(app_test_instance, run=True): + """Helper to run a Streamlit test and verify no exceptions. + + This helper reduces code duplication in tests that follow the pattern: + 1. Run the app test + 2. Verify no exceptions occurred + + Args: + app_test_instance: The AppTest instance to run + run: Whether to run the test (default: True) + + Returns: + The AppTest instance (run or not based on the run parameter) + """ + if run: + app_test_instance = app_test_instance.run() + assert not app_test_instance.exception + return app_test_instance + + +def run_page_with_models_enabled(app_server, app_test_func, st_file): + """Helper to run a Streamlit page with models enabled and verify no exceptions. + + Common test pattern that: + 1. Verifies app_server is available + 2. Creates app test instance + 3. Enables test models + 4. Runs the test + 5. Verifies no exceptions occurred + + Args: + app_server: The app_server fixture (asserted not None) + app_test_func: The app_test fixture function + st_file: The Streamlit file path to test + + Returns: + The AppTest instance after running + """ + assert app_server is not None + at = app_test_func(st_file) + at = enable_test_models(at) + at = at.run() + assert not at.exception + return at + + +def get_test_db_payload(): + """Get standard test database payload for integration tests. + + Returns: + dict: Database configuration payload with test credentials + """ + return { + "user": TEST_DB_CONFIG["db_username"], + "password": TEST_DB_CONFIG["db_password"], + "dsn": TEST_DB_CONFIG["db_dsn"], + } + + +def get_sample_oci_config(): + """Get sample OCI configuration for unit tests. + + Returns: + OracleCloudSettings: Sample OCI configuration object + """ + from common.schema import OracleCloudSettings # pylint: disable=import-outside-toplevel + + return OracleCloudSettings( + auth_profile="DEFAULT", + compartment_id="ocid1.compartment.oc1..test", + genai_region="us-ashburn-1", + user="ocid1.user.oc1..testuser", + fingerprint="test-fingerprint", + tenancy="ocid1.tenancy.oc1..testtenant", + key_file="/path/to/key.pem", + ) diff --git a/test/integration/client/content/__init__.py b/test/integration/client/content/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/test/integration/client/content/config/__init__.py b/test/integration/client/content/config/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/test/integration/client/content/config/tabs/__init__.py b/test/integration/client/content/config/tabs/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/client/integration/content/config/tabs/test_databases.py b/test/integration/client/content/config/tabs/test_databases.py similarity index 84% rename from tests/client/integration/content/config/tabs/test_databases.py rename to test/integration/client/content/config/tabs/test_databases.py index ab0d33ec..8b5e937b 100644 --- a/tests/client/integration/content/config/tabs/test_databases.py +++ b/test/integration/client/content/config/tabs/test_databases.py @@ -5,10 +5,9 @@ """ # spell-checker: disable +from test.db_fixtures import TEST_DB_CONFIG import pytest -from conftest import TEST_CONFIG - ############################################################################# # Test Streamlit UI @@ -53,9 +52,9 @@ def test_no_database(self, app_server, app_test): assert app_server is not None at = app_test(self.ST_FILE).run() assert at.session_state.database_configs is not None - at.text_input(key="database_user").set_value(TEST_CONFIG["db_username"]).run() - at.text_input(key="database_password").set_value(TEST_CONFIG["db_password"]).run() - at.text_input(key="database_dsn").set_value(TEST_CONFIG["db_dsn"]).run() + at.text_input(key="database_user").set_value(TEST_DB_CONFIG["db_username"]).run() + at.text_input(key="database_password").set_value(TEST_DB_CONFIG["db_password"]).run() + at.text_input(key="database_dsn").set_value(TEST_DB_CONFIG["db_dsn"]).run() at.button(key="save_database").click().run() assert at.error[0].value == "Current Status: Disconnected" @@ -67,9 +66,9 @@ def test_connected(self, app_server, app_test, db_container): assert db_container is not None at = app_test(self.ST_FILE).run() assert at.session_state.database_configs is not None - at.text_input(key="database_user").set_value(TEST_CONFIG["db_username"]).run() - at.text_input(key="database_password").set_value(TEST_CONFIG["db_password"]).run() - at.text_input(key="database_dsn").set_value(TEST_CONFIG["db_dsn"]).run() + at.text_input(key="database_user").set_value(TEST_DB_CONFIG["db_username"]).run() + at.text_input(key="database_password").set_value(TEST_DB_CONFIG["db_password"]).run() + at.text_input(key="database_dsn").set_value(TEST_DB_CONFIG["db_dsn"]).run() at.button(key="save_database").click().run() assert at.success[0].value == "Current Status: Connected" @@ -77,17 +76,17 @@ def test_connected(self, app_server, app_test, db_container): at.button(key="save_database").click().run() assert at.toast[0].value == "No changes detected." and at.toast[0].icon == "ℹ️" - assert at.session_state.database_configs[0]["user"] == TEST_CONFIG["db_username"] - assert at.session_state.database_configs[0]["password"] == TEST_CONFIG["db_password"] - assert at.session_state.database_configs[0]["dsn"] == TEST_CONFIG["db_dsn"] + assert at.session_state.database_configs[0]["user"] == TEST_DB_CONFIG["db_username"] + assert at.session_state.database_configs[0]["password"] == TEST_DB_CONFIG["db_password"] + assert at.session_state.database_configs[0]["dsn"] == TEST_DB_CONFIG["db_dsn"] test_cases = [ pytest.param( { "alias": "DEFAULT", "username": "", - "password": TEST_CONFIG["db_password"], - "dsn": TEST_CONFIG["db_dsn"], + "password": TEST_DB_CONFIG["db_password"], + "dsn": TEST_DB_CONFIG["db_dsn"], "expected": "Update Failed - Database: DEFAULT missing connection details.", }, id="missing_input", @@ -96,8 +95,8 @@ def test_connected(self, app_server, app_test, db_container): { "alias": "DEFAULT", "username": "ADMIN", - "password": TEST_CONFIG["db_password"], - "dsn": TEST_CONFIG["db_dsn"], + "password": TEST_DB_CONFIG["db_password"], + "dsn": TEST_DB_CONFIG["db_dsn"], "expected": "invalid credential or not authorized", }, id="bad_user", @@ -105,9 +104,9 @@ def test_connected(self, app_server, app_test, db_container): pytest.param( { "alias": "DEFAULT", - "username": TEST_CONFIG["db_username"], + "username": TEST_DB_CONFIG["db_username"], "password": "Wr0ng_P4ssW0rd", - "dsn": TEST_CONFIG["db_dsn"], + "dsn": TEST_DB_CONFIG["db_dsn"], "expected": "invalid credential or not authorized", }, id="bad_password", @@ -115,8 +114,8 @@ def test_connected(self, app_server, app_test, db_container): pytest.param( { "alias": "DEFAULT", - "username": TEST_CONFIG["db_username"], - "password": TEST_CONFIG["db_password"], + "username": TEST_DB_CONFIG["db_username"], + "password": TEST_DB_CONFIG["db_password"], "dsn": "//localhost:1521/WRONG_TP", "expected": "cannot connect to database", }, @@ -125,8 +124,8 @@ def test_connected(self, app_server, app_test, db_container): pytest.param( { "alias": "DEFAULT", - "username": TEST_CONFIG["db_username"], - "password": TEST_CONFIG["db_password"], + "username": TEST_DB_CONFIG["db_username"], + "password": TEST_DB_CONFIG["db_password"], "dsn": "WRONG_TP", "expected": "DPY-4", }, @@ -143,9 +142,9 @@ def test_disconnected(self, app_server, app_test, db_container, test_case): assert at.session_state.database_configs is not None # Input and save good database - at.text_input(key="database_user").set_value(TEST_CONFIG["db_username"]).run() - at.text_input(key="database_password").set_value(TEST_CONFIG["db_password"]).run() - at.text_input(key="database_dsn").set_value(TEST_CONFIG["db_dsn"]).run() + at.text_input(key="database_user").set_value(TEST_DB_CONFIG["db_username"]).run() + at.text_input(key="database_password").set_value(TEST_DB_CONFIG["db_password"]).run() + at.text_input(key="database_dsn").set_value(TEST_DB_CONFIG["db_dsn"]).run() at.button(key="save_database").click().run() # Update Database Details and Save @@ -161,9 +160,9 @@ def test_disconnected(self, app_server, app_test, db_container, test_case): # Due to the connection error, the settings should NOT be updated and be set # to previous successful test connection; connected will be False for error handling assert at.session_state.database_configs[0]["name"] == "DEFAULT" - assert at.session_state.database_configs[0]["user"] == TEST_CONFIG["db_username"] - assert at.session_state.database_configs[0]["password"] == TEST_CONFIG["db_password"] - assert at.session_state.database_configs[0]["dsn"] == TEST_CONFIG["db_dsn"] + assert at.session_state.database_configs[0]["user"] == TEST_DB_CONFIG["db_username"] + assert at.session_state.database_configs[0]["password"] == TEST_DB_CONFIG["db_password"] + assert at.session_state.database_configs[0]["dsn"] == TEST_DB_CONFIG["db_dsn"] assert at.session_state.database_configs[0]["wallet_password"] is None assert at.session_state.database_configs[0]["wallet_location"] is None assert at.session_state.database_configs[0]["config_dir"] is not None diff --git a/tests/client/integration/content/config/tabs/test_mcp.py b/test/integration/client/content/config/tabs/test_mcp.py similarity index 100% rename from tests/client/integration/content/config/tabs/test_mcp.py rename to test/integration/client/content/config/tabs/test_mcp.py diff --git a/tests/client/integration/content/config/tabs/test_models.py b/test/integration/client/content/config/tabs/test_models.py similarity index 98% rename from tests/client/integration/content/config/tabs/test_models.py rename to test/integration/client/content/config/tabs/test_models.py index f683b937..daec52d1 100644 --- a/tests/client/integration/content/config/tabs/test_models.py +++ b/test/integration/client/content/config/tabs/test_models.py @@ -7,7 +7,8 @@ import os from unittest.mock import MagicMock, patch -from conftest import temporary_sys_path + +from test.integration.client.conftest import temporary_sys_path # Streamlit File ST_FILE = "../src/client/content/config/tabs/models.py" @@ -74,9 +75,6 @@ def test_model_display_both_types(self, app_server, app_test): assert hasattr(at.session_state, "model_configs") assert at.session_state.model_configs is not None - # Check that we have models of different types - # model_types = {model['type'] for model in at.session_state.model_configs} - # Should have sections for both types even if no models exist headers = at.get("header") header_text = [h.value for h in headers] @@ -503,7 +501,7 @@ class TestModelCRUD: def test_create_model_success(self, monkeypatch): """Test creating a new model""" - with temporary_sys_path(os.path.join(os.path.dirname(__file__), "../../../../src")): + with temporary_sys_path(os.path.join(os.path.dirname(__file__), "../../../../../../src")): from client.content.config.tabs import models from client.utils import api_call import streamlit as st @@ -533,7 +531,7 @@ def test_create_model_success(self, monkeypatch): def test_patch_model_success(self, monkeypatch): """Test patching an existing model""" - with temporary_sys_path(os.path.join(os.path.dirname(__file__), "../../../../src")): + with temporary_sys_path(os.path.join(os.path.dirname(__file__), "../../../../../../src")): from client.content.config.tabs import models from client.utils import api_call import streamlit as st @@ -577,7 +575,7 @@ def test_patch_model_success(self, monkeypatch): def test_delete_model_success(self, monkeypatch): """Test deleting a model""" - with temporary_sys_path(os.path.join(os.path.dirname(__file__), "../../../../src")): + with temporary_sys_path(os.path.join(os.path.dirname(__file__), "../../../../../../src")): from client.content.config.tabs import models from client.utils import api_call import streamlit as st diff --git a/tests/client/integration/content/config/tabs/test_oci.py b/test/integration/client/content/config/tabs/test_oci.py similarity index 100% rename from tests/client/integration/content/config/tabs/test_oci.py rename to test/integration/client/content/config/tabs/test_oci.py diff --git a/tests/client/integration/content/config/tabs/test_settings.py b/test/integration/client/content/config/tabs/test_settings.py similarity index 100% rename from tests/client/integration/content/config/tabs/test_settings.py rename to test/integration/client/content/config/tabs/test_settings.py diff --git a/tests/client/integration/content/config/test_config.py b/test/integration/client/content/config/test_config.py similarity index 98% rename from tests/client/integration/content/config/test_config.py rename to test/integration/client/content/config/test_config.py index 532cabd5..6b997f61 100644 --- a/tests/client/integration/content/config/test_config.py +++ b/test/integration/client/content/config/test_config.py @@ -5,8 +5,8 @@ """ # spell-checker: disable +from test.integration.client.conftest import create_tabs_mock, run_streamlit_test import streamlit as st -from conftest import create_tabs_mock, run_streamlit_test ############################################################################# @@ -134,9 +134,11 @@ def test_get_functions_called(self, app_server, app_test, monkeypatch): # Create mock factory to reduce local variables def create_mock(module, func_name): original = getattr(module, func_name) + def mock(*args, **kwargs): calls[func_name] = True return original(*args, **kwargs) + return mock # Set up all mocks @@ -145,7 +147,7 @@ def mock(*args, **kwargs): (databases, "get_databases"), (models, "get_models"), (oci, "get_oci"), - (mcp, "get_mcp") + (mcp, "get_mcp"), ]: monkeypatch.setattr(module, func_name, create_mock(module, func_name)) diff --git a/tests/client/integration/content/test_api_server.py b/test/integration/client/content/test_api_server.py similarity index 73% rename from tests/client/integration/content/test_api_server.py rename to test/integration/client/content/test_api_server.py index 3b4d1040..85cb7b11 100644 --- a/tests/client/integration/content/test_api_server.py +++ b/test/integration/client/content/test_api_server.py @@ -31,15 +31,27 @@ def test_copy_client_settings_success(self, app_test, app_server): # Store original value for cleanup original_auth_profile = at.session_state.client_settings["oci"]["auth_profile"] - # Check that Server/Client Identical - assert at.session_state.client_settings == at.session_state.server_settings + def settings_equal_ignoring_client(s1, s2): + """Compare settings while ignoring the 'client' field which is expected to differ.""" + s1_copy = {k: v for k, v in s1.items() if k != "client"} + s2_copy = {k: v for k, v in s2.items() if k != "client"} + return s1_copy == s2_copy + + # Check that Server/Client Identical (excluding 'client' field) + assert settings_equal_ignoring_client( + at.session_state.client_settings, at.session_state.server_settings + ) # Update Client Settings at.session_state.client_settings["oci"]["auth_profile"] = "TESTING" - assert at.session_state.client_settings != at.session_state.server_settings + assert not settings_equal_ignoring_client( + at.session_state.client_settings, at.session_state.server_settings + ) assert at.session_state.server_settings["oci"]["auth_profile"] != "TESTING" at.button(key="copy_client_settings").click().run() # Validate settings have been copied - assert at.session_state.client_settings == at.session_state.server_settings + assert settings_equal_ignoring_client( + at.session_state.client_settings, at.session_state.server_settings + ) assert at.session_state.server_settings["oci"]["auth_profile"] == "TESTING" # Clean up: restore original value both in session state and on server to avoid polluting other tests diff --git a/tests/client/integration/content/test_chatbot.py b/test/integration/client/content/test_chatbot.py similarity index 98% rename from tests/client/integration/content/test_chatbot.py rename to test/integration/client/content/test_chatbot.py index 5e0602a3..c7c46ab9 100644 --- a/tests/client/integration/content/test_chatbot.py +++ b/test/integration/client/content/test_chatbot.py @@ -5,7 +5,7 @@ """ # spell-checker: disable -from conftest import enable_test_models +from test.integration.client.conftest import enable_test_models, run_page_with_models_enabled ############################################################################# @@ -27,16 +27,7 @@ def test_disabled(self, app_server, app_test): def test_page_loads_with_enabled_model(self, app_server, app_test): """Test that chatbot page loads successfully when a language model is enabled""" - assert app_server is not None - at = app_test(self.ST_FILE) - - # Enable at least one language model - at = enable_test_models(at) - - at = at.run() - - # Verify page loaded without errors - assert not at.exception + run_page_with_models_enabled(app_server, app_test, self.ST_FILE) ############################################################################# diff --git a/test/integration/client/content/test_testbed.py b/test/integration/client/content/test_testbed.py new file mode 100644 index 00000000..eceb11ca --- /dev/null +++ b/test/integration/client/content/test_testbed.py @@ -0,0 +1,32 @@ +# pylint: disable=protected-access,import-error,import-outside-toplevel +""" +Copyright (c) 2024, 2025, Oracle and/or its affiliates. +Licensed under the Universal Permissive License v1.0 as shown at http://oss.oracle.com/licenses/upl. +""" +# spell-checker: disable + +from test.integration.client.conftest import run_page_with_models_enabled + + +############################################################################# +# Test Streamlit UI +############################################################################# +class TestStreamlit: + """Test the Streamlit UI""" + + ST_FILE = "../src/client/content/testbed.py" + + def test_disabled(self, app_server, app_test): + """Test everything is disabled as nothing configured""" + assert app_server is not None + at = app_test(self.ST_FILE).run() + # When nothing is configured, one of these messages appears (depending on check order) + valid_messages = [ + "No OpenAI compatible language models are configured and/or enabled. Disabling Testing Framework.", + "Database is not configured. Disabling Testbed.", + ] + assert at.error[0].value in valid_messages and at.error[0].icon == "🛑" + + def test_page_loads(self, app_server, app_test): + """Confirm page loads with model enabled""" + run_page_with_models_enabled(app_server, app_test, self.ST_FILE) diff --git a/test/integration/client/content/tools/__init__.py b/test/integration/client/content/tools/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/test/integration/client/content/tools/tabs/__init__.py b/test/integration/client/content/tools/tabs/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/client/integration/content/tools/tabs/test_prompt_eng.py b/test/integration/client/content/tools/tabs/test_prompt_eng.py similarity index 100% rename from tests/client/integration/content/tools/tabs/test_prompt_eng.py rename to test/integration/client/content/tools/tabs/test_prompt_eng.py diff --git a/tests/client/integration/content/tools/tabs/test_split_embed.py b/test/integration/client/content/tools/tabs/test_split_embed.py similarity index 99% rename from tests/client/integration/content/tools/tabs/test_split_embed.py rename to test/integration/client/content/tools/tabs/test_split_embed.py index d20998ba..f438ea9c 100644 --- a/tests/client/integration/content/tools/tabs/test_split_embed.py +++ b/test/integration/client/content/tools/tabs/test_split_embed.py @@ -6,8 +6,8 @@ # spell-checker: disable from unittest.mock import patch +from test.integration.client.conftest import enable_test_embed_models import pandas as pd -from conftest import enable_test_embed_models ############################################################################# @@ -15,6 +15,7 @@ ############################################################################# class MockState: """Mock session state for testing OCI-related functionality""" + def __init__(self): self.client_settings = {"oci": {"auth_profile": "DEFAULT"}} @@ -438,6 +439,7 @@ def test_update_functions(self, app_server, app_test, monkeypatch): class MockDynamicState: """Mock state with dynamically set attributes""" + def __init__(self): for key, value in mock_state.items(): setattr(self, key, value) @@ -460,7 +462,7 @@ def __getattr__(self, name): update_chunk_size_slider() assert state_mock.selected_chunk_size_slider == 800 - object.__setattr__(state_mock, 'selected_chunk_size_slider', 1200) + object.__setattr__(state_mock, "selected_chunk_size_slider", 1200) update_chunk_size_input() assert state_mock.selected_chunk_size_input == 1200 @@ -468,7 +470,7 @@ def __getattr__(self, name): update_chunk_overlap_slider() assert state_mock.selected_chunk_overlap_slider == 15 - object.__setattr__(state_mock, 'selected_chunk_overlap_slider', 25) + object.__setattr__(state_mock, "selected_chunk_overlap_slider", 25) update_chunk_overlap_input() assert state_mock.selected_chunk_overlap_input == 25 diff --git a/tests/client/integration/content/tools/test_tools.py b/test/integration/client/content/tools/test_tools.py similarity index 98% rename from tests/client/integration/content/tools/test_tools.py rename to test/integration/client/content/tools/test_tools.py index ba87c6de..76cc73c2 100644 --- a/tests/client/integration/content/tools/test_tools.py +++ b/test/integration/client/content/tools/test_tools.py @@ -5,7 +5,7 @@ """ # spell-checker: disable -from conftest import create_tabs_mock, run_streamlit_test +from test.integration.client.conftest import create_tabs_mock, run_streamlit_test ############################################################################# diff --git a/test/integration/client/utils/__init__.py b/test/integration/client/utils/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/client/integration/utils/test_st_footer.py b/test/integration/client/utils/test_st_footer.py similarity index 100% rename from tests/client/integration/utils/test_st_footer.py rename to test/integration/client/utils/test_st_footer.py diff --git a/tests/client/integration/utils/test_vs_options.py b/test/integration/client/utils/test_vs_options.py similarity index 100% rename from tests/client/integration/utils/test_vs_options.py rename to test/integration/client/utils/test_vs_options.py diff --git a/test/shared_fixtures.py b/test/shared_fixtures.py index 8c2a152d..cc4e634b 100644 --- a/test/shared_fixtures.py +++ b/test/shared_fixtures.py @@ -353,3 +353,55 @@ def clean_env(): os.environ[var] = value elif var in os.environ: del os.environ[var] + + +################################################# +# Vector Store Test Data +################################################# + +# Shared vector store test data used across client tests +SAMPLE_VECTOR_STORE_DATA = { + "alias": "test_alias", + "model": "openai/text-embed-3", + "chunk_size": 1000, + "chunk_overlap": 200, + "distance_metric": "cosine", + "index_type": "IVF", + "vector_store": "vs_test", +} + +SAMPLE_VECTOR_STORE_DATA_ALT = { + "alias": "alias2", + "model": "openai/text-embed-3", + "chunk_size": 500, + "chunk_overlap": 100, + "distance_metric": "euclidean", + "index_type": "HNSW", + "vector_store": "vs2", +} + + +@pytest.fixture +def sample_vector_store_data(): + """Sample vector store data for testing - standard configuration.""" + return SAMPLE_VECTOR_STORE_DATA.copy() + + +@pytest.fixture +def sample_vector_store_data_alt(): + """Alternative sample vector store data for testing - different configuration.""" + return SAMPLE_VECTOR_STORE_DATA_ALT.copy() + + +@pytest.fixture +def sample_vector_stores_list(): + """List of sample vector stores with different aliases for filtering tests.""" + vs1 = SAMPLE_VECTOR_STORE_DATA.copy() + vs1["alias"] = "vs1" + vs1.pop("vector_store", None) + + vs2 = SAMPLE_VECTOR_STORE_DATA_ALT.copy() + vs2["alias"] = "vs2" + vs2.pop("vector_store", None) + + return [vs1, vs2] diff --git a/test/unit/client/__init__.py b/test/unit/client/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/test/unit/client/conftest.py b/test/unit/client/conftest.py new file mode 100644 index 00000000..e4ffecb9 --- /dev/null +++ b/test/unit/client/conftest.py @@ -0,0 +1,44 @@ +# pylint: disable=import-error,redefined-outer-name,unused-import +""" +Copyright (c) 2024, 2025, Oracle and/or its affiliates. +Licensed under the Universal Permissive License v1.0 as shown at http://oss.oracle.com/licenses/upl. + +Unit test fixtures for client tests. Unit tests mock dependencies rather than +requiring a real server, but some fixtures help establish Streamlit session state. +""" +# spell-checker: disable + +import os +import sys + +# Re-export shared vector store fixtures for pytest discovery +from test.shared_fixtures import ( # noqa: F401 + sample_vector_store_data, + sample_vector_store_data_alt, + sample_vector_stores_list, +) + +import pytest +from streamlit import session_state as state + +# Add src to path for client imports +SRC_PATH = os.path.join(os.path.dirname(__file__), "..", "..", "..", "src") +if SRC_PATH not in sys.path: + sys.path.insert(0, SRC_PATH) + + +@pytest.fixture(scope="session") +def app_server(): + """ + Minimal fixture for unit tests that just need session state initialized. + + Unlike integration tests, this doesn't actually start a server. + It just ensures Streamlit session state is available for testing. + """ + # Initialize basic state required by client modules + if not hasattr(state, "server"): + state.server = {"url": "http://localhost", "port": 8000, "key": "test-key"} + if not hasattr(state, "client_settings"): + state.client_settings = {"client": "test-client", "ll_model": {}} + + yield True # Just return True to indicate fixture is available diff --git a/test/unit/client/content/__init__.py b/test/unit/client/content/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/test/unit/client/content/config/__init__.py b/test/unit/client/content/config/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/test/unit/client/content/config/tabs/__init__.py b/test/unit/client/content/config/tabs/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/client/unit/content/config/tabs/test_mcp_unit.py b/test/unit/client/content/config/tabs/test_mcp_unit.py similarity index 100% rename from tests/client/unit/content/config/tabs/test_mcp_unit.py rename to test/unit/client/content/config/tabs/test_mcp_unit.py diff --git a/tests/client/unit/content/config/tabs/test_models_unit.py b/test/unit/client/content/config/tabs/test_models_unit.py similarity index 100% rename from tests/client/unit/content/config/tabs/test_models_unit.py rename to test/unit/client/content/config/tabs/test_models_unit.py diff --git a/tests/client/unit/content/test_chatbot_unit.py b/test/unit/client/content/test_chatbot_unit.py similarity index 96% rename from tests/client/unit/content/test_chatbot_unit.py rename to test/unit/client/content/test_chatbot_unit.py index 9d3afd0c..a2683a23 100644 --- a/tests/client/unit/content/test_chatbot_unit.py +++ b/test/unit/client/content/test_chatbot_unit.py @@ -39,9 +39,9 @@ def test_show_vector_search_refs_with_metadata(self, monkeypatch): monkeypatch.setattr(st, "columns", mock_columns) monkeypatch.setattr(st, "subheader", mock_subheader) - # Create test context - context = [ - [ + # Create test context - now expects dict with "documents" key + context = { + "documents": [ { "page_content": "This is chunk 1 content", "metadata": {"filename": "doc1.pdf", "source": "/path/to/doc1.pdf", "page": 1}, @@ -55,8 +55,8 @@ def test_show_vector_search_refs_with_metadata(self, monkeypatch): "metadata": {"filename": "doc1.pdf", "source": "/path/to/doc1.pdf", "page": 3}, }, ], - "test query", - ] + "context_input": "test query", + } # Call function chatbot.show_vector_search_refs(context) @@ -64,9 +64,6 @@ def test_show_vector_search_refs_with_metadata(self, monkeypatch): # Verify References header was shown assert any("References" in str(call) for call in mock_markdown.call_args_list) - # Verify Notes with query shown - assert any("test query" in str(call) for call in mock_markdown.call_args_list) - def test_show_vector_search_refs_missing_metadata(self, monkeypatch): """Test showing vector search references when metadata is missing""" from client.content import chatbot @@ -88,16 +85,16 @@ def test_show_vector_search_refs_missing_metadata(self, monkeypatch): monkeypatch.setattr(st, "columns", mock_columns) monkeypatch.setattr(st, "subheader", mock_subheader) - # Create test context with missing metadata - context = [ - [ + # Create test context with missing metadata - now expects dict with "documents" key + context = { + "documents": [ { "page_content": "Content without metadata", "metadata": {}, # Empty metadata - will cause KeyError } ], - "test query", - ] + "context_input": "test query", + } # Call function - should handle KeyError gracefully chatbot.show_vector_search_refs(context) @@ -308,10 +305,10 @@ def test_display_chat_history_with_vector_search(self, monkeypatch): mock_show_refs = MagicMock() monkeypatch.setattr(chatbot, "show_vector_search_refs", mock_show_refs) - # Create history with tool message - vector_refs = [[{"page_content": "content", "metadata": {}}], "query"] + # Create history with tool message - use correct tool name "optimizer_vs-retriever" + vector_refs = {"documents": [{"page_content": "content", "metadata": {}}], "context_input": "query"} history = [ - {"role": "tool", "name": "oraclevs_tool", "content": json.dumps(vector_refs)}, + {"role": "tool", "name": "optimizer_vs-retriever", "content": json.dumps(vector_refs)}, {"role": "ai", "content": "Based on the documents..."}, ] diff --git a/tests/client/unit/content/test_testbed_unit.py b/test/unit/client/content/test_testbed_unit.py similarity index 100% rename from tests/client/unit/content/test_testbed_unit.py rename to test/unit/client/content/test_testbed_unit.py diff --git a/test/unit/client/content/tools/__init__.py b/test/unit/client/content/tools/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/test/unit/client/content/tools/tabs/__init__.py b/test/unit/client/content/tools/tabs/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/client/unit/content/tools/tabs/test_split_embed_unit.py b/test/unit/client/content/tools/tabs/test_split_embed_unit.py similarity index 100% rename from tests/client/unit/content/tools/tabs/test_split_embed_unit.py rename to test/unit/client/content/tools/tabs/test_split_embed_unit.py diff --git a/test/unit/client/utils/__init__.py b/test/unit/client/utils/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/client/unit/utils/test_client_unit.py b/test/unit/client/utils/test_client_unit.py similarity index 100% rename from tests/client/unit/utils/test_client_unit.py rename to test/unit/client/utils/test_client_unit.py diff --git a/tests/client/unit/utils/test_st_common_unit.py b/test/unit/client/utils/test_st_common_unit.py similarity index 100% rename from tests/client/unit/utils/test_st_common_unit.py rename to test/unit/client/utils/test_st_common_unit.py diff --git a/tests/client/unit/utils/test_vs_options_unit.py b/test/unit/client/utils/test_vs_options_unit.py similarity index 100% rename from tests/client/unit/utils/test_vs_options_unit.py rename to test/unit/client/utils/test_vs_options_unit.py diff --git a/test/unit/server/api/conftest.py b/test/unit/server/api/conftest.py index d10ed367..858dca33 100644 --- a/test/unit/server/api/conftest.py +++ b/test/unit/server/api/conftest.py @@ -24,6 +24,11 @@ TEST_DB_DSN, ) +# Import TEST_DB_CONFIG for use in this file +# Note: db_container, db_connection, db_transaction fixtures are inherited +# from test/conftest.py - do not re-import here to avoid multiple containers +from test.db_fixtures import TEST_DB_CONFIG + import pytest from common.schema import ( diff --git a/tests/README.md b/tests/README.md deleted file mode 100644 index 40dabc8f..00000000 --- a/tests/README.md +++ /dev/null @@ -1,77 +0,0 @@ -# AI Optimizer for Apps Tests - - -This directory contains Tests for the AI Optimizer for Apps. Tests are automatically -run as part of opening a new Pull Requests. All tests must pass to enable merging. - -## Installing Test Dependencies - -1. Create and activate a Python Virtual Environment: - - ```bash - python3.11 -m venv .venv --copies - source .venv/bin/activate - pip3.11 install --upgrade pip wheel setuptools uv - ``` - -1. Install the Python modules: - - ```bash - uv pip install -e ".[all-test]" - ``` - -## Running Tests - -All tests can be run by using the following command from the **project root**: - -```bash -pytest tests -v [--log-cli-level=DEBUG] -``` - -### Server Endpoint Tests - -To run the server endpoint tests, use the following command from the **project root**: - -```bash -pytest tests/server -v [--log-cli-level=DEBUG] -``` - -These tests verify the functionality of the endpoints by establishing: -- A real FastAPI server -- A Docker container used for database tests -- Mocks for external dependencies (OCI) - -### Streamlit Tests - -To run the Streamlit page tests, use the following command from the **project root**: - -```bash -pytest tests/client -v [--log-cli-level=DEBUG] -``` - -These tests verify the functionality of the Streamlit app by establishing: -- A real AI Optimizer API server -- A Docker container used for database tests - -## Test Structure - -### Server Endpoint Tests - -The server endpoint tests are organized into two classes: -- `TestNoAuthEndpoints`: Tests that verify authentication is required -- `TestEndpoints`: Tests that verify the functionality of the endpoints - -### Streamlit Settings Page Tests - -The Streamlit settings page tests are organized into two classes: -- `TestFunctions`: Tests for the utility functions -- `TestUI`: Tests for the Streamlit UI components - -## Test Environment - -The tests use a combination of real and mocked components: -- A real FastAPI server is started for the endpoint tests -- A Docker container is used for database tests -- Streamlit components are tested using the AppTest framework -- External dependencies are mocked where appropriate -- To see the elements in the page for testing; use: `print([el for el in at.main])` diff --git a/tests/client/integration/content/test_testbed.py b/tests/client/integration/content/test_testbed.py deleted file mode 100644 index 0c3e4d44..00000000 --- a/tests/client/integration/content/test_testbed.py +++ /dev/null @@ -1,665 +0,0 @@ -# pylint: disable=protected-access,import-error,import-outside-toplevel -""" -Copyright (c) 2024, 2025, Oracle and/or its affiliates. -Licensed under the Universal Permissive License v1.0 as shown at http://oss.oracle.com/licenses/upl. -""" -# spell-checker: disable - -import os -from unittest.mock import patch -from conftest import setup_test_database, enable_test_models, temporary_sys_path - - -############################################################################# -# Test Streamlit UI -############################################################################# -class TestStreamlit: - """Test the Streamlit UI""" - - # Streamlit File path - ST_FILE = "../src/client/content/testbed.py" - - def test_initialization(self, app_server, app_test, db_container): - """Test initialization of the testbed component with real server data and database""" - assert app_server is not None - assert db_container is not None - - # Initialize app_test - now loads full config from server - at = app_test(self.ST_FILE) - - # Set up prerequisites using helper functions - at = setup_test_database(at) - at = enable_test_models(at) - - # Now run the app - at.run() - - # Verify specific widgets that should exist - # The testbed page should render these widgets when initialized - radio_widgets = at.get("radio") - assert len(radio_widgets) >= 1, ( - f"Expected at least 1 radio widget for testset source selection. Errors: {[e.value for e in at.error]}" - ) - - button_widgets = at.get("button") - assert len(button_widgets) >= 1, "Expected at least 1 button widget" - - file_uploader_widgets = at.get("file_uploader") - assert len(file_uploader_widgets) >= 1, "Expected at least 1 file uploader widget" - - # Test passes if the expected widgets are rendered - - def test_testset_source_selection(self, app_server, app_test, db_container): - """Test selection of test sets from different sources with real server data""" - assert app_server is not None - assert db_container is not None - - # Initialize app_test - now loads full config from server - at = app_test(self.ST_FILE) - - # Set up prerequisites using helper functions - at = setup_test_database(at) - at = enable_test_models(at) - - # Run the app to initialize all widgets - at.run() - - # Verify the expected widgets are present - radio_widgets = at.get("radio") - assert len(radio_widgets) > 0, f"Expected radio widgets. Errors: {[e.value for e in at.error]}" - - file_uploader_widgets = at.get("file_uploader") - assert len(file_uploader_widgets) > 0, "Expected file uploader widgets" - - # Test passes if the expected widgets are rendered - - def test_testset_generation_with_saved_ll_model(self, app_server, app_test, db_container): - """Test that testset generation UI correctly restores saved language model preferences - - This test verifies that when a user has a saved language model preference, - the UI correctly looks up the model's index from the language models list - (not the embedding models list). - """ - assert app_server is not None - assert db_container is not None - - # Initialize app_test - at = app_test(self.ST_FILE) - - # Set up prerequisites using helper functions - at = setup_test_database(at) - - # Create realistic model configurations with distinct LLM and embedding models - at.session_state.model_configs = [ - { - "id": "gpt-4o-mini", - "type": "ll", - "enabled": True, - "provider": "openai", - "openai_compat": True, - }, - { - "id": "gpt-4o", - "type": "ll", - "enabled": True, - "provider": "openai", - "openai_compat": True, - }, - { - "id": "text-embedding-3-small", - "type": "embed", - "enabled": True, - "provider": "openai", - "openai_compat": True, - }, - { - "id": "embed-english-v3.0", - "type": "embed", - "enabled": True, - "provider": "cohere", - "openai_compat": True, - }, - ] - - # Initialize client_settings with a saved LLM preference - # This simulates a user who previously selected a language model - if "client_settings" not in at.session_state: - at.session_state.client_settings = {} - if "testbed" not in at.session_state.client_settings: - at.session_state.client_settings["testbed"] = {} - - # Set a language model preference that exists in LL list but NOT in embed list - at.session_state.client_settings["testbed"]["qa_ll_model"] = "openai/gpt-4o-mini" - - # Run the app - should render without error - at.run() - - # Toggle to "Generate Q&A Test Set" mode - generate_toggle = at.get("toggle") - assert len(generate_toggle) > 0, "Expected toggle widget for 'Generate Q&A Test Set'" - - # This should not raise ValueError about model not being in list - generate_toggle[0].set_value(True).run() - - # Verify no exceptions occurred during rendering - assert not at.exception, f"Rendering failed with exception: {at.exception}" - - # Verify the selectboxes rendered correctly - selectboxes = at.get("selectbox") - assert len(selectboxes) >= 2, "Should have at least 2 selectboxes (LLM and embed model)" - - # Verify no errors were thrown - errors = at.get("error") - assert len(errors) == 0, f"Expected no errors, but got: {[e.value for e in errors]}" - - def test_testset_generation_default_ll_model(self, app_server, app_test, db_container): - """Test that testset generation UI sets correct default language model - - This test verifies that when no saved language model preference exists, - the UI correctly initializes the default from the language models list - (not the embedding models list). - """ - assert app_server is not None - assert db_container is not None - - # Initialize app_test - at = app_test(self.ST_FILE) - - # Set up prerequisites using helper functions - at = setup_test_database(at) - - # Create realistic model configurations with distinct LLM and embedding models - at.session_state.model_configs = [ - { - "id": "gpt-4o-mini", - "type": "ll", - "enabled": True, - "provider": "openai", - "openai_compat": True, - }, - { - "id": "gpt-4o", - "type": "ll", - "enabled": True, - "provider": "openai", - "openai_compat": True, - }, - { - "id": "text-embedding-3-small", - "type": "embed", - "enabled": True, - "provider": "openai", - "openai_compat": True, - }, - { - "id": "embed-english-v3.0", - "type": "embed", - "enabled": True, - "provider": "cohere", - "openai_compat": True, - }, - ] - - # Initialize client_settings but DON'T set saved preferences - # This triggers the default initialization code path - if "client_settings" not in at.session_state: - at.session_state.client_settings = {} - if "testbed" not in at.session_state.client_settings: - at.session_state.client_settings["testbed"] = {} - - # Run the app - should render without error - at.run() - - # Toggle to "Generate Q&A Test Set" mode - generate_toggle = at.get("toggle") - assert len(generate_toggle) > 0, "Expected toggle widget for 'Generate Q&A Test Set'" - - # This should not crash - defaults should be set correctly - generate_toggle[0].set_value(True).run() - - # Verify no exceptions occurred during rendering - assert not at.exception, f"Rendering failed with exception: {at.exception}" - - # Verify the selectboxes rendered correctly - selectboxes = at.get("selectbox") - assert len(selectboxes) >= 2, "Should have at least 2 selectboxes (LLM and embed model)" - - # Verify the default qa_ll_model is actually a language model, not an embedding model - qa_ll_model = at.session_state.client_settings["testbed"]["qa_ll_model"] - assert qa_ll_model in ["openai/gpt-4o-mini", "openai/gpt-4o"], ( - f"Default qa_ll_model should be a language model, got: {qa_ll_model}" - ) - - # Verify no errors were thrown - errors = at.get("error") - assert len(errors) == 0, f"Expected no errors, but got: {[e.value for e in errors]}" - - @patch("client.utils.api_call.post") - def test_evaluate_testset(self, mock_post, app_test, monkeypatch): - """Test evaluation of a test set""" - - # Mock the API responses for get_models - def mock_get(endpoint=None, **_kwargs): - if endpoint == "v1/models": - return [ - { - "id": "test-ll-model", - "type": "ll", - "enabled": True, - "url": "http://test.url", - "openai_compat": True, - }, - { - "id": "test-embed-model", - "type": "embed", - "enabled": True, - "url": "http://test.url", - "openai_compat": True, - }, - ] - return {} - - monkeypatch.setattr("client.utils.api_call.get", mock_get) - - # Mock API post response for evaluation - mock_post.return_value = { - "id": "eval123", - "score": 0.85, - "results": [{"question": "Test question 1", "score": 0.9}, {"question": "Test question 2", "score": 0.8}], - } - - # Mock functions that make external calls - monkeypatch.setattr("common.functions.is_url_accessible", lambda url: (True, "")) - monkeypatch.setattr("streamlit.cache_resource", lambda *args, **kwargs: lambda func: func) - monkeypatch.setattr("client.utils.st_common.is_db_configured", lambda: True) - - # Initialize app_test - at = app_test(self.ST_FILE) - - # Set up session state requirements - at.session_state.user_settings = { - "client": "test_client", - "oci": {"auth_profile": "DEFAULT"}, - "vector_search": {"database": "DEFAULT"}, - } - - at.session_state.ll_model_enabled = { - "test-ll-model": {"url": "http://test.url", "openai_compat": True, "enabled": True} - } - - at.session_state.embed_model_enabled = { - "test-embed-model": {"url": "http://test.url", "openai_compat": True, "enabled": True} - } - - # Run the app to initialize all widgets - at = at.run() - - # For this minimal test, just verify the app runs without error - # This test is valuable to ensure mocking works properly - assert True - - # Test passes if the app runs without errors - - @patch("client.content.testbed.st_common") - @patch("client.content.testbed.get_testbed_db_testsets") - def test_reset_testset_function(self, mock_get_testbed, mock_st_common): - """Test the reset_testset function""" - # Import the module to test the function directly - with temporary_sys_path(os.path.join(os.path.dirname(__file__), "../../../src")): - from client.content import testbed - - # Test reset_testset without cache - testbed.reset_testset(cache=False) - - # Verify clear_state_key was called for all expected keys - expected_calls = [ - "testbed", - "selected_testset_name", - "testbed_qa", - "testbed_db_testsets", - "testbed_evaluations", - ] - - for key in expected_calls: - mock_st_common.clear_state_key.assert_any_call(key) - - # Test reset_testset with cache - mock_st_common.reset_mock() - testbed.reset_testset(cache=True) - - # Should still call clear_state_key for all keys - for key in expected_calls: - mock_st_common.clear_state_key.assert_any_call(key) - - # Should also call clear on get_testbed_db_testsets - mock_get_testbed.clear.assert_called_once() - - def test_download_file_fragment(self): - """Test the download_file fragment function""" - # Since the download_file function is a streamlit fragment, - # we can only test that it exists and is callable - with temporary_sys_path(os.path.join(os.path.dirname(__file__), "../../../src")): - from client.content import testbed - - # Verify function exists and is callable - assert hasattr(testbed, "download_file") - assert callable(testbed.download_file) - - # Note: The actual streamlit fragment functionality - # is tested through the integration tests - - def test_update_record_function_logic(self): - """Test the update_record function logic""" - # Test that the function exists and is callable - with temporary_sys_path(os.path.join(os.path.dirname(__file__), "../../../src")): - from client.content import testbed - - assert hasattr(testbed, "update_record") - assert callable(testbed.update_record) - - # Note: The actual functionality is tested in integration tests - # since it depends heavily on Streamlit's session state - - def test_delete_record_function_exists(self): - """Test the delete_record function exists""" - # Test that the function exists and is callable - with temporary_sys_path(os.path.join(os.path.dirname(__file__), "../../../src")): - from client.content import testbed - - assert hasattr(testbed, "delete_record") - assert callable(testbed.delete_record) - - # Note: The actual functionality is tested in integration tests - # since it depends heavily on Streamlit's session state - - @patch("client.utils.api_call.get") - def test_get_testbed_db_testsets(self, mock_get, app_test): - """Test the get_testbed_db_testsets cached function""" - # Ensure app_test fixture is available for proper test context - assert app_test is not None - - with temporary_sys_path(os.path.join(os.path.dirname(__file__), "../../../src")): - from client.content import testbed - - # Mock API response - expected_response = { - "testsets": [ - {"tid": "test1", "name": "Test Set 1", "created": "2024-01-01"}, - {"tid": "test2", "name": "Test Set 2", "created": "2024-01-02"}, - ] - } - mock_get.return_value = expected_response - - # Test function call - result = testbed.get_testbed_db_testsets() - - # Verify API was called correctly - mock_get.assert_called_once_with(endpoint="v1/testbed/testsets") - assert result == expected_response - - def test_qa_delete_function_exists(self): - """Test the qa_delete function exists and is callable""" - with temporary_sys_path(os.path.join(os.path.dirname(__file__), "../../../src")): - from client.content import testbed - - assert hasattr(testbed, "qa_delete") - assert callable(testbed.qa_delete) - - # Note: Full functionality testing requires Streamlit session state - # and is covered by integration tests - - def test_qa_update_db_function_exists(self): - """Test the qa_update_db function exists and is callable""" - with temporary_sys_path(os.path.join(os.path.dirname(__file__), "../../../src")): - from client.content import testbed - - assert hasattr(testbed, "qa_update_db") - assert callable(testbed.qa_update_db) - - # Note: Full functionality testing requires Streamlit session state - # and is covered by integration tests - - def test_qa_update_gui_function_exists(self): - """Test the qa_update_gui function exists and is callable""" - with temporary_sys_path(os.path.join(os.path.dirname(__file__), "../../../src")): - from client.content import testbed - - assert hasattr(testbed, "qa_update_gui") - assert callable(testbed.qa_update_gui) - - # Note: Full UI functionality testing is covered by integration tests - - def test_evaluation_report_function_exists(self): - """Test the evaluation_report function exists and is callable""" - with temporary_sys_path(os.path.join(os.path.dirname(__file__), "../../../src")): - from client.content import testbed - - assert hasattr(testbed, "evaluation_report") - assert callable(testbed.evaluation_report) - - # Note: Full functionality testing with Streamlit dialogs - # is covered by integration tests - - def test_evaluation_report_with_eid_parameter(self): - """Test evaluation_report function accepts eid parameter""" - import inspect - - with temporary_sys_path(os.path.join(os.path.dirname(__file__), "../../../src")): - from client.content import testbed - - # Get function signature and verify eid parameter exists - sig = inspect.signature(testbed.evaluation_report) - assert "eid" in sig.parameters - assert "report" in sig.parameters - - # Verify function is callable - assert callable(testbed.evaluation_report) - - # Note: Full API integration testing is covered by integration tests - - def test_generate_qa_button_regression(self, app_server, app_test, db_container): - """Test that Generate Q&A button logic correctly handles testset_id check""" - assert app_server is not None - assert db_container is not None - - # Initialize app_test - at = app_test(self.ST_FILE) - - # Set up prerequisites using helper functions - at = setup_test_database(at) - - # Create model configurations - at.session_state.model_configs = [ - { - "id": "gpt-4o-mini", - "type": "ll", - "enabled": True, - "provider": "openai", - "openai_compat": True, - }, - { - "id": "text-embedding-3-small", - "type": "embed", - "enabled": True, - "provider": "openai", - "openai_compat": True, - }, - ] - - # Initialize client_settings - if "client_settings" not in at.session_state: - at.session_state.client_settings = {} - if "testbed" not in at.session_state.client_settings: - at.session_state.client_settings["testbed"] = {} - - # Run the app in default mode (loading existing test sets) - at.run() - - # In this mode, button should be disabled if testset_id is None - # (which it is initially) - load_button_default = at.button(key="load_tests") - assert load_button_default is not None, "Expected button with key 'load_tests' in default mode" - # Button should be disabled because we're in load mode with no testset_id - assert load_button_default.disabled, "Load Q&A button should be disabled without testset_id in load mode" - - # Now toggle to "Generate Q&A Test Set" mode - generate_toggle = at.toggle(key="selected_generate_test") - assert generate_toggle is not None, "Expected toggle with key 'selected_generate_test'" - generate_toggle.set_value(True).run() - - # In generate mode, testset_id should NOT affect button state - # The button should only be disabled if no file is uploaded - load_button_generate = at.button(key="load_tests") - assert load_button_generate is not None, "Expected button with key 'load_tests' in generate mode" - - # The button should be disabled because no file is uploaded yet, - # NOT because testset_id is None (which was the regression) - assert load_button_generate.disabled, "Generate Q&A button should be disabled without a file" - - # Verify we have a file uploader in generate mode - file_uploaders = at.get("file_uploader") - assert len(file_uploaders) > 0, "Expected at least one file uploader in generate mode" - - # The test passes if: - # 1. In load mode, button is disabled when testset_id is None - # 2. In generate mode, button state depends on file upload, not testset_id - # This confirms the regression fix is working correctly - - -############################################################################# -# Integration Tests with Real Database -############################################################################# -class TestTestbedDatabaseIntegration: - """Integration tests using real database container""" - - # Streamlit File path - ST_FILE = "../src/client/content/testbed.py" - - def test_testbed_with_real_database_simplified(self, app_server, db_container): - """Test basic testbed functionality with real database container (simplified)""" - assert app_server is not None - assert db_container is not None - - # Verify the database container exists and is not stopped - assert db_container.status in ["running", "created"] - - # This test verifies that: - # 1. The app server is running - # 2. The database container is available - # 3. The testbed module can be imported and has expected functions - with temporary_sys_path(os.path.join(os.path.dirname(__file__), "../../../src")): - from client.content import testbed - - # Verify key testbed functions exist - testbed_functions = [ - "main", - "reset_testset", - "get_testbed_db_testsets", - "qa_update_gui", - "evaluation_report", - ] - - for func_name in testbed_functions: - assert hasattr(testbed, func_name), f"Function {func_name} not found" - assert callable(getattr(testbed, func_name)), f"Function {func_name} is not callable" - - def test_testset_functions_callable(self, app_server, db_container): - """Test testset functions are callable (simplified)""" - assert app_server is not None - assert db_container is not None - - # Test that testbed functions can be imported and are callable - with temporary_sys_path(os.path.join(os.path.dirname(__file__), "../../../src")): - from client.content import testbed - - # Test functions that interact with the API/database - api_functions = ["get_testbed_db_testsets", "qa_delete", "qa_update_db"] - - for func_name in api_functions: - assert hasattr(testbed, func_name), f"Function {func_name} not found" - assert callable(getattr(testbed, func_name)), f"Function {func_name} is not callable" - - def test_database_integration_basic(self, app_server, db_container): - """Test basic database integration functionality""" - assert app_server is not None - assert db_container is not None - - # Verify the database container exists and is not stopped - assert db_container.status in ["running", "created"] - - # This is a simplified integration test that verifies: - # 1. The app server is running - # 2. The database container is running - # 3. The testbed module can be imported - with temporary_sys_path(os.path.join(os.path.dirname(__file__), "../../../src")): - from client.content import testbed - - # Verify all main functions are present and callable - main_functions = [ - "reset_testset", - "download_file", - "evaluation_report", - "get_testbed_db_testsets", - "qa_delete", - "qa_update_db", - "update_record", - "delete_record", - "qa_update_gui", - "main", - ] - - for func_name in main_functions: - assert hasattr(testbed, func_name), f"Function {func_name} not found" - assert callable(getattr(testbed, func_name)), f"Function {func_name} is not callable" - - def test_load_button_enabled_with_database_testset(self, app_server, app_test, db_container): - """Test that Load Q&A button is enabled when a database test set is selected""" - assert app_server is not None - assert db_container is not None - - # Initialize app_test - at = app_test(self.ST_FILE) - - # Set up prerequisites using helper functions - at = setup_test_database(at) - at = enable_test_models(at) - - # Mock database test sets to ensure we have some available - mock_testsets = [ - {"tid": "test1", "name": "Test Set 1", "created": "2024-01-01 10:00:00"}, - {"tid": "test2", "name": "Test Set 2", "created": "2024-01-02 11:00:00"}, - ] - at.session_state.testbed_db_testsets = mock_testsets - - # Run the app with "Generate Q&A Test Set" toggled OFF (default) - at.run() - - # Verify the toggle is in the correct state - generate_toggle = at.toggle(key="selected_generate_test") - assert generate_toggle is not None, "Expected toggle widget for 'Generate Q&A Test Set'" - assert generate_toggle.value is False, "Toggle should be OFF by default (existing test set mode)" - - # Verify we have a radio button for TestSet Source - radio_widgets = at.radio(key="radio_test_source") - assert radio_widgets is not None, "Expected radio widget for testset source selection" - - # Verify we have a selectbox for database test sets - selectbox = at.selectbox(key="selected_db_testset") - assert selectbox is not None, "Expected selectbox for database test set selection" - - # The selectbox should have our mock test sets as options - expected_options = ["Test Set 1 -- Created: 2024-01-01 10:00:00", "Test Set 2 -- Created: 2024-01-02 11:00:00"] - assert selectbox.options == expected_options, f"Expected options {expected_options}, got {selectbox.options}" - - # Select a test set - selectbox.set_value(expected_options[0]).run() - - # Get the Load Q&A button - load_button = at.button(key="load_tests") - assert load_button is not None, "Expected button with key 'load_tests'" - - # CRITICAL TEST: Button should be ENABLED when a database test set is selected - assert not load_button.disabled, ( - "Load Q&A button should be ENABLED when a database test set is selected. " - "This indicates the bug fix is not working correctly." - ) diff --git a/tests/client/integration/utils/test_st_common.py b/tests/client/integration/utils/test_st_common.py deleted file mode 100644 index 0103607a..00000000 --- a/tests/client/integration/utils/test_st_common.py +++ /dev/null @@ -1,9 +0,0 @@ -# pylint: disable=protected-access,import-error,import-outside-toplevel,redefined-outer-name -""" -Copyright (c) 2024, 2025, Oracle and/or its affiliates. -Licensed under the Universal Permissive License v1.0 as shown at http://oss.oracle.com/licenses/upl. - -Integration tests for st_common utilities. -Vector store selection tests have been moved to test_vs_options.py -""" -# spell-checker: disable diff --git a/tests/common/test_functions_sql.py b/tests/common/test_functions_sql.py deleted file mode 100644 index 44ad9d3b..00000000 --- a/tests/common/test_functions_sql.py +++ /dev/null @@ -1,139 +0,0 @@ -# pylint: disable=protected-access,import-error,import-outside-toplevel,redefined-outer-name -""" -Copyright (c) 2024, 2025, Oracle and/or its affiliates. -Licensed under the Universal Permissive License v1.0 as shown at http://oss.oracle.com/licenses/upl. - -Unit tests for client-side SQL validation integration - -Note: Tests for common.functions.is_sql_accessible have been migrated to -test/unit/common/test_functions.py. This file contains only client-side tests -for FileSourceData and UI error display logic. -""" -# spell-checker: disable - -from unittest.mock import patch -import pytest - -from common import functions - - -class TestFileSourceDataSQLValidation: - """ - Tests for FileSourceData.is_valid() method with SQL source - - These tests verify that the is_valid() method correctly uses the return value - from is_sql_accessible() function. The fix ensures that when is_sql_accessible - returns (True, ""), is_valid() should return True, and vice versa. - """ - - def test_is_valid_returns_true_when_sql_accessible_succeeds(self): - """Test that is_valid() returns True when SQL validation succeeds""" - from client.content.tools.tabs.split_embed import FileSourceData - - # Mock is_sql_accessible to return success (True, "") - with patch.object(functions, "is_sql_accessible", return_value=(True, "")): - data = FileSourceData( - file_source="SQL", - sql_connection="user/pass@dsn", - sql_query="SELECT text FROM docs" - ) - - result = data.is_valid() - - # The fix ensures this assertion passes - assert result is True, ( - "FileSourceData.is_valid() should return True when is_sql_accessible returns (True, ''). " - "This test will fail until the bug fix is applied." - ) - - def test_is_valid_returns_false_when_sql_accessible_fails(self): - """Test that is_valid() returns False when SQL validation fails""" - from client.content.tools.tabs.split_embed import FileSourceData - - # Mock is_sql_accessible to return failure (False, "error message") - with patch.object(functions, "is_sql_accessible", return_value=(False, "Connection failed")): - data = FileSourceData( - file_source="SQL", - sql_connection="user/pass@dsn", - sql_query="INVALID SQL" - ) - - result = data.is_valid() - - assert result is False, ( - "FileSourceData.is_valid() should return False when is_sql_accessible returns (False, msg)" - ) - - def test_is_valid_with_various_error_conditions(self): - """Test is_valid() with various SQL error conditions""" - from client.content.tools.tabs.split_embed import FileSourceData - - test_cases = [ - ((False, "Empty table"), False, "Empty result set"), - ((False, "Wrong connection"), False, "Invalid connection string"), - ((False, "2 columns"), False, "Multiple columns"), - ((False, "VARCHAR expected"), False, "Wrong column type"), - ] - - for sql_result, expected_valid, description in test_cases: - with patch.object(functions, "is_sql_accessible", return_value=sql_result): - data = FileSourceData( - file_source="SQL", - sql_connection="user/pass@dsn", - sql_query="SELECT text FROM docs" - ) - - result = data.is_valid() - - assert result == expected_valid, f"Failed for case: {description}" - - -class TestRenderLoadKBSectionErrorDisplay: - """ - Tests for the error display logic in _render_load_kb_section - - The fix changes line 272 from: - if is_invalid or msg: - to: - if not(is_invalid) or msg: - - This ensures errors are displayed when SQL validation actually fails. - """ - - def test_error_displayed_when_sql_validation_fails(self): - """Test that error is displayed when is_sql_accessible returns (False, msg)""" - # When is_sql_accessible returns (False, "Error message") - # The unpacked values are: is_invalid=False, msg="Error message" - # The condition should display error: not(False) or "Error message" = True or True = True - - is_invalid, msg = False, "Connection failed" - - # Simulate the logic in line 272 after the fix - should_display_error = not(is_invalid) or bool(msg) - - assert should_display_error is True, ( - "Error should be displayed when SQL validation fails. " - "is_sql_accessible returned (False, 'Connection failed'), " - "which should trigger error display." - ) - - def test_no_error_displayed_when_sql_validation_succeeds(self): - """Test that no error is displayed when is_sql_accessible returns (True, '')""" - # When is_sql_accessible returns (True, "") - # The unpacked values are: is_invalid=True, msg="" - # The condition should NOT display error: not(True) or "" = False or False = False - - is_invalid, msg = True, "" - - # Simulate the logic in line 272 after the fix - should_display_error = not(is_invalid) or bool(msg) - - assert should_display_error is False, ( - "Error should NOT be displayed when SQL validation succeeds. " - "is_sql_accessible returned (True, ''), " - "which should NOT trigger error display." - ) - - -if __name__ == "__main__": - pytest.main([__file__, "-v"]) diff --git a/tests/conftest.py b/tests/conftest.py deleted file mode 100644 index e70ed3f6..00000000 --- a/tests/conftest.py +++ /dev/null @@ -1,509 +0,0 @@ -""" -Copyright (c) 2024, 2025, Oracle and/or its affiliates. -Licensed under the Universal Permissive License v1.0 as shown at http://oss.oracle.com/licenses/upl. -""" -# spell-checker: disable -# pylint: disable=protected-access import-error import-outside-toplevel consider-using-with - -import os -import sys -import time -import socket -import shutil -import subprocess -from pathlib import Path -from typing import Generator, Optional -from contextlib import contextmanager - -import requests -import numpy as np -import pytest -import docker -from docker.errors import DockerException -from docker.models.containers import Container - -# This contains all the environment variables we consume on startup (add as required) -# Used to clear testing environment from users env; Do before any additional imports -API_VARS = ["API_SERVER_KEY", "API_SERVER_URL", "API_SERVER_PORT"] -DB_VARS = ["DB_USERNAME", "DB_PASSWORD", "DB_DSN", "DB_WALLET_PASSWORD", "TNS_ADMIN"] -MODEL_VARS = ["ON_PREM_OLLAMA_URL", "ON_PREM_HF_URL", "OPENAI_API_KEY", "PPLX_API_KEY", "COHERE_API_KEY"] -for env_var in [*API_VARS, *DB_VARS, *MODEL_VARS, *[var for var in os.environ if var.startswith("OCI_")]]: - os.environ.pop(env_var, None) - -# Setup a Test Configurations -TEST_CONFIG = { - "client": "server", - "auth_token": "testing-token", - "db_username": "PYTEST", - "db_password": "OrA_41_3xPl0d3r", - "db_dsn": "//localhost:1525/FREEPDB1", -} - -# Environments for Client/Server -os.environ["CONFIG_FILE"] = "/non/existant/path/config.json" # Prevent picking up an exported settings file -os.environ["OCI_CLI_CONFIG_FILE"] = "/non/existant/path" # Prevent picking up default OCI config file -os.environ["API_SERVER_KEY"] = TEST_CONFIG["auth_token"] -os.environ["API_SERVER_URL"] = "http://localhost" -os.environ["API_SERVER_PORT"] = "8015" - -# Import rest of required modules -from fastapi.testclient import TestClient # pylint: disable=wrong-import-position -from streamlit.testing.v1 import AppTest # pylint: disable=wrong-import-position - - -################################################# -# Fixures for tests/server -################################################# -@pytest.fixture(name="auth_headers") -def _auth_headers(): - """Return common header configurations for testing.""" - return { - "no_auth": {}, - "invalid_auth": {"Authorization": "Bearer invalid-token", "client": TEST_CONFIG["client"]}, - "valid_auth": {"Authorization": f"Bearer {TEST_CONFIG['auth_token']}", "client": TEST_CONFIG["client"]}, - } - - -@pytest.fixture(scope="session") -def client(): - """Create a test client for the FastAPI app.""" - # Lazy Load - import asyncio - from launch_server import create_app - - app = asyncio.run(create_app()) - return TestClient(app) - - -@pytest.fixture -def mock_embedding_model(): - """ - This fixture provides a mock embedding model for testing. - It returns a function that simulates embedding generation by returning random vectors. - """ - - def mock_embed_documents(texts: list[str]) -> list[list[float]]: - """Mock function that returns random embeddings for testing""" - return [np.random.rand(384).tolist() for _ in texts] # 384 is a common embedding dimension - - return mock_embed_documents - - -@pytest.fixture -def db_objects_manager(): - """ - Fixture to manage DATABASE_OBJECTS save/restore operations. - This reduces code duplication across tests that need to manipulate DATABASE_OBJECTS. - """ - from server.bootstrap.bootstrap import DATABASE_OBJECTS - - original_db_objects = DATABASE_OBJECTS.copy() - yield DATABASE_OBJECTS - DATABASE_OBJECTS.clear() - DATABASE_OBJECTS.extend(original_db_objects) - - -################################################# -# Fixures for tests/client -################################################# -@pytest.fixture(scope="session") -def app_server(request): - """Start the FastAPI server for Streamlit and wait for it to be ready""" - - def is_port_in_use(port): - with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: - return s.connect_ex(("localhost", port)) == 0 - - config_file = getattr(request, "param", None) - - # If config_file is passed, include it in the subprocess command - cmd = ["python", "launch_server.py"] - if config_file: - cmd.extend(["-c", config_file]) - - server_process = subprocess.Popen(cmd, cwd="src") - - try: - # Wait for server to be ready (up to 30 seconds) - max_wait = 30 - start_time = time.time() - while not is_port_in_use(8015): - if time.time() - start_time > max_wait: - raise TimeoutError("Server failed to start within 30 seconds") - time.sleep(0.5) - - yield server_process - - finally: - # Terminate the server after tests - server_process.terminate() - server_process.wait() - - -@pytest.fixture -def app_test(auth_headers): - """Establish Streamlit State for Client to Operate - - This fixture mimics what launch_client.py does in init_configs_state(), - loading the full configuration including all *_configs (database_configs, model_configs, - oci_configs, etc.) into session state, just like the real application does. - """ - - def _app_test(page): - at = AppTest.from_file(page, default_timeout=30) - at.session_state.server = { - "key": os.environ.get("API_SERVER_KEY"), - "url": os.environ.get("API_SERVER_URL"), - "port": int(os.environ.get("API_SERVER_PORT")), - "control": True, - } - # Load full config like launch_client.py does in init_configs_state() - full_config = requests.get( - url=f"{at.session_state.server['url']}:{at.session_state.server['port']}/v1/settings", - headers=auth_headers["valid_auth"], - params={ - "client": TEST_CONFIG["client"], - "full_config": True, - "incl_sensitive": True, - "incl_readonly": True, - }, - timeout=120, - ).json() - # Load all config items into session state (database_configs, model_configs, oci_configs, etc.) - for key, value in full_config.items(): - at.session_state[key] = value - return at - - return _app_test - - -def setup_test_database(app_test_instance): - """Configure and connect to test database for integration tests - - This helper function: - 1. Updates database config with test credentials - 2. Patches the database on the server - 3. Reloads full config to get updated database status - - Args: - app_test_instance: The AppTest instance from app_test fixture - - Returns: - The updated AppTest instance with database configured - """ - if not app_test_instance.session_state.database_configs: - return app_test_instance - - # Update database config with test credentials - db_config = app_test_instance.session_state.database_configs[0] - db_config["user"] = TEST_CONFIG["db_username"] - db_config["password"] = TEST_CONFIG["db_password"] - db_config["dsn"] = TEST_CONFIG["db_dsn"] - - # Update the database on the server to establish connection - server_url = app_test_instance.session_state.server["url"] - server_port = app_test_instance.session_state.server["port"] - server_key = app_test_instance.session_state.server["key"] - db_name = db_config["name"] - - response = requests.patch( - url=f"{server_url}:{server_port}/v1/databases/{db_name}", - headers={"Authorization": f"Bearer {server_key}", "client": "server"}, - json={"user": db_config["user"], "password": db_config["password"], "dsn": db_config["dsn"]}, - timeout=120, - ) - - if response.status_code != 200: - raise RuntimeError(f"Failed to update database: {response.text}") - - # Reload the full config to get the updated database status - full_config = requests.get( - url=f"{server_url}:{server_port}/v1/settings", - headers={"Authorization": f"Bearer {server_key}", "client": TEST_CONFIG["client"]}, - params={ - "client": TEST_CONFIG["client"], - "full_config": True, - "incl_sensitive": True, - "incl_readonly": True, - }, - timeout=120, - ).json() - - # Update session state with refreshed config - for key, value in full_config.items(): - app_test_instance.session_state[key] = value - - return app_test_instance - - -def enable_test_models(app_test_instance): - """Enable at least one LL model for testing - - Args: - app_test_instance: The AppTest instance from app_test fixture - - Returns: - The updated AppTest instance with models enabled - """ - for model in app_test_instance.session_state.model_configs: - if model["type"] == "ll": - model["enabled"] = True - break - - return app_test_instance - - -def enable_test_embed_models(app_test_instance): - """Enable at least one embedding model for testing - - Args: - app_test_instance: The AppTest instance from app_test fixture - - Returns: - The updated AppTest instance with embed models enabled - """ - for model in app_test_instance.session_state.model_configs: - if model["type"] == "embed": - model["enabled"] = True - break - - return app_test_instance - - -def create_tabs_mock(monkeypatch): - """Create a mock for st.tabs that captures what tabs are created - - This is a helper function to reduce code duplication in tests that need - to verify which tabs are created by the application. - - Args: - monkeypatch: pytest monkeypatch fixture - - Returns: - A list that will be populated with tab names as they are created - """ - import streamlit as st - - tabs_created = [] - original_tabs = st.tabs - - def mock_tabs(tab_list): - tabs_created.extend(tab_list) - return original_tabs(tab_list) - - monkeypatch.setattr(st, "tabs", mock_tabs) - return tabs_created - - -@contextmanager -def temporary_sys_path(path): - """Temporarily add a path to sys.path and remove it when done - - This context manager is useful for tests that need to temporarily modify - the Python path to import modules from specific locations. - - Args: - path: Path to add to sys.path - - Yields: - None - """ - sys.path.insert(0, path) - try: - yield - finally: - if path in sys.path: - sys.path.remove(path) - - -def run_streamlit_test(app_test_instance, run=True): - """Helper to run a Streamlit test and verify no exceptions - - This helper reduces code duplication in tests that follow the pattern: - 1. Run the app test - 2. Verify no exceptions occurred - - Args: - app_test_instance: The AppTest instance to run - run: Whether to run the test (default: True) - - Returns: - The AppTest instance (run or not based on the run parameter) - """ - if run: - app_test_instance = app_test_instance.run() - assert not app_test_instance.exception - return app_test_instance - - -def get_test_db_payload(): - """Get standard test database payload for integration tests - - Returns: - dict: Database configuration payload with test credentials - """ - return { - "user": TEST_CONFIG["db_username"], - "password": TEST_CONFIG["db_password"], - "dsn": TEST_CONFIG["db_dsn"], - } - - -def get_sample_oci_config(): - """Get sample OCI configuration for unit tests - - Returns: - OracleCloudSettings: Sample OCI configuration object - """ - from common.schema import OracleCloudSettings - - return OracleCloudSettings( - auth_profile="DEFAULT", - compartment_id="ocid1.compartment.oc1..test", - genai_region="us-ashburn-1", - user="ocid1.user.oc1..testuser", - fingerprint="test-fingerprint", - tenancy="ocid1.tenancy.oc1..testtenant", - key_file="/path/to/key.pem", - ) - - -################################################# -# Container for DB Tests -################################################# -def wait_for_container_ready(container: Container, ready_output: str, since: Optional[int] = None) -> None: - """Wait for container to be ready by checking its logs with exponential backoff.""" - start_time = time.time() - retry_interval = 2 - - while time.time() - start_time < 60: - try: - logs = container.logs(tail=100, since=since).decode("utf-8") - if ready_output in logs: - return - except DockerException as e: - container.remove(force=True) - raise DockerException(f"Failed to get container logs: {str(e)}") from e - - time.sleep(retry_interval) - retry_interval = min(retry_interval * 2, 60) # Exponential backoff, max 10 seconds - - container.remove(force=True) - raise TimeoutError("Container did not become ready timeout") - - -@contextmanager -def temp_sql_setup(): - """Context manager for temporary SQL setup files.""" - temp_dir = Path("tests/db_startup_temp") - try: - temp_dir.mkdir(exist_ok=True) - sql_content = f""" - alter system set vector_memory_size=512M scope=spfile; - - alter session set container=FREEPDB1; - CREATE TABLESPACE IF NOT EXISTS USERS DATAFILE '/opt/oracle/oradata/FREE/FREEPDB1/users_01.dbf' SIZE 100M; - CREATE USER IF NOT EXISTS "{TEST_CONFIG["db_username"]}" IDENTIFIED BY {TEST_CONFIG["db_password"]} - DEFAULT TABLESPACE "USERS" - TEMPORARY TABLESPACE "TEMP"; - GRANT "DB_DEVELOPER_ROLE" TO "{TEST_CONFIG["db_username"]}"; - ALTER USER "{TEST_CONFIG["db_username"]}" DEFAULT ROLE ALL; - ALTER USER "{TEST_CONFIG["db_username"]}" QUOTA UNLIMITED ON USERS; - - EXIT; - """ - - temp_sql_file = temp_dir / "01_db_user.sql" - temp_sql_file.write_text(sql_content, encoding="UTF-8") - yield temp_dir - finally: - if temp_dir.exists(): - shutil.rmtree(temp_dir) - - -@pytest.fixture(scope="session") -def db_container() -> Generator[Container, None, None]: - """Create and manage an Oracle database container for testing.""" - db_client = docker.from_env() - container = None - - try: - with temp_sql_setup() as temp_dir: - container = db_client.containers.run( - "container-registry.oracle.com/database/free:latest-lite", - environment={ - "ORACLE_PWD": TEST_CONFIG["db_password"], - "ORACLE_PDB": TEST_CONFIG["db_dsn"].split("/")[3], - }, - ports={"1521/tcp": int(TEST_CONFIG["db_dsn"].split("/")[2].split(":")[1])}, - volumes={str(temp_dir.absolute()): {"bind": "/opt/oracle/scripts/startup", "mode": "ro"}}, - detach=True, - ) - - # Wait for database to be ready - wait_for_container_ready(container, "DATABASE IS READY TO USE!") - - # Restart container to apply vector_memory_size - container.restart() - restart_time = int(time.time()) - wait_for_container_ready(container, "DATABASE IS READY TO USE!", since=restart_time) - - yield container - - except DockerException as e: - if container: - container.remove(force=True) - raise DockerException(f"Docker operation failed: {str(e)}") from e - - finally: - if container: - try: - container.stop(timeout=30) - container.remove() - except DockerException as e: - print(f"Warning: Failed to cleanup database container: {str(e)}") - - -################################################# -# Shared Test Data for Vector Store Tests -################################################# -@pytest.fixture -def sample_vector_store_data(): - """Sample vector store data for testing - standard configuration""" - return { - "alias": "test_alias", - "model": "openai/text-embed-3", - "chunk_size": 1000, - "chunk_overlap": 200, - "distance_metric": "cosine", - "index_type": "IVF", - "vector_store": "vs_test" - } - - -@pytest.fixture -def sample_vector_store_data_alt(): - """Alternative sample vector store data for testing - different configuration""" - return { - "alias": "alias2", - "model": "openai/text-embed-3", - "chunk_size": 500, - "chunk_overlap": 100, - "distance_metric": "euclidean", - "index_type": "HNSW", - "vector_store": "vs2" - } - - -@pytest.fixture -def sample_vector_stores_list(sample_vector_store_data, sample_vector_store_data_alt): # pylint: disable=redefined-outer-name - """List of sample vector stores with different aliases for filtering tests""" - vs1 = sample_vector_store_data.copy() - vs1["alias"] = "vs1" - vs1.pop("vector_store", None) # Remove vector_store field for filtering tests - - vs2 = sample_vector_store_data_alt.copy() - vs2["alias"] = "vs2" - vs2.pop("vector_store", None) # Remove vector_store field for filtering tests - - return [vs1, vs2] From 981fbc85cd143193e971843874dbcc5afebc7bc9 Mon Sep 17 00:00:00 2001 From: gotsysdba Date: Sun, 30 Nov 2025 12:46:12 +0000 Subject: [PATCH 10/20] shift back to tests --- pytest.ini | 16 +- test/conftest.py | 27 --- {test => tests}/__init__.py | 0 tests/conftest.py | 23 ++ {test => tests}/db_fixtures.py | 208 +++++++++++++++++- {test => tests}/integration/__init__.py | 0 .../integration/client/__init__.py | 0 .../integration/client/conftest.py | 123 ++++++----- .../integration/client/content/__init__.py | 0 .../client/content/config/__init__.py | 0 .../client/content/config/tabs/__init__.py | 0 .../content/config/tabs/test_databases.py | 2 +- .../client/content/config/tabs/test_mcp.py | 0 .../client/content/config/tabs/test_models.py | 0 .../client/content/config/tabs/test_oci.py | 0 .../content/config/tabs/test_settings.py | 0 .../client/content/config/test_config.py | 0 .../client/content/test_api_server.py | 0 .../client/content/test_chatbot.py | 0 .../client/content/test_testbed.py | 0 .../client/content/tools/__init__.py | 0 .../client/content/tools/tabs/__init__.py | 0 .../content/tools/tabs/test_prompt_eng.py | 0 .../content/tools/tabs/test_split_embed.py | 0 .../client/content/tools/test_tools.py | 0 .../integration/client/utils/__init__.py | 0 .../client/utils/test_st_footer.py | 0 .../client/utils/test_vs_options.py | 0 .../integration/common/__init__.py | 0 .../integration/common/test_functions.py | 12 +- tests/integration/conftest.py | 23 ++ .../integration/server/__init__.py | 0 .../integration/server/api/__init__.py | 0 .../integration/server/api/conftest.py | 84 +++++-- .../integration/server/api/v1/__init__.py | 0 .../integration/server/api/v1/test_chat.py | 0 .../server/api/v1/test_databases.py | 2 +- .../integration/server/api/v1/test_embed.py | 0 .../server/api/v1/test_mcp_prompts.py | 0 .../integration/server/api/v1/test_models.py | 0 .../integration/server/api/v1/test_oci.py | 0 .../integration/server/api/v1/test_probes.py | 0 .../server/api/v1/test_settings.py | 0 .../integration/server/api/v1/test_testbed.py | 0 .../integration/server/bootstrap/__init__.py | 0 .../integration/server/bootstrap/conftest.py | 23 +- .../bootstrap/test_bootstrap_configfile.py | 0 .../bootstrap/test_bootstrap_databases.py | 2 +- .../server/bootstrap/test_bootstrap_models.py | 2 +- .../server/bootstrap/test_bootstrap_oci.py | 0 .../bootstrap/test_bootstrap_settings.py | 0 {test => tests}/opentofu/OMRMetaSchema.yaml | 0 .../opentofu/validate_omr_schema.py | 0 {test => tests}/shared_fixtures.py | 184 +++++++++++++++- {test => tests}/unit/__init__.py | 0 {test => tests}/unit/client/__init__.py | 0 {test => tests}/unit/client/conftest.py | 13 +- .../unit/client/content/__init__.py | 0 .../unit/client/content/config/__init__.py | 0 .../client/content/config/tabs/__init__.py | 0 .../content/config/tabs/test_mcp_unit.py | 0 .../content/config/tabs/test_models_unit.py | 0 .../unit/client/content/test_chatbot_unit.py | 0 .../unit/client/content/test_testbed_unit.py | 0 .../unit/client/content/tools/__init__.py | 0 .../client/content/tools/tabs/__init__.py | 0 .../tools/tabs/test_split_embed_unit.py | 0 {test => tests}/unit/client/utils/__init__.py | 0 .../unit/client/utils/test_client_unit.py | 0 .../unit/client/utils/test_st_common_unit.py | 0 .../unit/client/utils/test_vs_options_unit.py | 0 {test => tests}/unit/common/__init__.py | 0 {test => tests}/unit/common/test_functions.py | 0 {test => tests}/unit/common/test_help_text.py | 0 .../unit/common/test_logging_config.py | 0 {test => tests}/unit/common/test_schema.py | 0 {test => tests}/unit/common/test_version.py | 0 tests/unit/conftest.py | 23 ++ {test => tests}/unit/server/__init__.py | 0 {test => tests}/unit/server/api/__init__.py | 0 {test => tests}/unit/server/api/conftest.py | 23 +- .../unit/server/api/utils/__init__.py | 0 .../unit/server/api/utils/test_utils_chat.py | 0 .../server/api/utils/test_utils_databases.py | 16 +- .../unit/server/api/utils/test_utils_embed.py | 0 .../unit/server/api/utils/test_utils_mcp.py | 2 +- .../server/api/utils/test_utils_models.py | 0 .../unit/server/api/utils/test_utils_oci.py | 0 .../server/api/utils/test_utils_settings.py | 0 .../server/api/utils/test_utils_testbed.py | 0 .../api/utils/test_utils_testbed_metrics.py | 0 .../server/api/utils/test_utils_webscrape.py | 0 .../unit/server/api/v1/__init__.py | 0 .../unit/server/api/v1/test_v1_chat.py | 0 .../unit/server/api/v1/test_v1_databases.py | 0 .../unit/server/api/v1/test_v1_embed.py | 0 .../unit/server/api/v1/test_v1_mcp.py | 2 +- .../unit/server/api/v1/test_v1_mcp_prompts.py | 0 .../unit/server/api/v1/test_v1_models.py | 0 .../unit/server/api/v1/test_v1_oci.py | 0 .../unit/server/api/v1/test_v1_probes.py | 0 .../unit/server/api/v1/test_v1_settings.py | 0 .../unit/server/api/v1/test_v1_testbed.py | 0 .../unit/server/bootstrap/__init__.py | 0 .../unit/server/bootstrap/conftest.py | 19 +- .../bootstrap/test_bootstrap_bootstrap.py | 0 .../bootstrap/test_bootstrap_configfile.py | 0 .../bootstrap/test_bootstrap_databases.py | 2 +- .../server/bootstrap/test_bootstrap_models.py | 2 +- .../server/bootstrap/test_bootstrap_oci.py | 0 .../bootstrap/test_bootstrap_settings.py | 0 111 files changed, 640 insertions(+), 193 deletions(-) delete mode 100644 test/conftest.py rename {test => tests}/__init__.py (100%) create mode 100644 tests/conftest.py rename {test => tests}/db_fixtures.py (50%) rename {test => tests}/integration/__init__.py (100%) rename {test => tests}/integration/client/__init__.py (100%) rename {test => tests}/integration/client/conftest.py (81%) rename {test => tests}/integration/client/content/__init__.py (100%) rename {test => tests}/integration/client/content/config/__init__.py (100%) rename {test => tests}/integration/client/content/config/tabs/__init__.py (100%) rename {test => tests}/integration/client/content/config/tabs/test_databases.py (99%) rename {test => tests}/integration/client/content/config/tabs/test_mcp.py (100%) rename {test => tests}/integration/client/content/config/tabs/test_models.py (100%) rename {test => tests}/integration/client/content/config/tabs/test_oci.py (100%) rename {test => tests}/integration/client/content/config/tabs/test_settings.py (100%) rename {test => tests}/integration/client/content/config/test_config.py (100%) rename {test => tests}/integration/client/content/test_api_server.py (100%) rename {test => tests}/integration/client/content/test_chatbot.py (100%) rename {test => tests}/integration/client/content/test_testbed.py (100%) rename {test => tests}/integration/client/content/tools/__init__.py (100%) rename {test => tests}/integration/client/content/tools/tabs/__init__.py (100%) rename {test => tests}/integration/client/content/tools/tabs/test_prompt_eng.py (100%) rename {test => tests}/integration/client/content/tools/tabs/test_split_embed.py (100%) rename {test => tests}/integration/client/content/tools/test_tools.py (100%) rename {test => tests}/integration/client/utils/__init__.py (100%) rename {test => tests}/integration/client/utils/test_st_footer.py (100%) rename {test => tests}/integration/client/utils/test_vs_options.py (100%) rename {test => tests}/integration/common/__init__.py (100%) rename {test => tests}/integration/common/test_functions.py (92%) create mode 100644 tests/integration/conftest.py rename {test => tests}/integration/server/__init__.py (100%) rename {test => tests}/integration/server/api/__init__.py (100%) rename {test => tests}/integration/server/api/conftest.py (71%) rename {test => tests}/integration/server/api/v1/__init__.py (100%) rename {test => tests}/integration/server/api/v1/test_chat.py (100%) rename {test => tests}/integration/server/api/v1/test_databases.py (99%) rename {test => tests}/integration/server/api/v1/test_embed.py (100%) rename {test => tests}/integration/server/api/v1/test_mcp_prompts.py (100%) rename {test => tests}/integration/server/api/v1/test_models.py (100%) rename {test => tests}/integration/server/api/v1/test_oci.py (100%) rename {test => tests}/integration/server/api/v1/test_probes.py (100%) rename {test => tests}/integration/server/api/v1/test_settings.py (100%) rename {test => tests}/integration/server/api/v1/test_testbed.py (100%) rename {test => tests}/integration/server/bootstrap/__init__.py (100%) rename {test => tests}/integration/server/bootstrap/conftest.py (87%) rename {test => tests}/integration/server/bootstrap/test_bootstrap_configfile.py (100%) rename {test => tests}/integration/server/bootstrap/test_bootstrap_databases.py (99%) rename {test => tests}/integration/server/bootstrap/test_bootstrap_models.py (99%) rename {test => tests}/integration/server/bootstrap/test_bootstrap_oci.py (100%) rename {test => tests}/integration/server/bootstrap/test_bootstrap_settings.py (100%) rename {test => tests}/opentofu/OMRMetaSchema.yaml (100%) rename {test => tests}/opentofu/validate_omr_schema.py (100%) rename {test => tests}/shared_fixtures.py (63%) rename {test => tests}/unit/__init__.py (100%) rename {test => tests}/unit/client/__init__.py (100%) rename {test => tests}/unit/client/conftest.py (82%) rename {test => tests}/unit/client/content/__init__.py (100%) rename {test => tests}/unit/client/content/config/__init__.py (100%) rename {test => tests}/unit/client/content/config/tabs/__init__.py (100%) rename {test => tests}/unit/client/content/config/tabs/test_mcp_unit.py (100%) rename {test => tests}/unit/client/content/config/tabs/test_models_unit.py (100%) rename {test => tests}/unit/client/content/test_chatbot_unit.py (100%) rename {test => tests}/unit/client/content/test_testbed_unit.py (100%) rename {test => tests}/unit/client/content/tools/__init__.py (100%) rename {test => tests}/unit/client/content/tools/tabs/__init__.py (100%) rename {test => tests}/unit/client/content/tools/tabs/test_split_embed_unit.py (100%) rename {test => tests}/unit/client/utils/__init__.py (100%) rename {test => tests}/unit/client/utils/test_client_unit.py (100%) rename {test => tests}/unit/client/utils/test_st_common_unit.py (100%) rename {test => tests}/unit/client/utils/test_vs_options_unit.py (100%) rename {test => tests}/unit/common/__init__.py (100%) rename {test => tests}/unit/common/test_functions.py (100%) rename {test => tests}/unit/common/test_help_text.py (100%) rename {test => tests}/unit/common/test_logging_config.py (100%) rename {test => tests}/unit/common/test_schema.py (100%) rename {test => tests}/unit/common/test_version.py (100%) create mode 100644 tests/unit/conftest.py rename {test => tests}/unit/server/__init__.py (100%) rename {test => tests}/unit/server/api/__init__.py (100%) rename {test => tests}/unit/server/api/conftest.py (89%) rename {test => tests}/unit/server/api/utils/__init__.py (100%) rename {test => tests}/unit/server/api/utils/test_utils_chat.py (100%) rename {test => tests}/unit/server/api/utils/test_utils_databases.py (98%) rename {test => tests}/unit/server/api/utils/test_utils_embed.py (100%) rename {test => tests}/unit/server/api/utils/test_utils_mcp.py (99%) rename {test => tests}/unit/server/api/utils/test_utils_models.py (100%) rename {test => tests}/unit/server/api/utils/test_utils_oci.py (100%) rename {test => tests}/unit/server/api/utils/test_utils_settings.py (100%) rename {test => tests}/unit/server/api/utils/test_utils_testbed.py (100%) rename {test => tests}/unit/server/api/utils/test_utils_testbed_metrics.py (100%) rename {test => tests}/unit/server/api/utils/test_utils_webscrape.py (100%) rename {test => tests}/unit/server/api/v1/__init__.py (100%) rename {test => tests}/unit/server/api/v1/test_v1_chat.py (100%) rename {test => tests}/unit/server/api/v1/test_v1_databases.py (100%) rename {test => tests}/unit/server/api/v1/test_v1_embed.py (100%) rename {test => tests}/unit/server/api/v1/test_v1_mcp.py (99%) rename {test => tests}/unit/server/api/v1/test_v1_mcp_prompts.py (100%) rename {test => tests}/unit/server/api/v1/test_v1_models.py (100%) rename {test => tests}/unit/server/api/v1/test_v1_oci.py (100%) rename {test => tests}/unit/server/api/v1/test_v1_probes.py (100%) rename {test => tests}/unit/server/api/v1/test_v1_settings.py (100%) rename {test => tests}/unit/server/api/v1/test_v1_testbed.py (100%) rename {test => tests}/unit/server/bootstrap/__init__.py (100%) rename {test => tests}/unit/server/bootstrap/conftest.py (74%) rename {test => tests}/unit/server/bootstrap/test_bootstrap_bootstrap.py (100%) rename {test => tests}/unit/server/bootstrap/test_bootstrap_configfile.py (100%) rename {test => tests}/unit/server/bootstrap/test_bootstrap_databases.py (99%) rename {test => tests}/unit/server/bootstrap/test_bootstrap_models.py (99%) rename {test => tests}/unit/server/bootstrap/test_bootstrap_oci.py (100%) rename {test => tests}/unit/server/bootstrap/test_bootstrap_settings.py (100%) diff --git a/pytest.ini b/pytest.ini index 2a5ffb75..cbfba509 100644 --- a/pytest.ini +++ b/pytest.ini @@ -7,4 +7,18 @@ pythonpath = src filterwarnings = ignore::DeprecationWarning -asyncio_default_fixture_loop_scope = function \ No newline at end of file +asyncio_default_fixture_loop_scope = function + +; Test markers for selective test execution +; Usage examples: +; pytest -m "unit" # Run only unit tests +; pytest -m "integration" # Run only integration tests +; pytest -m "not slow" # Skip slow tests +; pytest -m "not db" # Skip tests requiring database +; pytest -m "unit and not slow" # Fast unit tests only +markers = + unit: Unit tests (mocked dependencies, fast execution) + integration: Integration tests (real components, may require external services) + slow: Slow tests (deselect with '-m "not slow"') + db: Tests requiring Oracle database container (deselect with '-m "not db"') + db_container: Alias for db marker - tests requiring database container \ No newline at end of file diff --git a/test/conftest.py b/test/conftest.py deleted file mode 100644 index c25db9f6..00000000 --- a/test/conftest.py +++ /dev/null @@ -1,27 +0,0 @@ -""" -Copyright (c) 2024, 2025, Oracle and/or its affiliates. -Licensed under the Universal Permissive License v1.0 as shown at http://oss.oracle.com/licenses/upl. - -Pytest fixtures for unit tests with real Oracle database. - -Re-exports shared database fixtures from test.db_fixtures. -""" - -# Re-export shared fixtures for pytest discovery -from test.db_fixtures import ( - TEST_DB_CONFIG, - db_container, - db_connection, - db_transaction, -) - -# Expose TEST_CONFIG alias for backwards compatibility -TEST_CONFIG = TEST_DB_CONFIG - -__all__ = [ - "TEST_CONFIG", - "TEST_DB_CONFIG", - "db_container", - "db_connection", - "db_transaction", -] diff --git a/test/__init__.py b/tests/__init__.py similarity index 100% rename from test/__init__.py rename to tests/__init__.py diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 00000000..8935079e --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,23 @@ +""" +Copyright (c) 2024, 2025, Oracle and/or its affiliates. +Licensed under the Universal Permissive License v1.0 as shown at http://oss.oracle.com/licenses/upl. + +Root pytest configuration for the test suite. + +This conftest.py uses pytest_plugins to automatically load fixtures from: +- tests.shared_fixtures: Factory fixtures (make_database, make_model, etc.) +- tests.db_fixtures: Database container fixtures (db_container, db_connection, etc.) + +All fixtures defined in these modules are automatically available to all tests +without needing explicit imports in child conftest.py files. + +Constants and helper functions (e.g., TEST_DB_CONFIG, assert_model_list_valid) +still require explicit imports in the test files that use them. +""" + +# pytest_plugins automatically loads fixtures from these modules +# This replaces scattered "from tests.shared_fixtures import ..." across conftest files +pytest_plugins = [ + "tests.shared_fixtures", + "tests.db_fixtures", +] diff --git a/test/db_fixtures.py b/tests/db_fixtures.py similarity index 50% rename from test/db_fixtures.py rename to tests/db_fixtures.py index 51609173..293e5d80 100644 --- a/test/db_fixtures.py +++ b/tests/db_fixtures.py @@ -4,8 +4,52 @@ Shared database fixtures and utilities for tests. -This module provides common database container management functions -used by both unit and integration tests. +This module provides Oracle database container management and connection +fixtures for both unit and integration tests. + +FIXTURES (choose based on your test needs): + + db_container (session) + Raw Docker container. Rarely needed directly. + + db_connection (session) + Shared connection for the entire test session. + Use when you need low-level connection access. + + db_transaction (function) + Connection with savepoint-based isolation. + Best for DML tests (INSERT/UPDATE/DELETE). + WARNING: DDL operations (CREATE TABLE) invalidate savepoints! + + db_cursor (function) + Convenience cursor with automatic cleanup. + No transaction isolation - use for simple queries. + + db_clean (function) + For tests with DDL operations (CREATE TABLE, etc.). + Tracks and drops tables after test completes. + Usage: table = db_clean.register("MY_TABLE") + + db_module_connection (module) + Module-scoped connection for tests sharing state. + Use when multiple tests need the same tables/data. + +FIXTURE SELECTION GUIDE: + + Test Type | Recommended Fixture + -----------------------------|-------------------- + Simple SELECT queries | db_cursor + INSERT/UPDATE/DELETE | db_transaction + CREATE TABLE / DDL | db_clean + Multiple related tests | db_module_connection + Custom connection handling | db_connection + +Tests using any of these fixtures are automatically marked with +'db' and 'slow' markers, enabling: + + pytest -m "not db" # Skip database tests + pytest -m "not slow" # Skip slow tests (includes DB tests) + pytest -m "db" # Run only database tests """ # pylint: disable=redefined-outer-name @@ -24,6 +68,36 @@ from docker.models.containers import Container +# Database fixture names that trigger auto-marking +DB_FIXTURE_NAMES = { + "db_container", + "db_connection", + "db_transaction", + "db_cursor", + "db_clean", + "db_module_connection", +} + + +def pytest_collection_modifyitems(config, items): + """Automatically mark tests using database fixtures with 'db' and 'slow' markers. + + This hook inspects each test's fixture requirements and adds markers + to tests that use db_container, db_connection, or db_transaction fixtures. + """ + for item in items: + # Get the fixture names this test uses + try: + fixture_names = set(item.fixturenames) + except AttributeError: + continue + + # If test uses any DB fixture, mark it + if fixture_names & DB_FIXTURE_NAMES: + item.add_marker(pytest.mark.db) + item.add_marker(pytest.mark.slow) + + # Test database configuration - shared across all tests TEST_DB_CONFIG = { "db_username": "PYTEST", @@ -198,7 +272,13 @@ def db_transaction(db_connection) -> Generator[oracledb.Connection, None, None]: Note: DDL operations (CREATE TABLE, etc.) cause implicit commits in Oracle, which will invalidate the savepoint. Tests with DDL - should use mocks or handle cleanup manually. + should use db_clean instead. + + Usage: + def test_something(db_transaction): + cursor = db_transaction.cursor() + cursor.execute("INSERT INTO ...") + # Changes are automatically rolled back after test """ cursor = db_connection.cursor() cursor.execute("SAVEPOINT test_savepoint") @@ -207,3 +287,125 @@ def db_transaction(db_connection) -> Generator[oracledb.Connection, None, None]: cursor.execute("ROLLBACK TO SAVEPOINT test_savepoint") cursor.close() + + +@pytest.fixture +def db_cursor(db_connection) -> Generator[oracledb.Cursor, None, None]: + """Provides a database cursor with automatic cleanup. + + Convenience fixture that creates a cursor and ensures it's closed + after the test completes. Does not provide transaction isolation - + use db_transaction for that. + + Usage: + def test_something(db_cursor): + db_cursor.execute("SELECT * FROM dual") + result = db_cursor.fetchone() + """ + cursor = db_connection.cursor() + yield cursor + cursor.close() + + +class TableTracker: + """Helper class to track and clean up tables created during tests. + + This class provides methods to register tables for cleanup and + automatically drops them when cleanup() is called. + """ + + def __init__(self, connection: oracledb.Connection): + self.connection = connection + self.tables: list[str] = [] + + def register(self, table_name: str) -> str: + """Register a table for cleanup after test. + + Args: + table_name: Name of the table to track + + Returns: + The table name (for convenience in chaining) + """ + if table_name.upper() not in [t.upper() for t in self.tables]: + self.tables.append(table_name) + return table_name + + def cleanup(self) -> None: + """Drop all registered tables. + + Silently ignores errors if tables don't exist. + """ + cursor = self.connection.cursor() + for table_name in reversed(self.tables): # Reverse order for dependencies + try: + cursor.execute(f"DROP TABLE {table_name} PURGE") + except oracledb.DatabaseError: + pass # Table might not exist or have dependencies + try: + self.connection.commit() + except oracledb.DatabaseError: + pass + cursor.close() + self.tables.clear() + + +@pytest.fixture +def db_clean(db_connection) -> Generator[TableTracker, None, None]: + """Fixture for tests that perform DDL operations (CREATE TABLE, etc.). + + Unlike db_transaction which uses savepoints (invalidated by DDL), + this fixture tracks tables created during the test and drops them + during cleanup. + + Usage: + def test_create_table(db_clean): + cursor = db_clean.connection.cursor() + + # Register table BEFORE creating it + table_name = db_clean.register("MY_TEST_TABLE") + cursor.execute(f"CREATE TABLE {table_name} (id NUMBER)") + + # ... test logic ... + + # Table is automatically dropped after test + + For multiple tables with dependencies, register parent tables first: + def test_with_foreign_key(db_clean): + db_clean.register("PARENT_TABLE") + db_clean.register("CHILD_TABLE") # Will be dropped first + # ... + """ + tracker = TableTracker(db_connection) + yield tracker + tracker.cleanup() + + +@pytest.fixture(scope="module") +def db_module_connection(db_container) -> Generator[oracledb.Connection, None, None]: + """Module-scoped database connection. + + Use this when multiple tests in a module need to share database state + or when connection setup is expensive. Each module gets its own connection. + + Note: Tests using this fixture should be careful about state isolation. + Consider using unique table names or cleaning up after each test. + + Usage: + # In a test module + def test_first(db_module_connection): + # Uses shared connection for this module + pass + + def test_second(db_module_connection): + # Same connection as test_first + pass + """ + _ = db_container # Ensure container is running + conn = oracledb.connect( + user=TEST_DB_CONFIG["db_username"], + password=TEST_DB_CONFIG["db_password"], + dsn=TEST_DB_CONFIG["db_dsn"], + ) + yield conn + conn.close() diff --git a/test/integration/__init__.py b/tests/integration/__init__.py similarity index 100% rename from test/integration/__init__.py rename to tests/integration/__init__.py diff --git a/test/integration/client/__init__.py b/tests/integration/client/__init__.py similarity index 100% rename from test/integration/client/__init__.py rename to tests/integration/client/__init__.py diff --git a/test/integration/client/conftest.py b/tests/integration/client/conftest.py similarity index 81% rename from test/integration/client/conftest.py rename to tests/integration/client/conftest.py index 31daa5ea..e3780edd 100644 --- a/test/integration/client/conftest.py +++ b/tests/integration/client/conftest.py @@ -6,6 +6,13 @@ These fixtures provide Streamlit AppTest and FastAPI server management for testing the client UI components. + +Note: Shared fixtures (make_database, make_model, sample_vector_store_data, etc.) +are automatically available via pytest_plugins in test/conftest.py. + +Environment Setup: + Environment variables are managed via the session-scoped `client_test_env` fixture. + The `app_server` fixture depends on this to ensure proper configuration. """ # pylint: disable=redefined-outer-name @@ -17,27 +24,9 @@ import subprocess from contextlib import contextmanager -# Re-export shared fixtures for pytest discovery -from test.shared_fixtures import ( # noqa: F401 pylint: disable=unused-import - make_database, - make_model, - make_oci_config, - make_ll_settings, - make_settings, - make_configuration, - TEST_DB_USER, - TEST_DB_PASSWORD, - TEST_DB_DSN, - TEST_AUTH_TOKEN, - sample_vector_store_data, - sample_vector_store_data_alt, - sample_vector_stores_list, -) - -# Import TEST_DB_CONFIG for use in helper functions -# Note: db_container, db_connection, db_transaction fixtures are inherited -# from test/conftest.py - do not re-import here to avoid multiple containers -from test.db_fixtures import TEST_DB_CONFIG +# Import constants needed by fixtures and helper functions in this file +from tests.shared_fixtures import TEST_AUTH_TOKEN, ALL_TEST_ENV_VARS +from tests.db_fixtures import TEST_DB_CONFIG import pytest import requests @@ -57,38 +46,57 @@ def get_app_test(): ################################################# -# Environment Setup for Client Tests +# Test Configuration Constants ################################################# - -# Clear environment variables that might interfere with tests -API_VARS = ["API_SERVER_KEY", "API_SERVER_URL", "API_SERVER_PORT"] -DB_VARS = ["DB_USERNAME", "DB_PASSWORD", "DB_DSN", "DB_WALLET_PASSWORD", "TNS_ADMIN"] -MODEL_VARS = [ - "ON_PREM_OLLAMA_URL", - "ON_PREM_HF_URL", - "OPENAI_API_KEY", - "PPLX_API_KEY", - "COHERE_API_KEY", -] - -for env_var in [ - *API_VARS, - *DB_VARS, - *MODEL_VARS, - *[var for var in os.environ if var.startswith("OCI_")], -]: - os.environ.pop(env_var, None) - -# Test client configuration TEST_CLIENT = "client_test" TEST_SERVER_PORT = 8015 -# Set up test environment -os.environ["CONFIG_FILE"] = "/non/existent/path/config.json" -os.environ["OCI_CLI_CONFIG_FILE"] = "/non/existent/path" -os.environ["API_SERVER_KEY"] = TEST_AUTH_TOKEN -os.environ["API_SERVER_URL"] = "http://localhost" -os.environ["API_SERVER_PORT"] = str(TEST_SERVER_PORT) + +################################################# +# Environment Setup (Session-Scoped) +################################################# +@pytest.fixture(scope="session") +def client_test_env(): + """Session-scoped fixture to set up environment for client integration tests. + + This fixture: + 1. Saves the original environment state + 2. Clears all test-related environment variables + 3. Sets the required variables for client tests + 4. Restores the original state when the session ends + + The `app_server` fixture depends on this to ensure environment is configured + before the subprocess server is started. + """ + # Save original environment state + original_env = {var: os.environ.get(var) for var in ALL_TEST_ENV_VARS} + + # Also capture dynamic OCI_ vars + dynamic_oci_vars = [v for v in os.environ if v.startswith("OCI_") and v not in ALL_TEST_ENV_VARS] + for var in dynamic_oci_vars: + original_env[var] = os.environ.get(var) + + # Clear all test-related vars + for var in ALL_TEST_ENV_VARS: + os.environ.pop(var, None) + for var in dynamic_oci_vars: + os.environ.pop(var, None) + + # Set required environment variables for client tests + os.environ["CONFIG_FILE"] = "/non/existent/path/config.json" + os.environ["OCI_CLI_CONFIG_FILE"] = "/non/existent/path" + os.environ["API_SERVER_KEY"] = TEST_AUTH_TOKEN + os.environ["API_SERVER_URL"] = "http://localhost" + os.environ["API_SERVER_PORT"] = str(TEST_SERVER_PORT) + + yield + + # Restore original environment state + for var, value in original_env.items(): + if value is not None: + os.environ[var] = value + elif var in os.environ: + del os.environ[var] ################################################# @@ -107,8 +115,12 @@ def _auth_headers(): @pytest.fixture(scope="session") -def app_server(request): - """Start the FastAPI server for Streamlit and wait for it to be ready.""" +def app_server(request, client_test_env): + """Start the FastAPI server for Streamlit and wait for it to be ready. + + Depends on client_test_env to ensure environment is properly configured. + """ + _ = client_test_env # Ensure env is set up first def is_port_in_use(port): with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: @@ -121,14 +133,9 @@ def is_port_in_use(port): if config_file: cmd.extend(["-c", config_file]) - # Create environment with explicit port setting - # This ensures the server uses the correct port even when other conftest files - # may have modified os.environ during test collection + # Create environment with explicit settings for subprocess + # Copy current environment (which has been set up by client_test_env) env = os.environ.copy() - env["API_SERVER_PORT"] = str(TEST_SERVER_PORT) - env["API_SERVER_KEY"] = TEST_AUTH_TOKEN - env["CONFIG_FILE"] = "/non/existent/path/config.json" - env["OCI_CLI_CONFIG_FILE"] = "/non/existent/path" server_process = subprocess.Popen(cmd, cwd="src", env=env) # pylint: disable=consider-using-with diff --git a/test/integration/client/content/__init__.py b/tests/integration/client/content/__init__.py similarity index 100% rename from test/integration/client/content/__init__.py rename to tests/integration/client/content/__init__.py diff --git a/test/integration/client/content/config/__init__.py b/tests/integration/client/content/config/__init__.py similarity index 100% rename from test/integration/client/content/config/__init__.py rename to tests/integration/client/content/config/__init__.py diff --git a/test/integration/client/content/config/tabs/__init__.py b/tests/integration/client/content/config/tabs/__init__.py similarity index 100% rename from test/integration/client/content/config/tabs/__init__.py rename to tests/integration/client/content/config/tabs/__init__.py diff --git a/test/integration/client/content/config/tabs/test_databases.py b/tests/integration/client/content/config/tabs/test_databases.py similarity index 99% rename from test/integration/client/content/config/tabs/test_databases.py rename to tests/integration/client/content/config/tabs/test_databases.py index 8b5e937b..fed67ed9 100644 --- a/test/integration/client/content/config/tabs/test_databases.py +++ b/tests/integration/client/content/config/tabs/test_databases.py @@ -5,7 +5,7 @@ """ # spell-checker: disable -from test.db_fixtures import TEST_DB_CONFIG +from tests.db_fixtures import TEST_DB_CONFIG import pytest diff --git a/test/integration/client/content/config/tabs/test_mcp.py b/tests/integration/client/content/config/tabs/test_mcp.py similarity index 100% rename from test/integration/client/content/config/tabs/test_mcp.py rename to tests/integration/client/content/config/tabs/test_mcp.py diff --git a/test/integration/client/content/config/tabs/test_models.py b/tests/integration/client/content/config/tabs/test_models.py similarity index 100% rename from test/integration/client/content/config/tabs/test_models.py rename to tests/integration/client/content/config/tabs/test_models.py diff --git a/test/integration/client/content/config/tabs/test_oci.py b/tests/integration/client/content/config/tabs/test_oci.py similarity index 100% rename from test/integration/client/content/config/tabs/test_oci.py rename to tests/integration/client/content/config/tabs/test_oci.py diff --git a/test/integration/client/content/config/tabs/test_settings.py b/tests/integration/client/content/config/tabs/test_settings.py similarity index 100% rename from test/integration/client/content/config/tabs/test_settings.py rename to tests/integration/client/content/config/tabs/test_settings.py diff --git a/test/integration/client/content/config/test_config.py b/tests/integration/client/content/config/test_config.py similarity index 100% rename from test/integration/client/content/config/test_config.py rename to tests/integration/client/content/config/test_config.py diff --git a/test/integration/client/content/test_api_server.py b/tests/integration/client/content/test_api_server.py similarity index 100% rename from test/integration/client/content/test_api_server.py rename to tests/integration/client/content/test_api_server.py diff --git a/test/integration/client/content/test_chatbot.py b/tests/integration/client/content/test_chatbot.py similarity index 100% rename from test/integration/client/content/test_chatbot.py rename to tests/integration/client/content/test_chatbot.py diff --git a/test/integration/client/content/test_testbed.py b/tests/integration/client/content/test_testbed.py similarity index 100% rename from test/integration/client/content/test_testbed.py rename to tests/integration/client/content/test_testbed.py diff --git a/test/integration/client/content/tools/__init__.py b/tests/integration/client/content/tools/__init__.py similarity index 100% rename from test/integration/client/content/tools/__init__.py rename to tests/integration/client/content/tools/__init__.py diff --git a/test/integration/client/content/tools/tabs/__init__.py b/tests/integration/client/content/tools/tabs/__init__.py similarity index 100% rename from test/integration/client/content/tools/tabs/__init__.py rename to tests/integration/client/content/tools/tabs/__init__.py diff --git a/test/integration/client/content/tools/tabs/test_prompt_eng.py b/tests/integration/client/content/tools/tabs/test_prompt_eng.py similarity index 100% rename from test/integration/client/content/tools/tabs/test_prompt_eng.py rename to tests/integration/client/content/tools/tabs/test_prompt_eng.py diff --git a/test/integration/client/content/tools/tabs/test_split_embed.py b/tests/integration/client/content/tools/tabs/test_split_embed.py similarity index 100% rename from test/integration/client/content/tools/tabs/test_split_embed.py rename to tests/integration/client/content/tools/tabs/test_split_embed.py diff --git a/test/integration/client/content/tools/test_tools.py b/tests/integration/client/content/tools/test_tools.py similarity index 100% rename from test/integration/client/content/tools/test_tools.py rename to tests/integration/client/content/tools/test_tools.py diff --git a/test/integration/client/utils/__init__.py b/tests/integration/client/utils/__init__.py similarity index 100% rename from test/integration/client/utils/__init__.py rename to tests/integration/client/utils/__init__.py diff --git a/test/integration/client/utils/test_st_footer.py b/tests/integration/client/utils/test_st_footer.py similarity index 100% rename from test/integration/client/utils/test_st_footer.py rename to tests/integration/client/utils/test_st_footer.py diff --git a/test/integration/client/utils/test_vs_options.py b/tests/integration/client/utils/test_vs_options.py similarity index 100% rename from test/integration/client/utils/test_vs_options.py rename to tests/integration/client/utils/test_vs_options.py diff --git a/test/integration/common/__init__.py b/tests/integration/common/__init__.py similarity index 100% rename from test/integration/common/__init__.py rename to tests/integration/common/__init__.py diff --git a/test/integration/common/test_functions.py b/tests/integration/common/test_functions.py similarity index 92% rename from test/integration/common/test_functions.py rename to tests/integration/common/test_functions.py index 04f30c67..ae6a0c7c 100644 --- a/test/integration/common/test_functions.py +++ b/tests/integration/common/test_functions.py @@ -11,7 +11,7 @@ import os import tempfile -from test.conftest import TEST_CONFIG +from tests.db_fixtures import TEST_DB_CONFIG import pytest @@ -126,7 +126,7 @@ def test_is_sql_accessible_with_real_database(self, db_container): """is_sql_accessible should return True for valid database and query.""" # pylint: disable=unused-argument # Connection string format: username/password@dsn - db_conn = f"{TEST_CONFIG['db_username']}/{TEST_CONFIG['db_password']}@{TEST_CONFIG['db_dsn']}" + db_conn = f"{TEST_DB_CONFIG['db_username']}/{TEST_DB_CONFIG['db_password']}@{TEST_DB_CONFIG['db_dsn']}" # Must use VARCHAR2 - the function checks column type is VARCHAR, not CHAR query = "SELECT CAST('test' AS VARCHAR2(10)) FROM dual" @@ -139,7 +139,7 @@ def test_is_sql_accessible_with_real_database(self, db_container): def test_is_sql_accessible_invalid_credentials(self, db_container): """is_sql_accessible should return False for invalid credentials.""" # pylint: disable=unused-argument - db_conn = f"INVALID_USER/INVALID_PASSWORD@{TEST_CONFIG['db_dsn']}" + db_conn = f"INVALID_USER/INVALID_PASSWORD@{TEST_DB_CONFIG['db_dsn']}" query = "SELECT 'test' FROM dual" result, msg = functions.is_sql_accessible(db_conn, query) @@ -151,7 +151,7 @@ def test_is_sql_accessible_invalid_credentials(self, db_container): def test_is_sql_accessible_wrong_column_count(self, db_container): """is_sql_accessible should return False when query returns multiple columns.""" # pylint: disable=unused-argument - db_conn = f"{TEST_CONFIG['db_username']}/{TEST_CONFIG['db_password']}@{TEST_CONFIG['db_dsn']}" + db_conn = f"{TEST_DB_CONFIG['db_username']}/{TEST_DB_CONFIG['db_password']}@{TEST_DB_CONFIG['db_dsn']}" query = "SELECT 'a', 'b' FROM dual" # Two columns - should fail result, msg = functions.is_sql_accessible(db_conn, query) @@ -163,7 +163,7 @@ def test_is_sql_accessible_wrong_column_count(self, db_container): def test_run_sql_query_with_real_database(self, db_container): """run_sql_query should execute SQL and save results to CSV.""" # pylint: disable=unused-argument - db_conn = f"{TEST_CONFIG['db_username']}/{TEST_CONFIG['db_password']}@{TEST_CONFIG['db_dsn']}" + db_conn = f"{TEST_DB_CONFIG['db_username']}/{TEST_DB_CONFIG['db_password']}@{TEST_DB_CONFIG['db_dsn']}" query = "SELECT 'value1' AS col1, 'value2' AS col2 FROM dual" with tempfile.TemporaryDirectory() as tmpdir: @@ -184,7 +184,7 @@ def test_run_sql_query_with_real_database(self, db_container): def test_run_sql_query_invalid_connection(self, db_container): """run_sql_query should return falsy value for invalid connection.""" # pylint: disable=unused-argument - db_conn = f"INVALID_USER/INVALID_PASSWORD@{TEST_CONFIG['db_dsn']}" + db_conn = f"INVALID_USER/INVALID_PASSWORD@{TEST_DB_CONFIG['db_dsn']}" query = "SELECT 'test' FROM dual" with tempfile.TemporaryDirectory() as tmpdir: diff --git a/tests/integration/conftest.py b/tests/integration/conftest.py new file mode 100644 index 00000000..be912832 --- /dev/null +++ b/tests/integration/conftest.py @@ -0,0 +1,23 @@ +""" +Copyright (c) 2024, 2025, Oracle and/or its affiliates. +Licensed under the Universal Permissive License v1.0 as shown at http://oss.oracle.com/licenses/upl. + +Pytest configuration for integration tests. + +This conftest automatically marks all tests in the test/integration/ directory +with the 'integration' marker, enabling selective test execution: + + pytest -m "integration" # Run only integration tests + pytest -m "not integration" # Skip integration tests + pytest -m "integration and not db" # Integration tests without DB +""" + +import pytest + + +def pytest_collection_modifyitems(config, items): + """Automatically add 'integration' marker to all tests in this directory.""" + for item in items: + # Check if the test is under test/integration/ + if "/test/integration/" in str(item.fspath): + item.add_marker(pytest.mark.integration) diff --git a/test/integration/server/__init__.py b/tests/integration/server/__init__.py similarity index 100% rename from test/integration/server/__init__.py rename to tests/integration/server/__init__.py diff --git a/test/integration/server/api/__init__.py b/tests/integration/server/api/__init__.py similarity index 100% rename from test/integration/server/api/__init__.py rename to tests/integration/server/api/__init__.py diff --git a/test/integration/server/api/conftest.py b/tests/integration/server/api/conftest.py similarity index 71% rename from test/integration/server/api/conftest.py rename to tests/integration/server/api/conftest.py index a433fdc7..b0d1eddc 100644 --- a/test/integration/server/api/conftest.py +++ b/tests/integration/server/api/conftest.py @@ -7,23 +7,27 @@ Integration tests use a real FastAPI TestClient with the actual application, testing the full request/response cycle through the API layer. -Note: db_container fixture is inherited from test/conftest.py - do not import here. +Note: Shared fixtures (make_database, make_model, db_container, db_connection, etc.) +are automatically available via pytest_plugins in test/conftest.py. + +Environment Setup: + Environment variables are managed via the session-scoped `server_test_env` fixture, + which the `app` fixture depends on. This ensures proper isolation and explicit + dependency ordering. """ -# pylint: disable=redefined-outer-name unused-import -# Pytest fixtures use parameter injection where fixture names match parameters +# pylint: disable=redefined-outer-name import os import asyncio from typing import Generator -# Re-export shared fixtures for pytest discovery (before third-party imports per pylint) -from test.db_fixtures import TEST_DB_CONFIG -from test.shared_fixtures import ( - make_database, - make_model, +# Import constants needed by fixtures and test configuration in this file +from tests.db_fixtures import TEST_DB_CONFIG +from tests.shared_fixtures import ( DEFAULT_LL_MODEL_CONFIG, TEST_AUTH_TOKEN, + ALL_TEST_ENV_VARS, ) import numpy as np @@ -33,15 +37,6 @@ from server.bootstrap.bootstrap import DATABASE_OBJECTS, MODEL_OBJECTS, SETTINGS_OBJECTS - -# Clear environment variables that could interfere with tests -# This must happen before importing application modules -API_VARS = ["API_SERVER_KEY", "API_SERVER_URL", "API_SERVER_PORT"] -DB_VARS = ["DB_USERNAME", "DB_PASSWORD", "DB_DSN", "DB_WALLET_PASSWORD", "TNS_ADMIN"] -MODEL_VARS = ["ON_PREM_OLLAMA_URL", "ON_PREM_HF_URL", "OPENAI_API_KEY", "PPLX_API_KEY", "COHERE_API_KEY"] -for env_var in [*API_VARS, *DB_VARS, *MODEL_VARS, *[var for var in os.environ if var.startswith("OCI_")]]: - os.environ.pop(env_var, None) - # Test configuration - extends shared DB config with integration-specific settings TEST_CONFIG = { "client": "integration_test", @@ -49,10 +44,50 @@ **TEST_DB_CONFIG, } -# Set environment variables for test server -os.environ["CONFIG_FILE"] = "/non/existent/path/config.json" # Use empty config -os.environ["OCI_CLI_CONFIG_FILE"] = "/non/existent/path" # Prevent OCI config pickup -os.environ["API_SERVER_KEY"] = TEST_CONFIG["auth_token"] + +################################################# +# Environment Setup (Session-Scoped) +################################################# +@pytest.fixture(scope="session") +def server_test_env(): + """Session-scoped fixture to set up environment for server integration tests. + + This fixture: + 1. Saves the original environment state + 2. Clears all test-related environment variables + 3. Sets the required variables for the test server + 4. Restores the original state when the session ends + + The `app` fixture depends on this to ensure environment is configured + before the FastAPI application is created. + """ + # Save original environment state + original_env = {var: os.environ.get(var) for var in ALL_TEST_ENV_VARS} + + # Also capture dynamic OCI_ vars + dynamic_oci_vars = [v for v in os.environ if v.startswith("OCI_") and v not in ALL_TEST_ENV_VARS] + for var in dynamic_oci_vars: + original_env[var] = os.environ.get(var) + + # Clear all test-related vars + for var in ALL_TEST_ENV_VARS: + os.environ.pop(var, None) + for var in dynamic_oci_vars: + os.environ.pop(var, None) + + # Set required environment variables for test server + os.environ["CONFIG_FILE"] = "/non/existent/path/config.json" # Use empty config + os.environ["OCI_CLI_CONFIG_FILE"] = "/non/existent/path" # Prevent OCI config pickup + os.environ["API_SERVER_KEY"] = TEST_CONFIG["auth_token"] + + yield + + # Restore original environment state + for var, value in original_env.items(): + if value is not None: + os.environ[var] = value + elif var in os.environ: + del os.environ[var] ################################################# @@ -86,16 +121,17 @@ def test_client_auth_headers(test_client_settings): # FastAPI Test Client ################################################# @pytest.fixture(scope="session") -def app(): +def app(server_test_env): """Create the FastAPI application for testing. This fixture creates the actual FastAPI app using the same factory function as the production server (launch_server.create_app). - Import is done inside the fixture to ensure environment variables - are set before any application modules are loaded. + Depends on server_test_env to ensure environment variables are + configured before any application modules are loaded. """ # pylint: disable=import-outside-toplevel + _ = server_test_env # Ensure env is set up first from launch_server import create_app return asyncio.run(create_app()) diff --git a/test/integration/server/api/v1/__init__.py b/tests/integration/server/api/v1/__init__.py similarity index 100% rename from test/integration/server/api/v1/__init__.py rename to tests/integration/server/api/v1/__init__.py diff --git a/test/integration/server/api/v1/test_chat.py b/tests/integration/server/api/v1/test_chat.py similarity index 100% rename from test/integration/server/api/v1/test_chat.py rename to tests/integration/server/api/v1/test_chat.py diff --git a/test/integration/server/api/v1/test_databases.py b/tests/integration/server/api/v1/test_databases.py similarity index 99% rename from test/integration/server/api/v1/test_databases.py rename to tests/integration/server/api/v1/test_databases.py index a089df39..355148e9 100644 --- a/test/integration/server/api/v1/test_databases.py +++ b/tests/integration/server/api/v1/test_databases.py @@ -8,7 +8,7 @@ These endpoints require authentication. """ -from test.db_fixtures import TEST_DB_CONFIG +from tests.db_fixtures import TEST_DB_CONFIG class TestAuthentication: diff --git a/test/integration/server/api/v1/test_embed.py b/tests/integration/server/api/v1/test_embed.py similarity index 100% rename from test/integration/server/api/v1/test_embed.py rename to tests/integration/server/api/v1/test_embed.py diff --git a/test/integration/server/api/v1/test_mcp_prompts.py b/tests/integration/server/api/v1/test_mcp_prompts.py similarity index 100% rename from test/integration/server/api/v1/test_mcp_prompts.py rename to tests/integration/server/api/v1/test_mcp_prompts.py diff --git a/test/integration/server/api/v1/test_models.py b/tests/integration/server/api/v1/test_models.py similarity index 100% rename from test/integration/server/api/v1/test_models.py rename to tests/integration/server/api/v1/test_models.py diff --git a/test/integration/server/api/v1/test_oci.py b/tests/integration/server/api/v1/test_oci.py similarity index 100% rename from test/integration/server/api/v1/test_oci.py rename to tests/integration/server/api/v1/test_oci.py diff --git a/test/integration/server/api/v1/test_probes.py b/tests/integration/server/api/v1/test_probes.py similarity index 100% rename from test/integration/server/api/v1/test_probes.py rename to tests/integration/server/api/v1/test_probes.py diff --git a/test/integration/server/api/v1/test_settings.py b/tests/integration/server/api/v1/test_settings.py similarity index 100% rename from test/integration/server/api/v1/test_settings.py rename to tests/integration/server/api/v1/test_settings.py diff --git a/test/integration/server/api/v1/test_testbed.py b/tests/integration/server/api/v1/test_testbed.py similarity index 100% rename from test/integration/server/api/v1/test_testbed.py rename to tests/integration/server/api/v1/test_testbed.py diff --git a/test/integration/server/bootstrap/__init__.py b/tests/integration/server/bootstrap/__init__.py similarity index 100% rename from test/integration/server/bootstrap/__init__.py rename to tests/integration/server/bootstrap/__init__.py diff --git a/test/integration/server/bootstrap/conftest.py b/tests/integration/server/bootstrap/conftest.py similarity index 87% rename from test/integration/server/bootstrap/conftest.py rename to tests/integration/server/bootstrap/conftest.py index 4049fd39..2dfbf30c 100644 --- a/test/integration/server/bootstrap/conftest.py +++ b/tests/integration/server/bootstrap/conftest.py @@ -7,19 +7,19 @@ Integration tests for bootstrap test the actual bootstrap process with real file I/O, environment variables, and configuration loading. These tests verify end-to-end behavior of the bootstrap system. + +Note: Shared fixtures (reset_config_store, clean_env, make_database, make_model, etc.) +are automatically available via pytest_plugins in test/conftest.py. """ -# pylint: disable=redefined-outer-name unused-import +# pylint: disable=redefined-outer-name import json import tempfile from pathlib import Path -# Re-export shared fixtures for pytest discovery -from test.shared_fixtures import ( - reset_config_store, - clean_env, - BOOTSTRAP_ENV_VARS, +# Import constants needed by fixtures in this file +from tests.shared_fixtures import ( DEFAULT_LL_MODEL_CONFIG, TEST_INTEGRATION_DB_USER, TEST_INTEGRATION_DB_PASSWORD, @@ -29,8 +29,15 @@ import pytest -# Alias for backwards compatibility -clean_bootstrap_env = clean_env + +@pytest.fixture +def clean_bootstrap_env(clean_env): + """Alias for clean_env fixture for backwards compatibility. + + This fixture name is used in existing tests. It delegates to the + shared clean_env fixture loaded via pytest_plugins. + """ + yield @pytest.fixture diff --git a/test/integration/server/bootstrap/test_bootstrap_configfile.py b/tests/integration/server/bootstrap/test_bootstrap_configfile.py similarity index 100% rename from test/integration/server/bootstrap/test_bootstrap_configfile.py rename to tests/integration/server/bootstrap/test_bootstrap_configfile.py diff --git a/test/integration/server/bootstrap/test_bootstrap_databases.py b/tests/integration/server/bootstrap/test_bootstrap_databases.py similarity index 99% rename from test/integration/server/bootstrap/test_bootstrap_databases.py rename to tests/integration/server/bootstrap/test_bootstrap_databases.py index 70336db5..3a22b9a3 100644 --- a/test/integration/server/bootstrap/test_bootstrap_databases.py +++ b/tests/integration/server/bootstrap/test_bootstrap_databases.py @@ -12,7 +12,7 @@ import os -from test.shared_fixtures import ( +from tests.shared_fixtures import ( assert_database_list_valid, assert_has_default_database, get_database_by_name, diff --git a/test/integration/server/bootstrap/test_bootstrap_models.py b/tests/integration/server/bootstrap/test_bootstrap_models.py similarity index 99% rename from test/integration/server/bootstrap/test_bootstrap_models.py rename to tests/integration/server/bootstrap/test_bootstrap_models.py index b39042ec..ff094f31 100644 --- a/test/integration/server/bootstrap/test_bootstrap_models.py +++ b/tests/integration/server/bootstrap/test_bootstrap_models.py @@ -13,7 +13,7 @@ import os from unittest.mock import patch -from test.shared_fixtures import assert_model_list_valid, get_model_by_id +from tests.shared_fixtures import assert_model_list_valid, get_model_by_id import pytest diff --git a/test/integration/server/bootstrap/test_bootstrap_oci.py b/tests/integration/server/bootstrap/test_bootstrap_oci.py similarity index 100% rename from test/integration/server/bootstrap/test_bootstrap_oci.py rename to tests/integration/server/bootstrap/test_bootstrap_oci.py diff --git a/test/integration/server/bootstrap/test_bootstrap_settings.py b/tests/integration/server/bootstrap/test_bootstrap_settings.py similarity index 100% rename from test/integration/server/bootstrap/test_bootstrap_settings.py rename to tests/integration/server/bootstrap/test_bootstrap_settings.py diff --git a/test/opentofu/OMRMetaSchema.yaml b/tests/opentofu/OMRMetaSchema.yaml similarity index 100% rename from test/opentofu/OMRMetaSchema.yaml rename to tests/opentofu/OMRMetaSchema.yaml diff --git a/test/opentofu/validate_omr_schema.py b/tests/opentofu/validate_omr_schema.py similarity index 100% rename from test/opentofu/validate_omr_schema.py rename to tests/opentofu/validate_omr_schema.py diff --git a/test/shared_fixtures.py b/tests/shared_fixtures.py similarity index 63% rename from test/shared_fixtures.py rename to tests/shared_fixtures.py index cc4e634b..55170c24 100644 --- a/test/shared_fixtures.py +++ b/tests/shared_fixtures.py @@ -4,8 +4,33 @@ Shared pytest fixtures for unit and integration tests. -This module contains common fixture factories and utilities that are shared -across multiple test conftest files to avoid code duplication. +This module is loaded via pytest_plugins in test/conftest.py, making all +fixtures automatically available to all tests without explicit imports. + +FIXTURES (auto-loaded via pytest_plugins): + - make_database: Factory for Database objects + - make_model: Factory for Model objects + - make_oci_config: Factory for OracleCloudSettings objects + - make_ll_settings: Factory for LargeLanguageSettings objects + - make_settings: Factory for Settings objects + - make_configuration: Factory for Configuration objects + - temp_config_file: Creates temporary JSON config files + - reset_config_store: Resets ConfigStore singleton state + - clean_env: Clears bootstrap-related environment variables + - sample_vector_store_data: Sample vector store configuration + - sample_vector_store_data_alt: Alternative vector store configuration + - sample_vector_stores_list: List of sample vector stores + +CONSTANTS (require explicit import in test files): + - TEST_DB_USER, TEST_DB_PASSWORD, TEST_DB_DSN, TEST_DB_WALLET_PASSWORD + - TEST_API_KEY, TEST_API_KEY_ALT, TEST_AUTH_TOKEN + - TEST_INTEGRATION_DB_USER, TEST_INTEGRATION_DB_PASSWORD, TEST_INTEGRATION_DB_DSN + - DEFAULT_LL_MODEL_CONFIG, BOOTSTRAP_ENV_VARS + - SAMPLE_VECTOR_STORE_DATA, SAMPLE_VECTOR_STORE_DATA_ALT + +HELPER FUNCTIONS (require explicit import in test files): + - assert_database_list_valid, assert_has_default_database, get_database_by_name + - assert_model_list_valid, get_model_by_id """ # pylint: disable=redefined-outer-name @@ -91,6 +116,22 @@ "OCI_GENAI_SERVICE_ENDPOINT", ] +# API server environment variables +API_SERVER_ENV_VARS = [ + "API_SERVER_KEY", + "API_SERVER_URL", + "API_SERVER_PORT", +] + +# Config file environment variables +CONFIG_ENV_VARS = [ + "CONFIG_FILE", + "OCI_CLI_CONFIG_FILE", +] + +# All test-relevant environment variables (union of all categories) +ALL_TEST_ENV_VARS = list(set(BOOTSTRAP_ENV_VARS + API_SERVER_ENV_VARS + CONFIG_ENV_VARS)) + ################################################# # Schema Factory Fixtures @@ -338,21 +379,140 @@ def get_model_by_id(result, model_id): ################################################# +def _get_dynamic_oci_vars() -> list[str]: + """Get list of OCI_ prefixed environment variables currently set. + + Returns all environment variables starting with OCI_ that aren't + in our static list (catches user-specific OCI vars). + """ + static_oci_vars = {v for v in BOOTSTRAP_ENV_VARS if v.startswith("OCI_")} + return [v for v in os.environ if v.startswith("OCI_") and v not in static_oci_vars] + + @pytest.fixture -def clean_env(): - """Fixture to temporarily clear relevant environment variables.""" - original_values = {} +def clean_env(monkeypatch): + """Fixture to clear bootstrap-related environment variables using monkeypatch. + + Uses pytest's monkeypatch for proper isolation - changes are automatically + reverted after the test completes, even if the test fails. + + This fixture clears: + - Database variables (DB_USERNAME, DB_PASSWORD, etc.) + - Model API keys (OPENAI_API_KEY, COHERE_API_KEY, etc.) + - OCI variables (all OCI_* prefixed vars) + + Usage: + def test_bootstrap_without_env(clean_env): + # Environment is clean, no DB/API/OCI vars set + result = bootstrap.main() + assert result uses defaults + """ + # Clear all known bootstrap vars for var in BOOTSTRAP_ENV_VARS: - original_values[var] = os.environ.pop(var, None) + monkeypatch.delenv(var, raising=False) + + # Clear any dynamic OCI_ vars not in our static list + for var in _get_dynamic_oci_vars(): + monkeypatch.delenv(var, raising=False) + + yield + + +@pytest.fixture +def clean_all_env(monkeypatch): + """Fixture to clear ALL test-related environment variables. + + More aggressive than clean_env - also clears API server and config vars. + Use this when you need complete environment isolation. + + Usage: + def test_with_clean_slate(clean_all_env): + # No test-related env vars are set + pass + """ + for var in ALL_TEST_ENV_VARS: + monkeypatch.delenv(var, raising=False) + + # Clear any dynamic OCI_ vars + for var in _get_dynamic_oci_vars(): + monkeypatch.delenv(var, raising=False) yield - # Restore original values - for var, value in original_values.items(): - if value is not None: - os.environ[var] = value - elif var in os.environ: - del os.environ[var] + +@pytest.fixture +def isolated_env(monkeypatch): + """Fixture providing isolated environment with test defaults. + + Clears all test-related vars and sets safe defaults for test execution. + Use this when tests need a known, controlled environment state. + + Sets: + - CONFIG_FILE: /non/existent/path/config.json (forces empty config) + - OCI_CLI_CONFIG_FILE: /non/existent/path (prevents OCI config pickup) + + Usage: + def test_with_defaults(isolated_env): + # Environment has safe test defaults + pass + """ + # Clear all test-related vars first + for var in ALL_TEST_ENV_VARS: + monkeypatch.delenv(var, raising=False) + + # Clear dynamic OCI vars + for var in _get_dynamic_oci_vars(): + monkeypatch.delenv(var, raising=False) + + # Set safe test defaults + monkeypatch.setenv("CONFIG_FILE", "/non/existent/path/config.json") + monkeypatch.setenv("OCI_CLI_CONFIG_FILE", "/non/existent/path") + + yield monkeypatch # Yield monkeypatch so tests can add more vars if needed + + +def setup_test_env_vars( + monkeypatch, + auth_token: str = None, + server_url: str = "http://localhost", + server_port: int = 8000, + config_file: str = "/non/existent/path/config.json", +) -> None: + """Helper function to set up common test environment variables. + + This is a utility function (not a fixture) that can be called from + fixtures or tests to set up the environment consistently. + + Args: + monkeypatch: pytest monkeypatch fixture + auth_token: API server authentication token + server_url: API server URL (default: http://localhost) + server_port: API server port (default: 8000) + config_file: Path to config file (default: non-existent for empty config) + + Usage: + @pytest.fixture + def my_env(monkeypatch): + setup_test_env_vars(monkeypatch, auth_token="my-token", server_port=8015) + yield + """ + # Clear existing vars + for var in ALL_TEST_ENV_VARS: + monkeypatch.delenv(var, raising=False) + + # Clear dynamic OCI vars + for var in _get_dynamic_oci_vars(): + monkeypatch.delenv(var, raising=False) + + # Set config vars + monkeypatch.setenv("CONFIG_FILE", config_file) + monkeypatch.setenv("OCI_CLI_CONFIG_FILE", "/non/existent/path") + + # Set API server vars if token provided + if auth_token: + monkeypatch.setenv("API_SERVER_KEY", auth_token) + monkeypatch.setenv("API_SERVER_URL", server_url) + monkeypatch.setenv("API_SERVER_PORT", str(server_port)) ################################################# diff --git a/test/unit/__init__.py b/tests/unit/__init__.py similarity index 100% rename from test/unit/__init__.py rename to tests/unit/__init__.py diff --git a/test/unit/client/__init__.py b/tests/unit/client/__init__.py similarity index 100% rename from test/unit/client/__init__.py rename to tests/unit/client/__init__.py diff --git a/test/unit/client/conftest.py b/tests/unit/client/conftest.py similarity index 82% rename from test/unit/client/conftest.py rename to tests/unit/client/conftest.py index e4ffecb9..7ccdef56 100644 --- a/test/unit/client/conftest.py +++ b/tests/unit/client/conftest.py @@ -1,23 +1,20 @@ -# pylint: disable=import-error,redefined-outer-name,unused-import +# pylint: disable=import-error,redefined-outer-name """ Copyright (c) 2024, 2025, Oracle and/or its affiliates. Licensed under the Universal Permissive License v1.0 as shown at http://oss.oracle.com/licenses/upl. Unit test fixtures for client tests. Unit tests mock dependencies rather than requiring a real server, but some fixtures help establish Streamlit session state. + +Note: Shared fixtures (sample_vector_store_data, sample_vector_store_data_alt, +sample_vector_stores_list, make_database, make_model, etc.) are automatically +available via pytest_plugins in test/conftest.py. """ # spell-checker: disable import os import sys -# Re-export shared vector store fixtures for pytest discovery -from test.shared_fixtures import ( # noqa: F401 - sample_vector_store_data, - sample_vector_store_data_alt, - sample_vector_stores_list, -) - import pytest from streamlit import session_state as state diff --git a/test/unit/client/content/__init__.py b/tests/unit/client/content/__init__.py similarity index 100% rename from test/unit/client/content/__init__.py rename to tests/unit/client/content/__init__.py diff --git a/test/unit/client/content/config/__init__.py b/tests/unit/client/content/config/__init__.py similarity index 100% rename from test/unit/client/content/config/__init__.py rename to tests/unit/client/content/config/__init__.py diff --git a/test/unit/client/content/config/tabs/__init__.py b/tests/unit/client/content/config/tabs/__init__.py similarity index 100% rename from test/unit/client/content/config/tabs/__init__.py rename to tests/unit/client/content/config/tabs/__init__.py diff --git a/test/unit/client/content/config/tabs/test_mcp_unit.py b/tests/unit/client/content/config/tabs/test_mcp_unit.py similarity index 100% rename from test/unit/client/content/config/tabs/test_mcp_unit.py rename to tests/unit/client/content/config/tabs/test_mcp_unit.py diff --git a/test/unit/client/content/config/tabs/test_models_unit.py b/tests/unit/client/content/config/tabs/test_models_unit.py similarity index 100% rename from test/unit/client/content/config/tabs/test_models_unit.py rename to tests/unit/client/content/config/tabs/test_models_unit.py diff --git a/test/unit/client/content/test_chatbot_unit.py b/tests/unit/client/content/test_chatbot_unit.py similarity index 100% rename from test/unit/client/content/test_chatbot_unit.py rename to tests/unit/client/content/test_chatbot_unit.py diff --git a/test/unit/client/content/test_testbed_unit.py b/tests/unit/client/content/test_testbed_unit.py similarity index 100% rename from test/unit/client/content/test_testbed_unit.py rename to tests/unit/client/content/test_testbed_unit.py diff --git a/test/unit/client/content/tools/__init__.py b/tests/unit/client/content/tools/__init__.py similarity index 100% rename from test/unit/client/content/tools/__init__.py rename to tests/unit/client/content/tools/__init__.py diff --git a/test/unit/client/content/tools/tabs/__init__.py b/tests/unit/client/content/tools/tabs/__init__.py similarity index 100% rename from test/unit/client/content/tools/tabs/__init__.py rename to tests/unit/client/content/tools/tabs/__init__.py diff --git a/test/unit/client/content/tools/tabs/test_split_embed_unit.py b/tests/unit/client/content/tools/tabs/test_split_embed_unit.py similarity index 100% rename from test/unit/client/content/tools/tabs/test_split_embed_unit.py rename to tests/unit/client/content/tools/tabs/test_split_embed_unit.py diff --git a/test/unit/client/utils/__init__.py b/tests/unit/client/utils/__init__.py similarity index 100% rename from test/unit/client/utils/__init__.py rename to tests/unit/client/utils/__init__.py diff --git a/test/unit/client/utils/test_client_unit.py b/tests/unit/client/utils/test_client_unit.py similarity index 100% rename from test/unit/client/utils/test_client_unit.py rename to tests/unit/client/utils/test_client_unit.py diff --git a/test/unit/client/utils/test_st_common_unit.py b/tests/unit/client/utils/test_st_common_unit.py similarity index 100% rename from test/unit/client/utils/test_st_common_unit.py rename to tests/unit/client/utils/test_st_common_unit.py diff --git a/test/unit/client/utils/test_vs_options_unit.py b/tests/unit/client/utils/test_vs_options_unit.py similarity index 100% rename from test/unit/client/utils/test_vs_options_unit.py rename to tests/unit/client/utils/test_vs_options_unit.py diff --git a/test/unit/common/__init__.py b/tests/unit/common/__init__.py similarity index 100% rename from test/unit/common/__init__.py rename to tests/unit/common/__init__.py diff --git a/test/unit/common/test_functions.py b/tests/unit/common/test_functions.py similarity index 100% rename from test/unit/common/test_functions.py rename to tests/unit/common/test_functions.py diff --git a/test/unit/common/test_help_text.py b/tests/unit/common/test_help_text.py similarity index 100% rename from test/unit/common/test_help_text.py rename to tests/unit/common/test_help_text.py diff --git a/test/unit/common/test_logging_config.py b/tests/unit/common/test_logging_config.py similarity index 100% rename from test/unit/common/test_logging_config.py rename to tests/unit/common/test_logging_config.py diff --git a/test/unit/common/test_schema.py b/tests/unit/common/test_schema.py similarity index 100% rename from test/unit/common/test_schema.py rename to tests/unit/common/test_schema.py diff --git a/test/unit/common/test_version.py b/tests/unit/common/test_version.py similarity index 100% rename from test/unit/common/test_version.py rename to tests/unit/common/test_version.py diff --git a/tests/unit/conftest.py b/tests/unit/conftest.py new file mode 100644 index 00000000..3271e5c3 --- /dev/null +++ b/tests/unit/conftest.py @@ -0,0 +1,23 @@ +""" +Copyright (c) 2024, 2025, Oracle and/or its affiliates. +Licensed under the Universal Permissive License v1.0 as shown at http://oss.oracle.com/licenses/upl. + +Pytest configuration for unit tests. + +This conftest automatically marks all tests in the test/unit/ directory +with the 'unit' marker, enabling selective test execution: + + pytest -m "unit" # Run only unit tests + pytest -m "not unit" # Skip unit tests + pytest -m "unit and not slow" # Fast unit tests only +""" + +import pytest + + +def pytest_collection_modifyitems(config, items): + """Automatically add 'unit' marker to all tests in this directory.""" + for item in items: + # Check if the test is under test/unit/ + if "/test/unit/" in str(item.fspath): + item.add_marker(pytest.mark.unit) diff --git a/test/unit/server/__init__.py b/tests/unit/server/__init__.py similarity index 100% rename from test/unit/server/__init__.py rename to tests/unit/server/__init__.py diff --git a/test/unit/server/api/__init__.py b/tests/unit/server/api/__init__.py similarity index 100% rename from test/unit/server/api/__init__.py rename to tests/unit/server/api/__init__.py diff --git a/test/unit/server/api/conftest.py b/tests/unit/server/api/conftest.py similarity index 89% rename from test/unit/server/api/conftest.py rename to tests/unit/server/api/conftest.py index 858dca33..a640aa6c 100644 --- a/test/unit/server/api/conftest.py +++ b/tests/unit/server/api/conftest.py @@ -4,30 +4,23 @@ Pytest fixtures for server/api unit tests. Provides factory fixtures for creating test objects. + +Note: Shared fixtures (make_database, make_model, etc.) are automatically +available via pytest_plugins in test/conftest.py. Only import constants +and helper functions that are needed in this file. """ -# pylint: disable=redefined-outer-name unused-import -# Pytest fixtures use parameter injection where fixture names match parameters +# pylint: disable=redefined-outer-name from unittest.mock import MagicMock, AsyncMock -# Re-export shared fixtures for pytest discovery (before third-party imports per pylint) -from test.shared_fixtures import ( - make_database, - make_model, - make_oci_config, - make_ll_settings, - make_settings, - make_configuration, +# Import constants needed by fixtures in this file +from tests.shared_fixtures import ( TEST_DB_USER, TEST_DB_PASSWORD, TEST_DB_DSN, ) - -# Import TEST_DB_CONFIG for use in this file -# Note: db_container, db_connection, db_transaction fixtures are inherited -# from test/conftest.py - do not re-import here to avoid multiple containers -from test.db_fixtures import TEST_DB_CONFIG +from tests.db_fixtures import TEST_DB_CONFIG import pytest diff --git a/test/unit/server/api/utils/__init__.py b/tests/unit/server/api/utils/__init__.py similarity index 100% rename from test/unit/server/api/utils/__init__.py rename to tests/unit/server/api/utils/__init__.py diff --git a/test/unit/server/api/utils/test_utils_chat.py b/tests/unit/server/api/utils/test_utils_chat.py similarity index 100% rename from test/unit/server/api/utils/test_utils_chat.py rename to tests/unit/server/api/utils/test_utils_chat.py diff --git a/test/unit/server/api/utils/test_utils_databases.py b/tests/unit/server/api/utils/test_utils_databases.py similarity index 98% rename from test/unit/server/api/utils/test_utils_databases.py rename to tests/unit/server/api/utils/test_utils_databases.py index ac9b99c9..c340546d 100644 --- a/test/unit/server/api/utils/test_utils_databases.py +++ b/tests/unit/server/api/utils/test_utils_databases.py @@ -12,8 +12,8 @@ # pylint: disable=too-few-public-methods -from test.conftest import TEST_CONFIG -from test.shared_fixtures import TEST_DB_WALLET_PASSWORD +from tests.db_fixtures import TEST_DB_CONFIG +from tests.shared_fixtures import TEST_DB_WALLET_PASSWORD from unittest.mock import patch, MagicMock import pytest @@ -153,9 +153,9 @@ def test_connect_success_real_db(self, db_container, make_database): """connect should return connection on success (real database).""" # pylint: disable=unused-argument config = make_database( - user=TEST_CONFIG["db_username"], - password=TEST_CONFIG["db_password"], - dsn=TEST_CONFIG["db_dsn"], + user=TEST_DB_CONFIG["db_username"], + password=TEST_DB_CONFIG["db_password"], + dsn=TEST_DB_CONFIG["db_dsn"], ) result = utils_databases.connect(config) @@ -180,7 +180,7 @@ def test_connect_raises_permission_error_invalid_credentials(self, db_container, config = make_database( user="INVALID_USER", password=TEST_DB_WALLET_PASSWORD, # Using a fake password for invalid login test - dsn=TEST_CONFIG["db_dsn"], + dsn=TEST_DB_CONFIG["db_dsn"], ) with pytest.raises(PermissionError): @@ -194,8 +194,8 @@ def test_connect_raises_connection_error_invalid_dsn(self, db_container, make_da """ # pylint: disable=unused-argument config = make_database( - user=TEST_CONFIG["db_username"], - password=TEST_CONFIG["db_password"], + user=TEST_DB_CONFIG["db_username"], + password=TEST_DB_CONFIG["db_password"], dsn="//localhost:1525/NONEXISTENT_SERVICE", ) diff --git a/test/unit/server/api/utils/test_utils_embed.py b/tests/unit/server/api/utils/test_utils_embed.py similarity index 100% rename from test/unit/server/api/utils/test_utils_embed.py rename to tests/unit/server/api/utils/test_utils_embed.py diff --git a/test/unit/server/api/utils/test_utils_mcp.py b/tests/unit/server/api/utils/test_utils_mcp.py similarity index 99% rename from test/unit/server/api/utils/test_utils_mcp.py rename to tests/unit/server/api/utils/test_utils_mcp.py index 4fb234c1..34bd6a53 100644 --- a/test/unit/server/api/utils/test_utils_mcp.py +++ b/tests/unit/server/api/utils/test_utils_mcp.py @@ -7,7 +7,7 @@ """ import os -from test.shared_fixtures import TEST_API_KEY, TEST_API_KEY_ALT +from tests.shared_fixtures import TEST_API_KEY, TEST_API_KEY_ALT from unittest.mock import AsyncMock, MagicMock, patch import pytest diff --git a/test/unit/server/api/utils/test_utils_models.py b/tests/unit/server/api/utils/test_utils_models.py similarity index 100% rename from test/unit/server/api/utils/test_utils_models.py rename to tests/unit/server/api/utils/test_utils_models.py diff --git a/test/unit/server/api/utils/test_utils_oci.py b/tests/unit/server/api/utils/test_utils_oci.py similarity index 100% rename from test/unit/server/api/utils/test_utils_oci.py rename to tests/unit/server/api/utils/test_utils_oci.py diff --git a/test/unit/server/api/utils/test_utils_settings.py b/tests/unit/server/api/utils/test_utils_settings.py similarity index 100% rename from test/unit/server/api/utils/test_utils_settings.py rename to tests/unit/server/api/utils/test_utils_settings.py diff --git a/test/unit/server/api/utils/test_utils_testbed.py b/tests/unit/server/api/utils/test_utils_testbed.py similarity index 100% rename from test/unit/server/api/utils/test_utils_testbed.py rename to tests/unit/server/api/utils/test_utils_testbed.py diff --git a/test/unit/server/api/utils/test_utils_testbed_metrics.py b/tests/unit/server/api/utils/test_utils_testbed_metrics.py similarity index 100% rename from test/unit/server/api/utils/test_utils_testbed_metrics.py rename to tests/unit/server/api/utils/test_utils_testbed_metrics.py diff --git a/test/unit/server/api/utils/test_utils_webscrape.py b/tests/unit/server/api/utils/test_utils_webscrape.py similarity index 100% rename from test/unit/server/api/utils/test_utils_webscrape.py rename to tests/unit/server/api/utils/test_utils_webscrape.py diff --git a/test/unit/server/api/v1/__init__.py b/tests/unit/server/api/v1/__init__.py similarity index 100% rename from test/unit/server/api/v1/__init__.py rename to tests/unit/server/api/v1/__init__.py diff --git a/test/unit/server/api/v1/test_v1_chat.py b/tests/unit/server/api/v1/test_v1_chat.py similarity index 100% rename from test/unit/server/api/v1/test_v1_chat.py rename to tests/unit/server/api/v1/test_v1_chat.py diff --git a/test/unit/server/api/v1/test_v1_databases.py b/tests/unit/server/api/v1/test_v1_databases.py similarity index 100% rename from test/unit/server/api/v1/test_v1_databases.py rename to tests/unit/server/api/v1/test_v1_databases.py diff --git a/test/unit/server/api/v1/test_v1_embed.py b/tests/unit/server/api/v1/test_v1_embed.py similarity index 100% rename from test/unit/server/api/v1/test_v1_embed.py rename to tests/unit/server/api/v1/test_v1_embed.py diff --git a/test/unit/server/api/v1/test_v1_mcp.py b/tests/unit/server/api/v1/test_v1_mcp.py similarity index 99% rename from test/unit/server/api/v1/test_v1_mcp.py rename to tests/unit/server/api/v1/test_v1_mcp.py index 516d76cc..7e55b6ab 100644 --- a/test/unit/server/api/v1/test_v1_mcp.py +++ b/tests/unit/server/api/v1/test_v1_mcp.py @@ -8,7 +8,7 @@ # pylint: disable=too-few-public-methods -from test.shared_fixtures import TEST_API_KEY +from tests.shared_fixtures import TEST_API_KEY from unittest.mock import AsyncMock, MagicMock, patch import pytest diff --git a/test/unit/server/api/v1/test_v1_mcp_prompts.py b/tests/unit/server/api/v1/test_v1_mcp_prompts.py similarity index 100% rename from test/unit/server/api/v1/test_v1_mcp_prompts.py rename to tests/unit/server/api/v1/test_v1_mcp_prompts.py diff --git a/test/unit/server/api/v1/test_v1_models.py b/tests/unit/server/api/v1/test_v1_models.py similarity index 100% rename from test/unit/server/api/v1/test_v1_models.py rename to tests/unit/server/api/v1/test_v1_models.py diff --git a/test/unit/server/api/v1/test_v1_oci.py b/tests/unit/server/api/v1/test_v1_oci.py similarity index 100% rename from test/unit/server/api/v1/test_v1_oci.py rename to tests/unit/server/api/v1/test_v1_oci.py diff --git a/test/unit/server/api/v1/test_v1_probes.py b/tests/unit/server/api/v1/test_v1_probes.py similarity index 100% rename from test/unit/server/api/v1/test_v1_probes.py rename to tests/unit/server/api/v1/test_v1_probes.py diff --git a/test/unit/server/api/v1/test_v1_settings.py b/tests/unit/server/api/v1/test_v1_settings.py similarity index 100% rename from test/unit/server/api/v1/test_v1_settings.py rename to tests/unit/server/api/v1/test_v1_settings.py diff --git a/test/unit/server/api/v1/test_v1_testbed.py b/tests/unit/server/api/v1/test_v1_testbed.py similarity index 100% rename from test/unit/server/api/v1/test_v1_testbed.py rename to tests/unit/server/api/v1/test_v1_testbed.py diff --git a/test/unit/server/bootstrap/__init__.py b/tests/unit/server/bootstrap/__init__.py similarity index 100% rename from test/unit/server/bootstrap/__init__.py rename to tests/unit/server/bootstrap/__init__.py diff --git a/test/unit/server/bootstrap/conftest.py b/tests/unit/server/bootstrap/conftest.py similarity index 74% rename from test/unit/server/bootstrap/conftest.py rename to tests/unit/server/bootstrap/conftest.py index 9bc23743..b7c4a9a1 100644 --- a/test/unit/server/bootstrap/conftest.py +++ b/tests/unit/server/bootstrap/conftest.py @@ -4,26 +4,15 @@ Pytest fixtures for server/bootstrap unit tests. -Re-exports shared fixtures from test.shared_fixtures and adds unit-test specific fixtures. +Note: Shared fixtures (make_database, make_model, make_oci_config, make_ll_settings, +make_settings, make_configuration, temp_config_file, reset_config_store, clean_env) +are automatically available via pytest_plugins in test/conftest.py. """ -# pylint: disable=redefined-outer-name unused-import +# pylint: disable=redefined-outer-name from unittest.mock import MagicMock, patch -# Re-export shared fixtures for pytest discovery -from test.shared_fixtures import ( - make_database, - make_model, - make_oci_config, - make_ll_settings, - make_settings, - make_configuration, - temp_config_file, - reset_config_store, - clean_env, -) - import pytest diff --git a/test/unit/server/bootstrap/test_bootstrap_bootstrap.py b/tests/unit/server/bootstrap/test_bootstrap_bootstrap.py similarity index 100% rename from test/unit/server/bootstrap/test_bootstrap_bootstrap.py rename to tests/unit/server/bootstrap/test_bootstrap_bootstrap.py diff --git a/test/unit/server/bootstrap/test_bootstrap_configfile.py b/tests/unit/server/bootstrap/test_bootstrap_configfile.py similarity index 100% rename from test/unit/server/bootstrap/test_bootstrap_configfile.py rename to tests/unit/server/bootstrap/test_bootstrap_configfile.py diff --git a/test/unit/server/bootstrap/test_bootstrap_databases.py b/tests/unit/server/bootstrap/test_bootstrap_databases.py similarity index 99% rename from test/unit/server/bootstrap/test_bootstrap_databases.py rename to tests/unit/server/bootstrap/test_bootstrap_databases.py index 3a658689..195a9e75 100644 --- a/test/unit/server/bootstrap/test_bootstrap_databases.py +++ b/tests/unit/server/bootstrap/test_bootstrap_databases.py @@ -10,7 +10,7 @@ import os -from test.shared_fixtures import ( +from tests.shared_fixtures import ( assert_database_list_valid, assert_has_default_database, get_database_by_name, diff --git a/test/unit/server/bootstrap/test_bootstrap_models.py b/tests/unit/server/bootstrap/test_bootstrap_models.py similarity index 99% rename from test/unit/server/bootstrap/test_bootstrap_models.py rename to tests/unit/server/bootstrap/test_bootstrap_models.py index 4950b09c..6635a27b 100644 --- a/test/unit/server/bootstrap/test_bootstrap_models.py +++ b/tests/unit/server/bootstrap/test_bootstrap_models.py @@ -11,7 +11,7 @@ import os from unittest.mock import patch -from test.shared_fixtures import assert_model_list_valid, get_model_by_id, TEST_API_KEY +from tests.shared_fixtures import assert_model_list_valid, get_model_by_id, TEST_API_KEY import pytest diff --git a/test/unit/server/bootstrap/test_bootstrap_oci.py b/tests/unit/server/bootstrap/test_bootstrap_oci.py similarity index 100% rename from test/unit/server/bootstrap/test_bootstrap_oci.py rename to tests/unit/server/bootstrap/test_bootstrap_oci.py diff --git a/test/unit/server/bootstrap/test_bootstrap_settings.py b/tests/unit/server/bootstrap/test_bootstrap_settings.py similarity index 100% rename from test/unit/server/bootstrap/test_bootstrap_settings.py rename to tests/unit/server/bootstrap/test_bootstrap_settings.py From 7ed1d98964c8826bdab82a48c74d6a71cfd5dbcd Mon Sep 17 00:00:00 2001 From: gotsysdba Date: Sun, 30 Nov 2025 15:12:42 +0000 Subject: [PATCH 11/20] Remove patches --- pyproject.toml | 4 +- src/common/schema.py | 20 +-- src/launch_server.py | 3 - src/server/api/utils/testbed.py | 20 +-- src/server/api/v1/testbed.py | 34 ++--- src/server/patches/__init__.py | 0 src/server/patches/litellm_patch.py | 40 ------ .../patches/litellm_patch_oci_streaming.py | 132 ------------------ src/server/patches/litellm_patch_transform.py | 80 ----------- tests/db_fixtures.py | 6 +- tests/integration/client/conftest.py | 58 +++----- .../content/config/tabs/test_databases.py | 3 +- .../client/content/config/tabs/test_models.py | 2 +- .../client/content/config/test_config.py | 3 +- .../client/content/test_chatbot.py | 2 +- .../client/content/test_testbed.py | 2 +- .../content/tools/tabs/test_split_embed.py | 6 +- .../client/content/tools/test_tools.py | 2 +- tests/integration/common/test_functions.py | 3 +- tests/integration/conftest.py | 2 +- tests/integration/server/api/conftest.py | 59 +++----- .../integration/server/api/v1/test_testbed.py | 4 +- .../integration/server/bootstrap/conftest.py | 7 +- .../bootstrap/test_bootstrap_databases.py | 7 +- .../server/bootstrap/test_bootstrap_models.py | 3 +- tests/shared_fixtures.py | 78 +++++++++++ tests/unit/common/test_schema.py | 26 ++-- tests/unit/conftest.py | 2 +- tests/unit/server/api/conftest.py | 14 +- .../server/api/utils/test_utils_databases.py | 4 +- tests/unit/server/api/utils/test_utils_mcp.py | 2 +- .../server/api/utils/test_utils_webscrape.py | 2 +- tests/unit/server/api/v1/test_v1_embed.py | 2 +- tests/unit/server/api/v1/test_v1_mcp.py | 2 +- tests/unit/server/api/v1/test_v1_testbed.py | 12 +- .../bootstrap/test_bootstrap_databases.py | 7 +- .../server/bootstrap/test_bootstrap_models.py | 3 +- 37 files changed, 218 insertions(+), 438 deletions(-) delete mode 100644 src/server/patches/__init__.py delete mode 100644 src/server/patches/litellm_patch.py delete mode 100644 src/server/patches/litellm_patch_oci_streaming.py delete mode 100644 src/server/patches/litellm_patch_transform.py diff --git a/pyproject.toml b/pyproject.toml index 922ffede..2e4f5c33 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -32,7 +32,7 @@ server = [ "langchain-aimlapi==0.1.0", "langchain-cohere==0.4.6", "langchain-community==0.3.31", - "langchain-fireworks==0.3.0", + #"langchain-fireworks==0.3.0", "langchain-google-genai==2.1.12", "langchain-ibm==0.3.20", "langchain-mcp-adapters==0.1.13", @@ -43,7 +43,7 @@ server = [ "langchain-openai==0.3.35", "langchain-together==0.3.1", "langgraph==1.0.1", - "litellm==1.80.0", + "litellm==1.80.7", "llama-index==0.14.8", "lxml==6.0.2", "matplotlib==3.10.7", diff --git a/src/common/schema.py b/src/common/schema.py index 057487a0..8d761c09 100644 --- a/src/common/schema.py +++ b/src/common/schema.py @@ -346,18 +346,18 @@ class ChatRequest(LanguageModelParameters): ##################################################### # Testbed ##################################################### -class TestSets(BaseModel): - """TestSets""" +class QASets(BaseModel): + """QA Sets - Collection of Q&A test sets for testbed evaluation""" tid: str = Field(description="Test ID") - name: str = Field(description="Name of TestSet") - created: str = Field(description="Date TestSet Loaded") + name: str = Field(description="Name of QA Set") + created: str = Field(description="Date QA Set Loaded") -class TestSetQA(BaseModel): - """TestSet Q&A""" +class QASetData(BaseModel): + """QA Set Data - Question/Answer pairs for testbed evaluation""" - qa_data: list = Field(description="TestSet Q&A Data") + qa_data: list = Field(description="QA Set Data") class Evaluation(BaseModel): @@ -390,6 +390,6 @@ class EvaluationReport(Evaluation): ModelEnabledType = ModelAccess.__annotations__["enabled"] OCIProfileType = OracleCloudSettings.__annotations__["auth_profile"] OCIResourceOCID = OracleResource.__annotations__["ocid"] -TestSetsIdType = TestSets.__annotations__["tid"] -TestSetsNameType = TestSets.__annotations__["name"] -TestSetDateType = TestSets.__annotations__["created"] +QASetsIdType = QASets.__annotations__["tid"] +QASetsNameType = QASets.__annotations__["name"] +QASetsDateType = QASets.__annotations__["created"] diff --git a/src/launch_server.py b/src/launch_server.py index 5f2c2453..5bb0da31 100644 --- a/src/launch_server.py +++ b/src/launch_server.py @@ -5,9 +5,6 @@ # spell-checker:ignore configfile fastmcp noauth getpid procs litellm giskard ollama # spell-checker:ignore dotenv apiserver laddr -# Patch litellm for Giskard/Ollama issue -import server.patches.litellm_patch # pylint: disable=unused-import, wrong-import-order - # Set OS Environment before importing other modules # Set OS Environment (Don't move their position to reflect on imports) # pylint: disable=wrong-import-position diff --git a/src/server/api/utils/testbed.py b/src/server/api/utils/testbed.py index 21a790d3..dc3d252e 100644 --- a/src/server/api/utils/testbed.py +++ b/src/server/api/utils/testbed.py @@ -95,14 +95,14 @@ def get_testsets(db_conn: Connection) -> list: sql = "SELECT tid, name, to_char(created) FROM oai_testsets ORDER BY created" results = utils_databases.execute_sql(db_conn, sql) try: - testsets = [schema.TestSets(tid=tid.hex(), name=name, created=created) for tid, name, created in results] + testsets = [schema.QASets(tid=tid.hex(), name=name, created=created) for tid, name, created in results] except TypeError: create_testset_objects(db_conn) return testsets -def get_testset_qa(db_conn: Connection, tid: schema.TestSetsIdType) -> schema.TestSetQA: +def get_testset_qa(db_conn: Connection, tid: schema.QASetsIdType) -> schema.QASetData: """Get list of TestSet Q&A""" logger.info("Getting TestSet Q&A for TID: %s", tid) binds = {"tid": tid} @@ -110,10 +110,10 @@ def get_testset_qa(db_conn: Connection, tid: schema.TestSetsIdType) -> schema.Te results = utils_databases.execute_sql(db_conn, sql, binds) qa_data = [qa_data[0] for qa_data in results] - return schema.TestSetQA(qa_data=qa_data) + return schema.QASetData(qa_data=qa_data) -def get_evaluations(db_conn: Connection, tid: schema.TestSetsIdType) -> list[schema.Evaluation]: +def get_evaluations(db_conn: Connection, tid: schema.QASetsIdType) -> list[schema.Evaluation]: """Get list of Evaluations for a TID""" logger.info("Getting Evaluations for: %s", tid) evaluations = [] @@ -133,7 +133,7 @@ def get_evaluations(db_conn: Connection, tid: schema.TestSetsIdType) -> list[sch def delete_qa( db_conn: Connection, - tid: schema.TestSetsIdType, + tid: schema.QASetsIdType, ) -> None: """Delete Q&A""" binds = {"tid": tid} @@ -144,11 +144,11 @@ def delete_qa( def upsert_qa( db_conn: Connection, - name: schema.TestSetsNameType, - created: schema.TestSetDateType, + name: schema.QASetsNameType, + created: schema.QASetsDateType, json_data: json, - tid: schema.TestSetsIdType = None, -) -> schema.TestSetsIdType: + tid: schema.QASetsIdType = None, +) -> schema.QASetsIdType: """Upsert Q&A""" logger.info("Upsert TestSet: %s - %s", name, created) parsed_data = json.loads(json_data) @@ -270,7 +270,7 @@ def build_knowledge_base( return testset -def process_report(db_conn: Connection, eid: schema.TestSetsIdType) -> schema.EvaluationReport: +def process_report(db_conn: Connection, eid: schema.QASetsIdType) -> schema.EvaluationReport: """Process an evaluate report""" # Main diff --git a/src/server/api/v1/testbed.py b/src/server/api/v1/testbed.py index 6705b149..035f3a0b 100644 --- a/src/server/api/v1/testbed.py +++ b/src/server/api/v1/testbed.py @@ -40,11 +40,11 @@ @auth.get( "/testsets", description="Get Stored TestSets.", - response_model=list[schema.TestSets], + response_model=list[schema.QASets], ) async def testbed_testsets( client: schema.ClientIdType = Header(default="server"), -) -> list[schema.TestSets]: +) -> list[schema.QASets]: """Get a list of stored TestSets, create TestSet objects if they don't exist""" testsets = utils_testbed.get_testsets(db_conn=utils_databases.get_client_database(client).connection) return testsets @@ -56,7 +56,7 @@ async def testbed_testsets( response_model=list[schema.Evaluation], ) async def testbed_evaluations( - tid: schema.TestSetsIdType, + tid: schema.QASetsIdType, client: schema.ClientIdType = Header(default="server"), ) -> list[schema.Evaluation]: """Get Evaluations""" @@ -72,7 +72,7 @@ async def testbed_evaluations( response_model=schema.EvaluationReport, ) async def testbed_evaluation( - eid: schema.TestSetsIdType, + eid: schema.QASetsIdType, client: schema.ClientIdType = Header(default="server"), ) -> schema.EvaluationReport: """Get Evaluations""" @@ -84,13 +84,13 @@ async def testbed_evaluation( @auth.get( "/testset_qa", - description="Get Stored schema.TestSets Q&A.", - response_model=schema.TestSetQA, + description="Get Stored Testbed Q&A.", + response_model=schema.QASetData, ) async def testbed_testset_qa( - tid: schema.TestSetsIdType, + tid: schema.QASetsIdType, client: schema.ClientIdType = Header(default="server"), -) -> schema.TestSetQA: +) -> schema.QASetData: """Get TestSet Q&A""" return utils_testbed.get_testset_qa( db_conn=utils_databases.get_client_database(client).connection, tid=tid.upper() @@ -102,7 +102,7 @@ async def testbed_testset_qa( description="Delete a TestSet", ) async def testbed_delete_testset( - tid: Optional[schema.TestSetsIdType] = None, + tid: Optional[schema.QASetsIdType] = None, client: schema.ClientIdType = Header(default="server"), ) -> JSONResponse: """Delete TestSet""" @@ -113,14 +113,14 @@ async def testbed_delete_testset( @auth.post( "/testset_load", description="Upsert TestSets.", - response_model=schema.TestSetQA, + response_model=schema.QASetData, ) async def testbed_upsert_testsets( files: list[UploadFile], - name: schema.TestSetsNameType, - tid: Optional[schema.TestSetsIdType] = None, + name: schema.QASetsNameType, + tid: Optional[schema.QASetsIdType] = None, client: schema.ClientIdType = Header(default="server"), -) -> schema.TestSetQA: +) -> schema.QASetData: """Update stored TestSet data""" created = datetime.now().isoformat() db_conn = utils_databases.get_client_database(client).connection @@ -194,16 +194,16 @@ def _handle_testset_error(ex: Exception, temp_directory, ll_model: str): @auth.post( "/testset_generate", description="Generate Q&A Test Set.", - response_model=schema.TestSetQA, + response_model=schema.QASetData, ) async def testbed_generate_qa( files: list[UploadFile], - name: schema.TestSetsNameType, + name: schema.QASetsNameType, ll_model: str, embed_model: str, questions: int = 2, client: schema.ClientIdType = Header(default="server"), -) -> schema.TestSetQA: +) -> schema.QASetData: """Retrieve contents from a local file uploaded and generate Q&A""" # Get the Model Configuration try: @@ -249,7 +249,7 @@ async def _collect_testbed_answers(loaded_testset: QATestset, client: str) -> li response_model=schema.EvaluationReport, ) async def testbed_evaluate( - tid: schema.TestSetsIdType, + tid: schema.QASetsIdType, judge: str, client: schema.ClientIdType = Header(default="server"), ) -> schema.EvaluationReport: diff --git a/src/server/patches/__init__.py b/src/server/patches/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/src/server/patches/litellm_patch.py b/src/server/patches/litellm_patch.py deleted file mode 100644 index d8736e9e..00000000 --- a/src/server/patches/litellm_patch.py +++ /dev/null @@ -1,40 +0,0 @@ -""" -Copyright (c) 2024, 2025, Oracle and/or its affiliates. -Licensed under the Universal Permissive License v1.0 as shown at http://oss.oracle.com/licenses/upl. - -LiteLLM Patch Orchestrator -========================== -This module serves as the entry point for all litellm patches. -It imports and applies patches from specialized modules: - -- litellm_patch_transform: Ollama transform_response patch for non-streaming responses -- litellm_patch_oci_auth: OCI authentication patches (instance principals, request signing) -- litellm_patch_oci_streaming: OCI streaming patches (tool call field fixes) - -All patches use guard checks to prevent double-patching. -""" -# spell-checker:ignore litellm - -from common import logging_config - -logger = logging_config.logging.getLogger("patches.litellm_patch") - -logger.info("Loading litellm patches...") - -# Import patch modules - they apply patches on import -# pylint: disable=unused-import -try: - from . import litellm_patch_transform - - logger.info("✓ Ollama transform_response patch loaded") -except Exception as e: - logger.error("✗ Failed to load Ollama transform patch: %s", e) - -try: - from . import litellm_patch_oci_streaming - - logger.info("✓ OCI streaming patches loaded (handle_generic_stream_chunk)") -except Exception as e: - logger.error("✗ Failed to load OCI streaming patches: %s", e) - -logger.info("All litellm patches loaded successfully") diff --git a/src/server/patches/litellm_patch_oci_streaming.py b/src/server/patches/litellm_patch_oci_streaming.py deleted file mode 100644 index bf3db3eb..00000000 --- a/src/server/patches/litellm_patch_oci_streaming.py +++ /dev/null @@ -1,132 +0,0 @@ -""" -Copyright (c) 2024, 2025, Oracle and/or its affiliates. -Licensed under the Universal Permissive License v1.0 as shown at http://oss.oracle.com/licenses/upl. - -OCI Streaming Patches -===================== -Patches for OCI GenAI service streaming responses with tool calls. - -Issue: OCI API returns tool calls without 'arguments' field, causing Pydantic validation error -Error: ValidationError: 1 validation error for OCIStreamChunk message.toolCalls.0.arguments Field required - -This happens when OCI models (e.g., meta.llama-3.1-405b-instruct) attempt tool calling but return -incomplete tool call structures missing the required 'arguments' field during streaming. - -This module patches OCIStreamWrapper._handle_generic_stream_chunk to add missing required fields -with empty defaults before Pydantic validation. -""" -# spell-checker:ignore litellm giskard ollama llms -# pylint: disable=unused-argument,protected-access - -from common import logging_config - -logger = logging_config.logging.getLogger("patches.litellm_patch_oci_streaming") - -# Patch OCI _handle_generic_stream_chunk to add missing 'arguments' field in tool calls -try: - from litellm.llms.oci.chat.transformation import OCIStreamWrapper - - original_handle_generic_stream_chunk = getattr(OCIStreamWrapper, "_handle_generic_stream_chunk", None) -except ImportError: - original_handle_generic_stream_chunk = None - -if original_handle_generic_stream_chunk and not getattr( - original_handle_generic_stream_chunk, "_is_custom_patch", False -): - from litellm.llms.oci.chat.transformation import ( - OCIStreamChunk, - OCITextContentPart, - OCIImageContentPart, - adapt_tools_to_openai_standard, - ) - from litellm.types.utils import ModelResponseStream, StreamingChoices, Delta - - def _fix_missing_tool_call_fields(tool_call: dict) -> list: - """Add missing required fields to tool call and return list of missing fields""" - missing_fields = [] - if "arguments" not in tool_call: - tool_call["arguments"] = "" - missing_fields.append("arguments") - if "id" not in tool_call: - tool_call["id"] = "" - missing_fields.append("id") - if "name" not in tool_call: - tool_call["name"] = "" - missing_fields.append("name") - return missing_fields - - def _patch_tool_calls(dict_chunk: dict) -> None: - """Fix missing required fields in tool calls before Pydantic validation""" - if dict_chunk.get("message") and dict_chunk["message"].get("toolCalls"): - for tool_call in dict_chunk["message"]["toolCalls"]: - missing_fields = _fix_missing_tool_call_fields(tool_call) - if missing_fields: - logger.debug( - "OCI tool call streaming chunk missing fields: %s (Type: %s) - adding empty defaults", - missing_fields, - tool_call.get("type", "unknown"), - ) - - def _extract_text_content(typed_chunk: OCIStreamChunk) -> str: - """Extract text content from chunk message""" - text = "" - if typed_chunk.message and typed_chunk.message.content: - for item in typed_chunk.message.content: - if isinstance(item, OCITextContentPart): - text += item.text - elif isinstance(item, OCIImageContentPart): - raise ValueError("OCI does not support image content in streaming responses") - else: - raise ValueError(f"Unsupported content type in OCI response: {item.type}") - return text - - def custom_handle_generic_stream_chunk(self, dict_chunk: dict): - """ - Custom handler to fix missing 'arguments' field in OCI tool calls. - - OCI API sometimes returns tool calls with structure: - {'type': 'FUNCTION', 'id': '...', 'name': 'tool_name'} - - But OCIStreamChunk Pydantic model requires 'arguments' field in tool calls. - This patch adds an empty arguments dict if missing. - """ - # Fix missing required fields in tool calls before Pydantic validation - # OCI streams tool calls progressively, so early chunks may be missing required fields - _patch_tool_calls(dict_chunk) - - # Now proceed with original validation and processing - try: - typed_chunk = OCIStreamChunk(**dict_chunk) - except TypeError as e: - raise ValueError(f"Chunk cannot be casted to OCIStreamChunk: {str(e)}") from e - - if typed_chunk.index is None: - typed_chunk.index = 0 - - text = _extract_text_content(typed_chunk) - - tool_calls = None - if typed_chunk.message and typed_chunk.message.toolCalls: - tool_calls = adapt_tools_to_openai_standard(typed_chunk.message.toolCalls) - - return ModelResponseStream( - choices=[ - StreamingChoices( - index=typed_chunk.index if typed_chunk.index else 0, - delta=Delta( - content=text, - tool_calls=[tool.model_dump() for tool in tool_calls] if tool_calls else None, - provider_specific_fields=None, - thinking_blocks=None, - reasoning_content=None, - ), - finish_reason=typed_chunk.finishReason, - ) - ] - ) - - # Mark it to avoid double patching - custom_handle_generic_stream_chunk._is_custom_patch = True - - # Patch it - OCIStreamWrapper._handle_generic_stream_chunk = custom_handle_generic_stream_chunk diff --git a/src/server/patches/litellm_patch_transform.py b/src/server/patches/litellm_patch_transform.py deleted file mode 100644 index 2bba1f26..00000000 --- a/src/server/patches/litellm_patch_transform.py +++ /dev/null @@ -1,80 +0,0 @@ -""" -Copyright (c) 2024, 2025, Oracle and/or its affiliates. -Licensed under the Universal Permissive License v1.0 as shown at http://oss.oracle.com/licenses/upl. -""" -# spell-checker:ignore litellm giskard ollama llms -# pylint: disable=unused-argument,protected-access - -from typing import TYPE_CHECKING, List, Any -import time -import litellm -from litellm.llms.ollama.completion.transformation import OllamaConfig -from litellm.types.llms.openai import AllMessageValues -from litellm.types.utils import ModelResponse -from httpx._models import Response - -from common import logging_config - -logger = logging_config.logging.getLogger("patches.litellm_patch_transform") - -# Only patch if not already patched -if not getattr(OllamaConfig.transform_response, "_is_custom_patch", False): - if TYPE_CHECKING: - from litellm.litellm_core_utils.litellm_logging import Logging as _LiteLLMLoggingObj - - LiteLLMLoggingObj = _LiteLLMLoggingObj - else: - LiteLLMLoggingObj = Any - - def custom_transform_response( - self, - model: str, - raw_response: Response, - model_response: ModelResponse, - logging_obj: LiteLLMLoggingObj, - request_data: dict, - messages: List[AllMessageValues], - optional_params: dict, - litellm_params: dict, - encoding: str, - **kwargs, - ): - """ - Custom transform response from - .venv/lib/python3.11/site-packages/litellm/llms/ollama/completion/transformation.py - - Additional kwargs: - api_key: Optional[str] - API key for authentication - json_mode: Optional[bool] - JSON mode flag - """ - logger.info("Custom transform_response is running") - response_json = raw_response.json() - - model_response.choices[0].finish_reason = "stop" - model_response.choices[0].message.content = response_json["response"] - - _prompt = request_data.get("prompt", "") - prompt_tokens = response_json.get( - "prompt_eval_count", - len(encoding.encode(_prompt, disallowed_special=())), - ) - completion_tokens = response_json.get("eval_count", len(response_json.get("message", {}).get("content", ""))) - - setattr( - model_response, - "usage", - litellm.Usage( - prompt_tokens=prompt_tokens, - completion_tokens=completion_tokens, - total_tokens=prompt_tokens + completion_tokens, - ), - ) - model_response.created = int(time.time()) - model_response.model = "ollama/" + model - return model_response - - # Mark it to avoid double patching - custom_transform_response._is_custom_patch = True - - # Patch it - OllamaConfig.transform_response = custom_transform_response diff --git a/tests/db_fixtures.py b/tests/db_fixtures.py index 293e5d80..4af4ce84 100644 --- a/tests/db_fixtures.py +++ b/tests/db_fixtures.py @@ -79,7 +79,7 @@ } -def pytest_collection_modifyitems(config, items): +def pytest_collection_modifyitems(items): """Automatically mark tests using database fixtures with 'db' and 'slow' markers. This hook inspects each test's fixture requirements and adds markers @@ -144,7 +144,7 @@ def wait_for_container_ready( @contextmanager -def temp_sql_setup(temp_dir_path: str = "test/db_startup_temp"): +def temp_sql_setup(temp_dir_path: str = "tests/db_startup_temp"): """Context manager for temporary SQL setup files. Creates a temporary directory with SQL initialization scripts @@ -182,7 +182,7 @@ def temp_sql_setup(temp_dir_path: str = "test/db_startup_temp"): shutil.rmtree(temp_dir) -def create_db_container(temp_dir_name: str = "test/db_startup_temp") -> Generator[Container, None, None]: +def create_db_container(temp_dir_name: str = "tests/db_startup_temp") -> Generator[Container, None, None]: """Create and manage an Oracle database container for testing. This generator function handles the full lifecycle of a Docker-based diff --git a/tests/integration/client/conftest.py b/tests/integration/client/conftest.py index e3780edd..f10c89c8 100644 --- a/tests/integration/client/conftest.py +++ b/tests/integration/client/conftest.py @@ -23,26 +23,28 @@ import socket import subprocess from contextlib import contextmanager - -# Import constants needed by fixtures and helper functions in this file -from tests.shared_fixtures import TEST_AUTH_TOKEN, ALL_TEST_ENV_VARS -from tests.db_fixtures import TEST_DB_CONFIG +from functools import lru_cache import pytest import requests -# Lazy import to avoid circular imports - stored in module-level variable -_app_test_class = None +# Import constants and helpers needed by fixtures in this file +from tests.db_fixtures import TEST_DB_CONFIG +from tests.shared_fixtures import ( + TEST_AUTH_TOKEN, + make_auth_headers, + save_env_state, + clear_env_state, + restore_env_state, +) +@lru_cache(maxsize=1) def get_app_test(): """Lazy import of Streamlit's AppTest.""" - global _app_test_class # pylint: disable=global-statement - if _app_test_class is None: - from streamlit.testing.v1 import AppTest # pylint: disable=import-outside-toplevel + from streamlit.testing.v1 import AppTest # pylint: disable=import-outside-toplevel - _app_test_class = AppTest - return _app_test_class + return AppTest ################################################# @@ -68,19 +70,8 @@ def client_test_env(): The `app_server` fixture depends on this to ensure environment is configured before the subprocess server is started. """ - # Save original environment state - original_env = {var: os.environ.get(var) for var in ALL_TEST_ENV_VARS} - - # Also capture dynamic OCI_ vars - dynamic_oci_vars = [v for v in os.environ if v.startswith("OCI_") and v not in ALL_TEST_ENV_VARS] - for var in dynamic_oci_vars: - original_env[var] = os.environ.get(var) - - # Clear all test-related vars - for var in ALL_TEST_ENV_VARS: - os.environ.pop(var, None) - for var in dynamic_oci_vars: - os.environ.pop(var, None) + original_env = save_env_state() + clear_env_state(original_env) # Set required environment variables for client tests os.environ["CONFIG_FILE"] = "/non/existent/path/config.json" @@ -91,12 +82,7 @@ def client_test_env(): yield - # Restore original environment state - for var, value in original_env.items(): - if value is not None: - os.environ[var] = value - elif var in os.environ: - del os.environ[var] + restore_env_state(original_env) ################################################# @@ -107,11 +93,7 @@ def client_test_env(): @pytest.fixture(name="auth_headers") def _auth_headers(): """Return common header configurations for testing.""" - return { - "no_auth": {}, - "invalid_auth": {"Authorization": "Bearer invalid-token", "client": TEST_CLIENT}, - "valid_auth": {"Authorization": f"Bearer {TEST_AUTH_TOKEN}", "client": TEST_CLIENT}, - } + return make_auth_headers(TEST_AUTH_TOKEN, TEST_CLIENT) @pytest.fixture(scope="session") @@ -140,12 +122,12 @@ def is_port_in_use(port): server_process = subprocess.Popen(cmd, cwd="src", env=env) # pylint: disable=consider-using-with try: - # Wait for server to be ready (up to 30 seconds) - max_wait = 30 + # Wait for server to be ready (up to 120 seconds) + max_wait = 120 start_time = time.time() while not is_port_in_use(TEST_SERVER_PORT): if time.time() - start_time > max_wait: - raise TimeoutError("Server failed to start within 30 seconds") + raise TimeoutError("Server failed to start within 120 seconds") time.sleep(0.5) yield server_process diff --git a/tests/integration/client/content/config/tabs/test_databases.py b/tests/integration/client/content/config/tabs/test_databases.py index fed67ed9..df24fc78 100644 --- a/tests/integration/client/content/config/tabs/test_databases.py +++ b/tests/integration/client/content/config/tabs/test_databases.py @@ -5,9 +5,10 @@ """ # spell-checker: disable -from tests.db_fixtures import TEST_DB_CONFIG import pytest +from tests.db_fixtures import TEST_DB_CONFIG + ############################################################################# # Test Streamlit UI diff --git a/tests/integration/client/content/config/tabs/test_models.py b/tests/integration/client/content/config/tabs/test_models.py index daec52d1..17d5199f 100644 --- a/tests/integration/client/content/config/tabs/test_models.py +++ b/tests/integration/client/content/config/tabs/test_models.py @@ -8,7 +8,7 @@ import os from unittest.mock import MagicMock, patch -from test.integration.client.conftest import temporary_sys_path +from tests.integration.client.conftest import temporary_sys_path # Streamlit File ST_FILE = "../src/client/content/config/tabs/models.py" diff --git a/tests/integration/client/content/config/test_config.py b/tests/integration/client/content/config/test_config.py index 6b997f61..0f6e684f 100644 --- a/tests/integration/client/content/config/test_config.py +++ b/tests/integration/client/content/config/test_config.py @@ -5,9 +5,10 @@ """ # spell-checker: disable -from test.integration.client.conftest import create_tabs_mock, run_streamlit_test import streamlit as st +from tests.integration.client.conftest import create_tabs_mock, run_streamlit_test + ############################################################################# # Test Streamlit UI diff --git a/tests/integration/client/content/test_chatbot.py b/tests/integration/client/content/test_chatbot.py index c7c46ab9..4216349a 100644 --- a/tests/integration/client/content/test_chatbot.py +++ b/tests/integration/client/content/test_chatbot.py @@ -5,7 +5,7 @@ """ # spell-checker: disable -from test.integration.client.conftest import enable_test_models, run_page_with_models_enabled +from tests.integration.client.conftest import enable_test_models, run_page_with_models_enabled ############################################################################# diff --git a/tests/integration/client/content/test_testbed.py b/tests/integration/client/content/test_testbed.py index eceb11ca..f3ccb350 100644 --- a/tests/integration/client/content/test_testbed.py +++ b/tests/integration/client/content/test_testbed.py @@ -5,7 +5,7 @@ """ # spell-checker: disable -from test.integration.client.conftest import run_page_with_models_enabled +from tests.integration.client.conftest import run_page_with_models_enabled ############################################################################# diff --git a/tests/integration/client/content/tools/tabs/test_split_embed.py b/tests/integration/client/content/tools/tabs/test_split_embed.py index f438ea9c..341b5d8b 100644 --- a/tests/integration/client/content/tools/tabs/test_split_embed.py +++ b/tests/integration/client/content/tools/tabs/test_split_embed.py @@ -1,14 +1,16 @@ -# pylint: disable=protected-access,import-error,import-outside-toplevel """ Copyright (c) 2024, 2025, Oracle and/or its affiliates. Licensed under the Universal Permissive License v1.0 as shown at http://oss.oracle.com/licenses/upl. """ +# pylint: disable=protected-access import-error import-outside-toplevel # spell-checker: disable from unittest.mock import patch -from test.integration.client.conftest import enable_test_embed_models + import pandas as pd +from tests.integration.client.conftest import enable_test_embed_models + ############################################################################# # Test Helpers diff --git a/tests/integration/client/content/tools/test_tools.py b/tests/integration/client/content/tools/test_tools.py index 76cc73c2..aef1824e 100644 --- a/tests/integration/client/content/tools/test_tools.py +++ b/tests/integration/client/content/tools/test_tools.py @@ -5,7 +5,7 @@ """ # spell-checker: disable -from test.integration.client.conftest import create_tabs_mock, run_streamlit_test +from tests.integration.client.conftest import create_tabs_mock, run_streamlit_test ############################################################################# diff --git a/tests/integration/common/test_functions.py b/tests/integration/common/test_functions.py index ae6a0c7c..02be6fae 100644 --- a/tests/integration/common/test_functions.py +++ b/tests/integration/common/test_functions.py @@ -11,11 +11,10 @@ import os import tempfile -from tests.db_fixtures import TEST_DB_CONFIG - import pytest from common import functions +from tests.db_fixtures import TEST_DB_CONFIG class TestIsUrlAccessibleIntegration: diff --git a/tests/integration/conftest.py b/tests/integration/conftest.py index be912832..e622b4e9 100644 --- a/tests/integration/conftest.py +++ b/tests/integration/conftest.py @@ -15,7 +15,7 @@ import pytest -def pytest_collection_modifyitems(config, items): +def pytest_collection_modifyitems(items): """Automatically add 'integration' marker to all tests in this directory.""" for item in items: # Check if the test is under test/integration/ diff --git a/tests/integration/server/api/conftest.py b/tests/integration/server/api/conftest.py index b0d1eddc..4b8af5e3 100644 --- a/tests/integration/server/api/conftest.py +++ b/tests/integration/server/api/conftest.py @@ -18,24 +18,25 @@ # pylint: disable=redefined-outer-name -import os import asyncio +import os from typing import Generator -# Import constants needed by fixtures and test configuration in this file -from tests.db_fixtures import TEST_DB_CONFIG -from tests.shared_fixtures import ( - DEFAULT_LL_MODEL_CONFIG, - TEST_AUTH_TOKEN, - ALL_TEST_ENV_VARS, -) - import numpy as np - import pytest from fastapi.testclient import TestClient from server.bootstrap.bootstrap import DATABASE_OBJECTS, MODEL_OBJECTS, SETTINGS_OBJECTS +# Import constants and helpers needed by fixtures in this file +from tests.db_fixtures import TEST_DB_CONFIG +from tests.shared_fixtures import ( + DEFAULT_LL_MODEL_CONFIG, + TEST_AUTH_TOKEN, + make_auth_headers, + save_env_state, + clear_env_state, + restore_env_state, +) # Test configuration - extends shared DB config with integration-specific settings TEST_CONFIG = { @@ -61,33 +62,17 @@ def server_test_env(): The `app` fixture depends on this to ensure environment is configured before the FastAPI application is created. """ - # Save original environment state - original_env = {var: os.environ.get(var) for var in ALL_TEST_ENV_VARS} - - # Also capture dynamic OCI_ vars - dynamic_oci_vars = [v for v in os.environ if v.startswith("OCI_") and v not in ALL_TEST_ENV_VARS] - for var in dynamic_oci_vars: - original_env[var] = os.environ.get(var) - - # Clear all test-related vars - for var in ALL_TEST_ENV_VARS: - os.environ.pop(var, None) - for var in dynamic_oci_vars: - os.environ.pop(var, None) + original_env = save_env_state() + clear_env_state(original_env) # Set required environment variables for test server - os.environ["CONFIG_FILE"] = "/non/existent/path/config.json" # Use empty config - os.environ["OCI_CLI_CONFIG_FILE"] = "/non/existent/path" # Prevent OCI config pickup + os.environ["CONFIG_FILE"] = "/non/existent/path/config.json" + os.environ["OCI_CLI_CONFIG_FILE"] = "/non/existent/path" os.environ["API_SERVER_KEY"] = TEST_CONFIG["auth_token"] yield - # Restore original environment state - for var, value in original_env.items(): - if value is not None: - os.environ[var] = value - elif var in os.environ: - del os.environ[var] + restore_env_state(original_env) ################################################# @@ -96,11 +81,7 @@ def server_test_env(): @pytest.fixture def auth_headers(): """Return common header configurations for testing.""" - return { - "no_auth": {}, - "invalid_auth": {"Authorization": "Bearer invalid-token", "client": TEST_CONFIG["client"]}, - "valid_auth": {"Authorization": f"Bearer {TEST_CONFIG['auth_token']}", "client": TEST_CONFIG["client"]}, - } + return make_auth_headers(TEST_CONFIG["auth_token"], TEST_CONFIG["client"]) @pytest.fixture @@ -110,11 +91,7 @@ def test_client_auth_headers(test_client_settings): Use this fixture for endpoints that look up client settings via the client header. It ensures the test_client exists in SETTINGS_OBJECTS before returning headers. """ - return { - "no_auth": {}, - "invalid_auth": {"Authorization": "Bearer invalid-token", "client": test_client_settings}, - "valid_auth": {"Authorization": f"Bearer {TEST_CONFIG['auth_token']}", "client": test_client_settings}, - } + return make_auth_headers(TEST_CONFIG["auth_token"], test_client_settings) ################################################# diff --git a/tests/integration/server/api/v1/test_testbed.py b/tests/integration/server/api/v1/test_testbed.py index bddaeaf3..04a825e5 100644 --- a/tests/integration/server/api/v1/test_testbed.py +++ b/tests/integration/server/api/v1/test_testbed.py @@ -14,7 +14,7 @@ import pytest -from common.schema import TestSetQA as QATestSet, Evaluation, EvaluationReport +from common.schema import QASetData, Evaluation, EvaluationReport class TestAuthentication: @@ -297,7 +297,7 @@ def test_testbed_generate_qa_mocked(self, client, test_client_auth_headers): # This is a complex operation that requires a model to generate Q&A, so we'll mock this part with patch.object(client, "post") as mock_post: # Configure the mock to return a successful response - mock_qa_data = QATestSet( + mock_qa_data = QASetData( qa_data=[ {"question": "Generated Q1?", "answer": "Generated A1"}, {"question": "Generated Q2?", "answer": "Generated A2"}, diff --git a/tests/integration/server/bootstrap/conftest.py b/tests/integration/server/bootstrap/conftest.py index 2dfbf30c..c033a29e 100644 --- a/tests/integration/server/bootstrap/conftest.py +++ b/tests/integration/server/bootstrap/conftest.py @@ -18,6 +18,8 @@ import tempfile from pathlib import Path +import pytest + # Import constants needed by fixtures in this file from tests.shared_fixtures import ( DEFAULT_LL_MODEL_CONFIG, @@ -27,11 +29,10 @@ TEST_API_KEY_ALT, ) -import pytest - @pytest.fixture -def clean_bootstrap_env(clean_env): +@pytest.mark.usefixtures("clean_env") +def clean_bootstrap_env(): """Alias for clean_env fixture for backwards compatibility. This fixture name is used in existing tests. It delegates to the diff --git a/tests/integration/server/bootstrap/test_bootstrap_databases.py b/tests/integration/server/bootstrap/test_bootstrap_databases.py index 3a22b9a3..8872e23f 100644 --- a/tests/integration/server/bootstrap/test_bootstrap_databases.py +++ b/tests/integration/server/bootstrap/test_bootstrap_databases.py @@ -12,16 +12,15 @@ import os +import pytest + +from server.bootstrap import databases as databases_module from tests.shared_fixtures import ( assert_database_list_valid, assert_has_default_database, get_database_by_name, ) -import pytest - -from server.bootstrap import databases as databases_module - @pytest.mark.usefixtures("reset_config_store", "clean_bootstrap_env") class TestDatabasesBootstrapWithConfig: diff --git a/tests/integration/server/bootstrap/test_bootstrap_models.py b/tests/integration/server/bootstrap/test_bootstrap_models.py index ff094f31..c1a69d23 100644 --- a/tests/integration/server/bootstrap/test_bootstrap_models.py +++ b/tests/integration/server/bootstrap/test_bootstrap_models.py @@ -13,11 +13,10 @@ import os from unittest.mock import patch -from tests.shared_fixtures import assert_model_list_valid, get_model_by_id - import pytest from server.bootstrap import models as models_module +from tests.shared_fixtures import assert_model_list_valid, get_model_by_id @pytest.mark.usefixtures("reset_config_store", "clean_bootstrap_env") diff --git a/tests/shared_fixtures.py b/tests/shared_fixtures.py index 55170c24..62f2aa53 100644 --- a/tests/shared_fixtures.py +++ b/tests/shared_fixtures.py @@ -515,6 +515,84 @@ def my_env(monkeypatch): monkeypatch.setenv("API_SERVER_PORT", str(server_port)) +################################################# +# Session-scoped Environment Helpers +################################################# +# These helpers are for session-scoped fixtures that can't use monkeypatch. +# They manually save/restore environment state. + + +def save_env_state() -> dict: + """Save the current state of test-related environment variables. + + Returns a dict mapping var names to their values (or None if not set). + Also captures dynamic OCI_ vars not in our static list. + + Usage: + original_env = save_env_state() + # ... modify environment ... + restore_env_state(original_env) + """ + original_env = {var: os.environ.get(var) for var in ALL_TEST_ENV_VARS} + + # Also capture dynamic OCI_ vars + for var in _get_dynamic_oci_vars(): + original_env[var] = os.environ.get(var) + + return original_env + + +def clear_env_state(original_env: dict) -> None: + """Clear all test-related environment variables. + + Clears all vars in ALL_TEST_ENV_VARS plus any dynamic OCI_ vars + that were captured in original_env. + + Args: + original_env: Dict from save_env_state() (used to get dynamic var names) + """ + for var in ALL_TEST_ENV_VARS: + os.environ.pop(var, None) + + # Clear dynamic OCI vars that were in original_env + for var in original_env: + if var not in ALL_TEST_ENV_VARS: + os.environ.pop(var, None) + + +def restore_env_state(original_env: dict) -> None: + """Restore environment variables to their original state. + + Args: + original_env: Dict from save_env_state() + """ + for var, value in original_env.items(): + if value is not None: + os.environ[var] = value + elif var in os.environ: + del os.environ[var] + + +def make_auth_headers(auth_token: str, client_id: str) -> dict: + """Create standard auth headers dict for testing. + + Returns a dict with 'no_auth', 'invalid_auth', and 'valid_auth' keys, + each containing the appropriate headers for that auth scenario. + + Args: + auth_token: Valid authentication token + client_id: Client identifier for the client header + + Returns: + Dict with auth header configurations for testing + """ + return { + "no_auth": {}, + "invalid_auth": {"Authorization": "Bearer invalid-token", "client": client_id}, + "valid_auth": {"Authorization": f"Bearer {auth_token}", "client": client_id}, + } + + ################################################# # Vector Store Test Data ################################################# diff --git a/tests/unit/common/test_schema.py b/tests/unit/common/test_schema.py index 0dc0616e..8bd22f95 100644 --- a/tests/unit/common/test_schema.py +++ b/tests/unit/common/test_schema.py @@ -40,8 +40,8 @@ # Completions ChatRequest, # Testbed - TestSets, - TestSetQA, + QASets, + QASetData, Evaluation, EvaluationReport, # Types @@ -549,23 +549,23 @@ def test_default_model_is_none(self): assert request.model is None -class TestTestbedModels: - """Tests for Testbed-related models.""" +class TestQAModels: + """Tests for QA testbed-related models.""" - def test_test_sets_required_fields(self): - """TestSets should require tid, name, and created.""" + def test_qa_sets_required_fields(self): + """QASets should require tid, name, and created.""" with pytest.raises(ValidationError): - TestSets() + QASets() - test_set = TestSets(tid="123", name="Test Set", created="2024-01-01") - assert test_set.tid == "123" + qa_set = QASets(tid="123", name="Test Set", created="2024-01-01") + assert qa_set.tid == "123" - def test_test_set_qa_required_fields(self): - """TestSetQA should require qa_data.""" + def test_qa_set_data_required_fields(self): + """QASetData should require qa_data.""" with pytest.raises(ValidationError): - TestSetQA() + QASetData() - qa = TestSetQA(qa_data=[{"q": "question", "a": "answer"}]) + qa = QASetData(qa_data=[{"q": "question", "a": "answer"}]) assert len(qa.qa_data) == 1 def test_evaluation_required_fields(self): diff --git a/tests/unit/conftest.py b/tests/unit/conftest.py index 3271e5c3..0981c270 100644 --- a/tests/unit/conftest.py +++ b/tests/unit/conftest.py @@ -15,7 +15,7 @@ import pytest -def pytest_collection_modifyitems(config, items): +def pytest_collection_modifyitems(items): """Automatically add 'unit' marker to all tests in this directory.""" for item in items: # Check if the test is under test/unit/ diff --git a/tests/unit/server/api/conftest.py b/tests/unit/server/api/conftest.py index a640aa6c..7317e706 100644 --- a/tests/unit/server/api/conftest.py +++ b/tests/unit/server/api/conftest.py @@ -14,14 +14,6 @@ from unittest.mock import MagicMock, AsyncMock -# Import constants needed by fixtures in this file -from tests.shared_fixtures import ( - TEST_DB_USER, - TEST_DB_PASSWORD, - TEST_DB_DSN, -) -from tests.db_fixtures import TEST_DB_CONFIG - import pytest from common.schema import ( @@ -29,6 +21,12 @@ DatabaseVectorStorage, ChatRequest, ) +# Import constants needed by fixtures in this file +from tests.shared_fixtures import ( + TEST_DB_USER, + TEST_DB_PASSWORD, + TEST_DB_DSN, +) @pytest.fixture diff --git a/tests/unit/server/api/utils/test_utils_databases.py b/tests/unit/server/api/utils/test_utils_databases.py index c340546d..49ceaec4 100644 --- a/tests/unit/server/api/utils/test_utils_databases.py +++ b/tests/unit/server/api/utils/test_utils_databases.py @@ -12,8 +12,6 @@ # pylint: disable=too-few-public-methods -from tests.db_fixtures import TEST_DB_CONFIG -from tests.shared_fixtures import TEST_DB_WALLET_PASSWORD from unittest.mock import patch, MagicMock import pytest @@ -22,6 +20,8 @@ from common.schema import DatabaseSettings from server.api.utils import databases as utils_databases from server.api.utils.databases import DbException, ExistsDatabaseError, UnknownDatabaseError +from tests.db_fixtures import TEST_DB_CONFIG +from tests.shared_fixtures import TEST_DB_WALLET_PASSWORD class TestDbException: diff --git a/tests/unit/server/api/utils/test_utils_mcp.py b/tests/unit/server/api/utils/test_utils_mcp.py index 34bd6a53..ea81be6d 100644 --- a/tests/unit/server/api/utils/test_utils_mcp.py +++ b/tests/unit/server/api/utils/test_utils_mcp.py @@ -7,12 +7,12 @@ """ import os -from tests.shared_fixtures import TEST_API_KEY, TEST_API_KEY_ALT from unittest.mock import AsyncMock, MagicMock, patch import pytest from server.api.utils import mcp +from tests.shared_fixtures import TEST_API_KEY, TEST_API_KEY_ALT class TestGetClient: diff --git a/tests/unit/server/api/utils/test_utils_webscrape.py b/tests/unit/server/api/utils/test_utils_webscrape.py index cd042992..f8dc5e72 100644 --- a/tests/unit/server/api/utils/test_utils_webscrape.py +++ b/tests/unit/server/api/utils/test_utils_webscrape.py @@ -8,13 +8,13 @@ # pylint: disable=too-few-public-methods -from test.unit.server.api.conftest import create_mock_aiohttp_session from unittest.mock import patch, AsyncMock import pytest from bs4 import BeautifulSoup from server.api.utils import webscrape +from tests.unit.server.api.conftest import create_mock_aiohttp_session class TestNormalizeWs: diff --git a/tests/unit/server/api/v1/test_v1_embed.py b/tests/unit/server/api/v1/test_v1_embed.py index aef37044..855c8339 100644 --- a/tests/unit/server/api/v1/test_v1_embed.py +++ b/tests/unit/server/api/v1/test_v1_embed.py @@ -10,7 +10,6 @@ from io import BytesIO from pathlib import Path -from test.unit.server.api.conftest import create_mock_aiohttp_session from unittest.mock import patch, MagicMock, AsyncMock import json @@ -21,6 +20,7 @@ from common.schema import DatabaseVectorStorage, VectorStoreRefreshRequest from server.api.v1 import embed from server.api.utils.databases import DbException +from tests.unit.server.api.conftest import create_mock_aiohttp_session @pytest.fixture diff --git a/tests/unit/server/api/v1/test_v1_mcp.py b/tests/unit/server/api/v1/test_v1_mcp.py index 7e55b6ab..5da618ab 100644 --- a/tests/unit/server/api/v1/test_v1_mcp.py +++ b/tests/unit/server/api/v1/test_v1_mcp.py @@ -8,12 +8,12 @@ # pylint: disable=too-few-public-methods -from tests.shared_fixtures import TEST_API_KEY from unittest.mock import AsyncMock, MagicMock, patch import pytest from server.api.v1 import mcp +from tests.shared_fixtures import TEST_API_KEY class TestGetMcp: diff --git a/tests/unit/server/api/v1/test_v1_testbed.py b/tests/unit/server/api/v1/test_v1_testbed.py index 8ba6d12c..a69afab4 100644 --- a/tests/unit/server/api/v1/test_v1_testbed.py +++ b/tests/unit/server/api/v1/test_v1_testbed.py @@ -15,7 +15,7 @@ import litellm from server.api.v1 import testbed -from common.schema import TestSets, TestSetQA, Evaluation, EvaluationReport +from common.schema import QASets, QASetData, Evaluation, EvaluationReport class TestTestbedTestsets: @@ -33,8 +33,8 @@ async def test_testbed_testsets_returns_list( mock_get_db.return_value = mock_db mock_testsets = [ - TestSets(tid="TS001", name="Test Set 1", created="2024-01-01"), - TestSets(tid="TS002", name="Test Set 2", created="2024-01-02"), + QASets(tid="TS001", name="Test Set 1", created="2024-01-01"), + QASets(tid="TS002", name="Test Set 2", created="2024-01-02"), ] mock_get_testsets.return_value = mock_testsets @@ -139,7 +139,7 @@ async def test_testbed_testset_qa_returns_data( mock_db.connection = mock_db_connection mock_get_db.return_value = mock_db - mock_qa = TestSetQA(qa_data=[{"question": "Q1", "answer": "A1"}]) + mock_qa = QASetData(qa_data=[{"question": "Q1", "answer": "A1"}]) mock_get_qa.return_value = mock_qa result = await testbed.testbed_testset_qa(tid="ts001", client="test_client") @@ -186,7 +186,7 @@ async def test_testbed_upsert_testsets_success( mock_get_db.return_value = mock_db mock_jsonl.return_value = [{"question": "Q1", "answer": "A1"}] mock_upsert.return_value = "TS001" - mock_testset_qa.return_value = TestSetQA(qa_data=[{"question": "Q1"}]) + mock_testset_qa.return_value = QASetData(qa_data=[{"question": "Q1"}]) mock_file = UploadFile(file=BytesIO(b'{"question": "Q1"}'), filename="test.jsonl") @@ -194,7 +194,7 @@ async def test_testbed_upsert_testsets_success( files=[mock_file], name="Test Set", tid=None, client="test_client" ) - assert isinstance(result, TestSetQA) + assert isinstance(result, QASetData) mock_db_connection.commit.assert_called_once() @pytest.mark.asyncio diff --git a/tests/unit/server/bootstrap/test_bootstrap_databases.py b/tests/unit/server/bootstrap/test_bootstrap_databases.py index 195a9e75..0db5af42 100644 --- a/tests/unit/server/bootstrap/test_bootstrap_databases.py +++ b/tests/unit/server/bootstrap/test_bootstrap_databases.py @@ -10,16 +10,15 @@ import os +import pytest + +from server.bootstrap import databases as databases_module from tests.shared_fixtures import ( assert_database_list_valid, assert_has_default_database, get_database_by_name, ) -import pytest - -from server.bootstrap import databases as databases_module - @pytest.mark.usefixtures("reset_config_store", "clean_env") class TestDatabasesMain: diff --git a/tests/unit/server/bootstrap/test_bootstrap_models.py b/tests/unit/server/bootstrap/test_bootstrap_models.py index 6635a27b..c638275c 100644 --- a/tests/unit/server/bootstrap/test_bootstrap_models.py +++ b/tests/unit/server/bootstrap/test_bootstrap_models.py @@ -11,11 +11,10 @@ import os from unittest.mock import patch -from tests.shared_fixtures import assert_model_list_valid, get_model_by_id, TEST_API_KEY - import pytest from server.bootstrap import models as models_module +from tests.shared_fixtures import assert_model_list_valid, get_model_by_id, TEST_API_KEY @pytest.mark.usefixtures("reset_config_store", "clean_env", "mock_is_url_accessible") From 4a2e61293716ec813ca37ccc2f2b45401d21b674 Mon Sep 17 00:00:00 2001 From: gotsysdba Date: Sun, 30 Nov 2025 15:20:52 +0000 Subject: [PATCH 12/20] Cleaner test directory structure without boilerplate __init__.py files --- pytest.ini | 2 +- tests/__init__.py | 1 - tests/conftest.py | 12 ++++++++---- tests/integration/__init__.py | 6 ------ tests/integration/client/__init__.py | 0 tests/integration/client/conftest.py | 4 ++-- tests/integration/client/content/__init__.py | 0 tests/integration/client/content/config/__init__.py | 0 .../client/content/config/tabs/__init__.py | 0 .../client/content/config/tabs/test_databases.py | 2 +- .../client/content/config/tabs/test_models.py | 2 +- .../integration/client/content/config/test_config.py | 2 +- tests/integration/client/content/test_chatbot.py | 2 +- tests/integration/client/content/test_testbed.py | 2 +- tests/integration/client/content/tools/__init__.py | 0 .../client/content/tools/tabs/__init__.py | 0 .../client/content/tools/tabs/test_split_embed.py | 2 +- tests/integration/client/content/tools/test_tools.py | 2 +- tests/integration/client/utils/__init__.py | 0 tests/integration/common/__init__.py | 6 ------ tests/integration/common/test_functions.py | 2 +- tests/integration/server/__init__.py | 6 ------ tests/integration/server/api/__init__.py | 6 ------ tests/integration/server/api/conftest.py | 4 ++-- tests/integration/server/api/v1/__init__.py | 6 ------ tests/integration/server/api/v1/test_databases.py | 2 +- tests/integration/server/bootstrap/__init__.py | 1 - tests/integration/server/bootstrap/conftest.py | 2 +- .../server/bootstrap/test_bootstrap_databases.py | 2 +- .../server/bootstrap/test_bootstrap_models.py | 2 +- tests/unit/__init__.py | 1 - tests/unit/client/__init__.py | 0 tests/unit/client/content/__init__.py | 0 tests/unit/client/content/config/__init__.py | 0 tests/unit/client/content/config/tabs/__init__.py | 0 tests/unit/client/content/tools/__init__.py | 0 tests/unit/client/content/tools/tabs/__init__.py | 0 tests/unit/client/utils/__init__.py | 0 tests/unit/common/__init__.py | 6 ------ tests/unit/server/__init__.py | 1 - tests/unit/server/api/__init__.py | 1 - tests/unit/server/api/conftest.py | 2 +- tests/unit/server/api/utils/__init__.py | 1 - tests/unit/server/api/utils/test_utils_databases.py | 4 ++-- tests/unit/server/api/utils/test_utils_mcp.py | 2 +- tests/unit/server/api/utils/test_utils_webscrape.py | 2 +- tests/unit/server/api/v1/__init__.py | 1 - tests/unit/server/api/v1/test_v1_embed.py | 2 +- tests/unit/server/api/v1/test_v1_mcp.py | 2 +- tests/unit/server/bootstrap/__init__.py | 1 - .../server/bootstrap/test_bootstrap_databases.py | 2 +- tests/unit/server/bootstrap/test_bootstrap_models.py | 2 +- 52 files changed, 34 insertions(+), 74 deletions(-) delete mode 100644 tests/__init__.py delete mode 100644 tests/integration/__init__.py delete mode 100644 tests/integration/client/__init__.py delete mode 100644 tests/integration/client/content/__init__.py delete mode 100644 tests/integration/client/content/config/__init__.py delete mode 100644 tests/integration/client/content/config/tabs/__init__.py delete mode 100644 tests/integration/client/content/tools/__init__.py delete mode 100644 tests/integration/client/content/tools/tabs/__init__.py delete mode 100644 tests/integration/client/utils/__init__.py delete mode 100644 tests/integration/common/__init__.py delete mode 100644 tests/integration/server/__init__.py delete mode 100644 tests/integration/server/api/__init__.py delete mode 100644 tests/integration/server/api/v1/__init__.py delete mode 100644 tests/integration/server/bootstrap/__init__.py delete mode 100644 tests/unit/__init__.py delete mode 100644 tests/unit/client/__init__.py delete mode 100644 tests/unit/client/content/__init__.py delete mode 100644 tests/unit/client/content/config/__init__.py delete mode 100644 tests/unit/client/content/config/tabs/__init__.py delete mode 100644 tests/unit/client/content/tools/__init__.py delete mode 100644 tests/unit/client/content/tools/tabs/__init__.py delete mode 100644 tests/unit/client/utils/__init__.py delete mode 100644 tests/unit/common/__init__.py delete mode 100644 tests/unit/server/__init__.py delete mode 100644 tests/unit/server/api/__init__.py delete mode 100644 tests/unit/server/api/utils/__init__.py delete mode 100644 tests/unit/server/api/v1/__init__.py delete mode 100644 tests/unit/server/bootstrap/__init__.py diff --git a/pytest.ini b/pytest.ini index cbfba509..583cdf9e 100644 --- a/pytest.ini +++ b/pytest.ini @@ -4,7 +4,7 @@ ; spell-checker: disable [pytest] -pythonpath = src +pythonpath = src tests filterwarnings = ignore::DeprecationWarning asyncio_default_fixture_loop_scope = function diff --git a/tests/__init__.py b/tests/__init__.py deleted file mode 100644 index 66173aec..00000000 --- a/tests/__init__.py +++ /dev/null @@ -1 +0,0 @@ -# Test package diff --git a/tests/conftest.py b/tests/conftest.py index 8935079e..0a4840ec 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -5,19 +5,23 @@ Root pytest configuration for the test suite. This conftest.py uses pytest_plugins to automatically load fixtures from: -- tests.shared_fixtures: Factory fixtures (make_database, make_model, etc.) -- tests.db_fixtures: Database container fixtures (db_container, db_connection, etc.) +- shared_fixtures: Factory fixtures (make_database, make_model, etc.) +- db_fixtures: Database container fixtures (db_container, db_connection, etc.) All fixtures defined in these modules are automatically available to all tests without needing explicit imports in child conftest.py files. Constants and helper functions (e.g., TEST_DB_CONFIG, assert_model_list_valid) still require explicit imports in the test files that use them. + +Note: The 'tests' directory is added to pythonpath in pytest.ini, enabling +direct imports like 'from shared_fixtures import X' instead of 'from tests.shared_fixtures import X'. +This removes the need for __init__.py files in test directories. """ # pytest_plugins automatically loads fixtures from these modules # This replaces scattered "from tests.shared_fixtures import ..." across conftest files pytest_plugins = [ - "tests.shared_fixtures", - "tests.db_fixtures", + "shared_fixtures", + "db_fixtures", ] diff --git a/tests/integration/__init__.py b/tests/integration/__init__.py deleted file mode 100644 index 2577126f..00000000 --- a/tests/integration/__init__.py +++ /dev/null @@ -1,6 +0,0 @@ -""" -Copyright (c) 2024, 2025, Oracle and/or its affiliates. -Licensed under the Universal Permissive License v1.0 as shown at http://oss.oracle.com/licenses/upl. - -Integration tests package. -""" diff --git a/tests/integration/client/__init__.py b/tests/integration/client/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/tests/integration/client/conftest.py b/tests/integration/client/conftest.py index f10c89c8..08e3ea03 100644 --- a/tests/integration/client/conftest.py +++ b/tests/integration/client/conftest.py @@ -29,8 +29,8 @@ import requests # Import constants and helpers needed by fixtures in this file -from tests.db_fixtures import TEST_DB_CONFIG -from tests.shared_fixtures import ( +from db_fixtures import TEST_DB_CONFIG +from shared_fixtures import ( TEST_AUTH_TOKEN, make_auth_headers, save_env_state, diff --git a/tests/integration/client/content/__init__.py b/tests/integration/client/content/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/tests/integration/client/content/config/__init__.py b/tests/integration/client/content/config/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/tests/integration/client/content/config/tabs/__init__.py b/tests/integration/client/content/config/tabs/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/tests/integration/client/content/config/tabs/test_databases.py b/tests/integration/client/content/config/tabs/test_databases.py index df24fc78..828678a4 100644 --- a/tests/integration/client/content/config/tabs/test_databases.py +++ b/tests/integration/client/content/config/tabs/test_databases.py @@ -7,7 +7,7 @@ import pytest -from tests.db_fixtures import TEST_DB_CONFIG +from db_fixtures import TEST_DB_CONFIG ############################################################################# diff --git a/tests/integration/client/content/config/tabs/test_models.py b/tests/integration/client/content/config/tabs/test_models.py index 17d5199f..5319c77b 100644 --- a/tests/integration/client/content/config/tabs/test_models.py +++ b/tests/integration/client/content/config/tabs/test_models.py @@ -8,7 +8,7 @@ import os from unittest.mock import MagicMock, patch -from tests.integration.client.conftest import temporary_sys_path +from integration.client.conftest import temporary_sys_path # Streamlit File ST_FILE = "../src/client/content/config/tabs/models.py" diff --git a/tests/integration/client/content/config/test_config.py b/tests/integration/client/content/config/test_config.py index 0f6e684f..62b062de 100644 --- a/tests/integration/client/content/config/test_config.py +++ b/tests/integration/client/content/config/test_config.py @@ -7,7 +7,7 @@ import streamlit as st -from tests.integration.client.conftest import create_tabs_mock, run_streamlit_test +from integration.client.conftest import create_tabs_mock, run_streamlit_test ############################################################################# diff --git a/tests/integration/client/content/test_chatbot.py b/tests/integration/client/content/test_chatbot.py index 4216349a..7cea20df 100644 --- a/tests/integration/client/content/test_chatbot.py +++ b/tests/integration/client/content/test_chatbot.py @@ -5,7 +5,7 @@ """ # spell-checker: disable -from tests.integration.client.conftest import enable_test_models, run_page_with_models_enabled +from integration.client.conftest import enable_test_models, run_page_with_models_enabled ############################################################################# diff --git a/tests/integration/client/content/test_testbed.py b/tests/integration/client/content/test_testbed.py index f3ccb350..0afb779e 100644 --- a/tests/integration/client/content/test_testbed.py +++ b/tests/integration/client/content/test_testbed.py @@ -5,7 +5,7 @@ """ # spell-checker: disable -from tests.integration.client.conftest import run_page_with_models_enabled +from integration.client.conftest import run_page_with_models_enabled ############################################################################# diff --git a/tests/integration/client/content/tools/__init__.py b/tests/integration/client/content/tools/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/tests/integration/client/content/tools/tabs/__init__.py b/tests/integration/client/content/tools/tabs/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/tests/integration/client/content/tools/tabs/test_split_embed.py b/tests/integration/client/content/tools/tabs/test_split_embed.py index 341b5d8b..50d22d5e 100644 --- a/tests/integration/client/content/tools/tabs/test_split_embed.py +++ b/tests/integration/client/content/tools/tabs/test_split_embed.py @@ -9,7 +9,7 @@ import pandas as pd -from tests.integration.client.conftest import enable_test_embed_models +from integration.client.conftest import enable_test_embed_models ############################################################################# diff --git a/tests/integration/client/content/tools/test_tools.py b/tests/integration/client/content/tools/test_tools.py index aef1824e..e30cd003 100644 --- a/tests/integration/client/content/tools/test_tools.py +++ b/tests/integration/client/content/tools/test_tools.py @@ -5,7 +5,7 @@ """ # spell-checker: disable -from tests.integration.client.conftest import create_tabs_mock, run_streamlit_test +from integration.client.conftest import create_tabs_mock, run_streamlit_test ############################################################################# diff --git a/tests/integration/client/utils/__init__.py b/tests/integration/client/utils/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/tests/integration/common/__init__.py b/tests/integration/common/__init__.py deleted file mode 100644 index d63a0614..00000000 --- a/tests/integration/common/__init__.py +++ /dev/null @@ -1,6 +0,0 @@ -""" -Copyright (c) 2024, 2025, Oracle and/or its affiliates. -Licensed under the Universal Permissive License v1.0 as shown at http://oss.oracle.com/licenses/upl. - -Integration tests for common module. -""" diff --git a/tests/integration/common/test_functions.py b/tests/integration/common/test_functions.py index 02be6fae..167ccf16 100644 --- a/tests/integration/common/test_functions.py +++ b/tests/integration/common/test_functions.py @@ -14,7 +14,7 @@ import pytest from common import functions -from tests.db_fixtures import TEST_DB_CONFIG +from db_fixtures import TEST_DB_CONFIG class TestIsUrlAccessibleIntegration: diff --git a/tests/integration/server/__init__.py b/tests/integration/server/__init__.py deleted file mode 100644 index a242d937..00000000 --- a/tests/integration/server/__init__.py +++ /dev/null @@ -1,6 +0,0 @@ -""" -Copyright (c) 2024, 2025, Oracle and/or its affiliates. -Licensed under the Universal Permissive License v1.0 as shown at http://oss.oracle.com/licenses/upl. - -Server integration tests package. -""" diff --git a/tests/integration/server/api/__init__.py b/tests/integration/server/api/__init__.py deleted file mode 100644 index c4b92db3..00000000 --- a/tests/integration/server/api/__init__.py +++ /dev/null @@ -1,6 +0,0 @@ -""" -Copyright (c) 2024, 2025, Oracle and/or its affiliates. -Licensed under the Universal Permissive License v1.0 as shown at http://oss.oracle.com/licenses/upl. - -Server API integration tests package. -""" diff --git a/tests/integration/server/api/conftest.py b/tests/integration/server/api/conftest.py index 4b8af5e3..e8814b38 100644 --- a/tests/integration/server/api/conftest.py +++ b/tests/integration/server/api/conftest.py @@ -28,8 +28,8 @@ from server.bootstrap.bootstrap import DATABASE_OBJECTS, MODEL_OBJECTS, SETTINGS_OBJECTS # Import constants and helpers needed by fixtures in this file -from tests.db_fixtures import TEST_DB_CONFIG -from tests.shared_fixtures import ( +from db_fixtures import TEST_DB_CONFIG +from shared_fixtures import ( DEFAULT_LL_MODEL_CONFIG, TEST_AUTH_TOKEN, make_auth_headers, diff --git a/tests/integration/server/api/v1/__init__.py b/tests/integration/server/api/v1/__init__.py deleted file mode 100644 index d55308b1..00000000 --- a/tests/integration/server/api/v1/__init__.py +++ /dev/null @@ -1,6 +0,0 @@ -""" -Copyright (c) 2024, 2025, Oracle and/or its affiliates. -Licensed under the Universal Permissive License v1.0 as shown at http://oss.oracle.com/licenses/upl. - -Server API v1 integration tests package. -""" diff --git a/tests/integration/server/api/v1/test_databases.py b/tests/integration/server/api/v1/test_databases.py index 355148e9..891a390a 100644 --- a/tests/integration/server/api/v1/test_databases.py +++ b/tests/integration/server/api/v1/test_databases.py @@ -8,7 +8,7 @@ These endpoints require authentication. """ -from tests.db_fixtures import TEST_DB_CONFIG +from db_fixtures import TEST_DB_CONFIG class TestAuthentication: diff --git a/tests/integration/server/bootstrap/__init__.py b/tests/integration/server/bootstrap/__init__.py deleted file mode 100644 index 90dc5216..00000000 --- a/tests/integration/server/bootstrap/__init__.py +++ /dev/null @@ -1 +0,0 @@ -# Bootstrap integration test package diff --git a/tests/integration/server/bootstrap/conftest.py b/tests/integration/server/bootstrap/conftest.py index c033a29e..8f06ab47 100644 --- a/tests/integration/server/bootstrap/conftest.py +++ b/tests/integration/server/bootstrap/conftest.py @@ -21,7 +21,7 @@ import pytest # Import constants needed by fixtures in this file -from tests.shared_fixtures import ( +from shared_fixtures import ( DEFAULT_LL_MODEL_CONFIG, TEST_INTEGRATION_DB_USER, TEST_INTEGRATION_DB_PASSWORD, diff --git a/tests/integration/server/bootstrap/test_bootstrap_databases.py b/tests/integration/server/bootstrap/test_bootstrap_databases.py index 8872e23f..32f6492d 100644 --- a/tests/integration/server/bootstrap/test_bootstrap_databases.py +++ b/tests/integration/server/bootstrap/test_bootstrap_databases.py @@ -15,7 +15,7 @@ import pytest from server.bootstrap import databases as databases_module -from tests.shared_fixtures import ( +from shared_fixtures import ( assert_database_list_valid, assert_has_default_database, get_database_by_name, diff --git a/tests/integration/server/bootstrap/test_bootstrap_models.py b/tests/integration/server/bootstrap/test_bootstrap_models.py index c1a69d23..49d63e23 100644 --- a/tests/integration/server/bootstrap/test_bootstrap_models.py +++ b/tests/integration/server/bootstrap/test_bootstrap_models.py @@ -16,7 +16,7 @@ import pytest from server.bootstrap import models as models_module -from tests.shared_fixtures import assert_model_list_valid, get_model_by_id +from shared_fixtures import assert_model_list_valid, get_model_by_id @pytest.mark.usefixtures("reset_config_store", "clean_bootstrap_env") diff --git a/tests/unit/__init__.py b/tests/unit/__init__.py deleted file mode 100644 index 06825972..00000000 --- a/tests/unit/__init__.py +++ /dev/null @@ -1 +0,0 @@ -# Unit test package diff --git a/tests/unit/client/__init__.py b/tests/unit/client/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/tests/unit/client/content/__init__.py b/tests/unit/client/content/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/tests/unit/client/content/config/__init__.py b/tests/unit/client/content/config/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/tests/unit/client/content/config/tabs/__init__.py b/tests/unit/client/content/config/tabs/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/tests/unit/client/content/tools/__init__.py b/tests/unit/client/content/tools/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/tests/unit/client/content/tools/tabs/__init__.py b/tests/unit/client/content/tools/tabs/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/tests/unit/client/utils/__init__.py b/tests/unit/client/utils/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/tests/unit/common/__init__.py b/tests/unit/common/__init__.py deleted file mode 100644 index e4188e30..00000000 --- a/tests/unit/common/__init__.py +++ /dev/null @@ -1,6 +0,0 @@ -""" -Copyright (c) 2024, 2025, Oracle and/or its affiliates. -Licensed under the Universal Permissive License v1.0 as shown at http://oss.oracle.com/licenses/upl. - -Unit tests for common module. -""" diff --git a/tests/unit/server/__init__.py b/tests/unit/server/__init__.py deleted file mode 100644 index bc4d60b5..00000000 --- a/tests/unit/server/__init__.py +++ /dev/null @@ -1 +0,0 @@ -# Server unit test package diff --git a/tests/unit/server/api/__init__.py b/tests/unit/server/api/__init__.py deleted file mode 100644 index b4333d68..00000000 --- a/tests/unit/server/api/__init__.py +++ /dev/null @@ -1 +0,0 @@ -# API unit test package diff --git a/tests/unit/server/api/conftest.py b/tests/unit/server/api/conftest.py index 7317e706..4a951486 100644 --- a/tests/unit/server/api/conftest.py +++ b/tests/unit/server/api/conftest.py @@ -22,7 +22,7 @@ ChatRequest, ) # Import constants needed by fixtures in this file -from tests.shared_fixtures import ( +from shared_fixtures import ( TEST_DB_USER, TEST_DB_PASSWORD, TEST_DB_DSN, diff --git a/tests/unit/server/api/utils/__init__.py b/tests/unit/server/api/utils/__init__.py deleted file mode 100644 index 9d9b7b29..00000000 --- a/tests/unit/server/api/utils/__init__.py +++ /dev/null @@ -1 +0,0 @@ -# Utils unit test package diff --git a/tests/unit/server/api/utils/test_utils_databases.py b/tests/unit/server/api/utils/test_utils_databases.py index 49ceaec4..5a2afb92 100644 --- a/tests/unit/server/api/utils/test_utils_databases.py +++ b/tests/unit/server/api/utils/test_utils_databases.py @@ -20,8 +20,8 @@ from common.schema import DatabaseSettings from server.api.utils import databases as utils_databases from server.api.utils.databases import DbException, ExistsDatabaseError, UnknownDatabaseError -from tests.db_fixtures import TEST_DB_CONFIG -from tests.shared_fixtures import TEST_DB_WALLET_PASSWORD +from db_fixtures import TEST_DB_CONFIG +from shared_fixtures import TEST_DB_WALLET_PASSWORD class TestDbException: diff --git a/tests/unit/server/api/utils/test_utils_mcp.py b/tests/unit/server/api/utils/test_utils_mcp.py index ea81be6d..93169c25 100644 --- a/tests/unit/server/api/utils/test_utils_mcp.py +++ b/tests/unit/server/api/utils/test_utils_mcp.py @@ -12,7 +12,7 @@ import pytest from server.api.utils import mcp -from tests.shared_fixtures import TEST_API_KEY, TEST_API_KEY_ALT +from shared_fixtures import TEST_API_KEY, TEST_API_KEY_ALT class TestGetClient: diff --git a/tests/unit/server/api/utils/test_utils_webscrape.py b/tests/unit/server/api/utils/test_utils_webscrape.py index f8dc5e72..33be5ad1 100644 --- a/tests/unit/server/api/utils/test_utils_webscrape.py +++ b/tests/unit/server/api/utils/test_utils_webscrape.py @@ -14,7 +14,7 @@ from bs4 import BeautifulSoup from server.api.utils import webscrape -from tests.unit.server.api.conftest import create_mock_aiohttp_session +from unit.server.api.conftest import create_mock_aiohttp_session class TestNormalizeWs: diff --git a/tests/unit/server/api/v1/__init__.py b/tests/unit/server/api/v1/__init__.py deleted file mode 100644 index a6ad55f3..00000000 --- a/tests/unit/server/api/v1/__init__.py +++ /dev/null @@ -1 +0,0 @@ -# v1 API unit test package diff --git a/tests/unit/server/api/v1/test_v1_embed.py b/tests/unit/server/api/v1/test_v1_embed.py index 855c8339..fc431f54 100644 --- a/tests/unit/server/api/v1/test_v1_embed.py +++ b/tests/unit/server/api/v1/test_v1_embed.py @@ -20,7 +20,7 @@ from common.schema import DatabaseVectorStorage, VectorStoreRefreshRequest from server.api.v1 import embed from server.api.utils.databases import DbException -from tests.unit.server.api.conftest import create_mock_aiohttp_session +from unit.server.api.conftest import create_mock_aiohttp_session @pytest.fixture diff --git a/tests/unit/server/api/v1/test_v1_mcp.py b/tests/unit/server/api/v1/test_v1_mcp.py index 5da618ab..14461014 100644 --- a/tests/unit/server/api/v1/test_v1_mcp.py +++ b/tests/unit/server/api/v1/test_v1_mcp.py @@ -13,7 +13,7 @@ import pytest from server.api.v1 import mcp -from tests.shared_fixtures import TEST_API_KEY +from shared_fixtures import TEST_API_KEY class TestGetMcp: diff --git a/tests/unit/server/bootstrap/__init__.py b/tests/unit/server/bootstrap/__init__.py deleted file mode 100644 index 170366b5..00000000 --- a/tests/unit/server/bootstrap/__init__.py +++ /dev/null @@ -1 +0,0 @@ -# Bootstrap unit test package diff --git a/tests/unit/server/bootstrap/test_bootstrap_databases.py b/tests/unit/server/bootstrap/test_bootstrap_databases.py index 0db5af42..f536aede 100644 --- a/tests/unit/server/bootstrap/test_bootstrap_databases.py +++ b/tests/unit/server/bootstrap/test_bootstrap_databases.py @@ -13,7 +13,7 @@ import pytest from server.bootstrap import databases as databases_module -from tests.shared_fixtures import ( +from shared_fixtures import ( assert_database_list_valid, assert_has_default_database, get_database_by_name, diff --git a/tests/unit/server/bootstrap/test_bootstrap_models.py b/tests/unit/server/bootstrap/test_bootstrap_models.py index c638275c..8f8b56a7 100644 --- a/tests/unit/server/bootstrap/test_bootstrap_models.py +++ b/tests/unit/server/bootstrap/test_bootstrap_models.py @@ -14,7 +14,7 @@ import pytest from server.bootstrap import models as models_module -from tests.shared_fixtures import assert_model_list_valid, get_model_by_id, TEST_API_KEY +from shared_fixtures import assert_model_list_valid, get_model_by_id, TEST_API_KEY @pytest.mark.usefixtures("reset_config_store", "clean_env", "mock_is_url_accessible") From 636ff318013629f6557a068a2ebec5cefa3863f7 Mon Sep 17 00:00:00 2001 From: gotsysdba Date: Sun, 30 Nov 2025 19:42:36 +0000 Subject: [PATCH 13/20] update pytest.ini --- pytest.ini | 1 + 1 file changed, 1 insertion(+) diff --git a/pytest.ini b/pytest.ini index 583cdf9e..ea04c41f 100644 --- a/pytest.ini +++ b/pytest.ini @@ -5,6 +5,7 @@ [pytest] pythonpath = src tests +addopts = --disable-warnings --import-mode=importlib filterwarnings = ignore::DeprecationWarning asyncio_default_fixture_loop_scope = function From 5cd9ae0a5660eba1b9852e4f3a439edc82739996 Mon Sep 17 00:00:00 2001 From: gotsysdba Date: Mon, 1 Dec 2025 12:23:55 +0000 Subject: [PATCH 14/20] Updated tests --- src/client/content/config/tabs/settings.py | 20 +- src/client/content/testbed.py | 13 +- .../client/content/config/tabs/test_models.py | 125 +-- .../content/config/tabs/test_settings.py | 584 ++------------ .../content/tools/tabs/test_split_embed.py | 104 --- tests/integration/common/test_functions.py | 2 +- tests/integration/server/api/conftest.py | 4 +- .../bootstrap/test_bootstrap_databases.py | 4 +- .../server/bootstrap/test_bootstrap_models.py | 2 +- tests/shared_fixtures.py | 35 + .../content/config/tabs/test_models_unit.py | 117 +++ .../content/config/tabs/test_settings_unit.py | 572 ++++++++++++++ .../content/test_testbed_records_unit.py | 479 ++++++++++++ .../client/content/test_testbed_ui_unit.py | 174 +++++ .../unit/client/content/test_testbed_unit.py | 712 +++--------------- .../tools/tabs/test_split_embed_unit.py | 36 + tests/unit/server/api/conftest.py | 12 +- .../unit/server/api/utils/test_utils_chat.py | 12 - .../server/api/utils/test_utils_databases.py | 16 +- .../unit/server/api/utils/test_utils_embed.py | 12 - tests/unit/server/api/utils/test_utils_mcp.py | 14 +- .../server/api/utils/test_utils_models.py | 12 - .../api/utils/test_utils_module_config.py | 47 ++ tests/unit/server/api/utils/test_utils_oci.py | 12 - .../server/api/utils/test_utils_settings.py | 12 - .../server/api/utils/test_utils_testbed.py | 12 - .../server/api/utils/test_utils_webscrape.py | 2 +- tests/unit/server/api/v1/test_v1_chat.py | 28 - tests/unit/server/api/v1/test_v1_databases.py | 27 - tests/unit/server/api/v1/test_v1_embed.py | 35 +- tests/unit/server/api/v1/test_v1_mcp.py | 30 +- .../unit/server/api/v1/test_v1_mcp_prompts.py | 27 - tests/unit/server/api/v1/test_v1_models.py | 28 - .../server/api/v1/test_v1_module_config.py | 100 +++ tests/unit/server/api/v1/test_v1_oci.py | 30 - tests/unit/server/api/v1/test_v1_settings.py | 28 - tests/unit/server/api/v1/test_v1_testbed.py | 12 - .../bootstrap/test_bootstrap_bootstrap.py | 12 - .../bootstrap/test_bootstrap_configfile.py | 13 - .../bootstrap/test_bootstrap_databases.py | 16 +- .../server/bootstrap/test_bootstrap_models.py | 14 +- .../bootstrap/test_bootstrap_module_config.py | 43 ++ .../server/bootstrap/test_bootstrap_oci.py | 12 - .../bootstrap/test_bootstrap_settings.py | 12 - 44 files changed, 1813 insertions(+), 1800 deletions(-) create mode 100644 tests/unit/client/content/config/tabs/test_settings_unit.py create mode 100644 tests/unit/client/content/test_testbed_records_unit.py create mode 100644 tests/unit/client/content/test_testbed_ui_unit.py create mode 100644 tests/unit/server/api/utils/test_utils_module_config.py create mode 100644 tests/unit/server/api/v1/test_v1_module_config.py create mode 100644 tests/unit/server/bootstrap/test_bootstrap_module_config.py diff --git a/src/client/content/config/tabs/settings.py b/src/client/content/config/tabs/settings.py index 69d7cb18..53dac03d 100644 --- a/src/client/content/config/tabs/settings.py +++ b/src/client/content/config/tabs/settings.py @@ -363,17 +363,23 @@ def spring_ai_conf_check(ll_model: dict, embed_model: dict) -> str: def spring_ai_obaas(src_dir, file_name, provider, ll_config, embed_config): """Get the system prompt for SpringAI export""" + ## FUTURE FEATURE: # Determine which system prompt would be active based on tools_enabled - tools_enabled = state.client_settings.get("tools_enabled", []) + # tools_enabled = state.client_settings.get("tools_enabled", []) # Select prompt name based on tools configuration - if not tools_enabled: - prompt_name = "optimizer_basic-default" - if state.client_settings["vector_search"]["enabled"]: - prompt_name = "optimizer_vs-no-tools-default" + # if not tools_enabled: + # prompt_name = "optimizer_basic-default" + # if state.client_settings["vector_search"]["enabled"]: + # prompt_name = "optimizer_vs-no-tools-default" + # else: + # # Tools are enabled, use tools-default prompt + # prompt_name = "optimizer_tools-default" + ## Legacy Feature: + if "Vector Search" in state.client_settings.get("tools_enabled", []): + prompt_name = "optimizer_vs-no-tools-default" else: - # Tools are enabled, use tools-default prompt - prompt_name = "optimizer_tools-default" + prompt_name = "optimizer_basic-default" # Find the prompt in configs sys_prompt_obj = next((item for item in state.prompt_configs if item["name"] == prompt_name), None) diff --git a/src/client/content/testbed.py b/src/client/content/testbed.py index 43bd1aa8..a716980b 100644 --- a/src/client/content/testbed.py +++ b/src/client/content/testbed.py @@ -85,9 +85,12 @@ def create_gauge(value): st.dataframe(ll_settings_reversed, hide_index=True) if report["settings"]["testbed"]["judge_model"]: st.markdown(f"**Judge Model**: {report['settings']['testbed']['judge_model']}") - # if discovery; then list out the tables that were discovered (MCP implementation) - # if report["settings"]["vector_search"].get("discovery"): - if report["settings"]["vector_search"]["enabled"]: + # Backward compatibility + try: + vs_enabled = report["settings"]["vector_search"]["enabled"] + except KeyError: + vs_enabled = "Vector Search" in report["settings"]["tools_enabled"] + if vs_enabled: st.subheader("Vector Search Settings") st.markdown(f"""**Database**: {report["settings"]["database"]["alias"]}; **Vector Store**: {report["settings"]["vector_search"]["vector_store"]} @@ -100,6 +103,8 @@ def create_gauge(value): if report["settings"]["vector_search"]["search_type"] == "Similarity": embed_settings.drop(["score_threshold", "fetch_k", "lambda_mult"], axis=1, inplace=True) st.dataframe(embed_settings, hide_index=True) + # if discovery; then list out the tables that were discovered (MCP implementation) + # if report["settings"]["vector_search"].get("discovery"): else: st.markdown("**Evaluated without Vector Search**") @@ -516,7 +521,7 @@ def render_evaluation_ui(available_ll_models: list) -> None: key="evaluate_button", help="Evaluation will automatically save the TestSet to the Database", on_click=qa_update_db, - disabled=not state.enable_client, + disabled=not state.get("enable_client", True), ): with st.spinner("Starting Q&A evaluation... please be patient.", show_time=True): st_common.clear_state_key("testbed_evaluations") diff --git a/tests/integration/client/content/config/tabs/test_models.py b/tests/integration/client/content/config/tabs/test_models.py index 5319c77b..82284b1a 100644 --- a/tests/integration/client/content/config/tabs/test_models.py +++ b/tests/integration/client/content/config/tabs/test_models.py @@ -5,10 +5,7 @@ """ # spell-checker: disable -import os -from unittest.mock import MagicMock, patch - -from integration.client.conftest import temporary_sys_path +from unittest.mock import patch # Streamlit File ST_FILE = "../src/client/content/config/tabs/models.py" @@ -491,123 +488,3 @@ def test_render_api_configuration_uses_litellm_default_when_no_saved_value(self, else: # If no model has api_base, it should be empty string assert result_model["api_base"] == "" - - -############################################################################# -# Test Model CRUD Operations -############################################################################# -class TestModelCRUD: - """Test model create/patch/delete operations""" - - def test_create_model_success(self, monkeypatch): - """Test creating a new model""" - with temporary_sys_path(os.path.join(os.path.dirname(__file__), "../../../../../../src")): - from client.content.config.tabs import models - from client.utils import api_call - import streamlit as st - - # Setup test model - test_model = { - "id": "new-model", - "provider": "openai", - "type": "ll", - "enabled": True, - } - - # Mock API call - mock_post = MagicMock() - monkeypatch.setattr(api_call, "post", mock_post) - - # Mock st.success - mock_success = MagicMock() - monkeypatch.setattr(st, "success", mock_success) - - # Call create_model - models.create_model(test_model) - - # Verify API was called - mock_post.assert_called_once() - assert mock_success.called - - def test_patch_model_success(self, monkeypatch): - """Test patching an existing model""" - with temporary_sys_path(os.path.join(os.path.dirname(__file__), "../../../../../../src")): - from client.content.config.tabs import models - from client.utils import api_call - import streamlit as st - from streamlit import session_state as state - - # Setup test model - test_model = { - "id": "existing-model", - "provider": "openai", - "type": "ll", - "enabled": False, - } - - # Setup state with client settings - state.client_settings = { - "ll_model": {"model": "openai/existing-model"}, - "testbed": { - "judge_model": None, - "qa_ll_model": None, - "qa_embed_model": None, - }, - } - - # Mock API call - mock_patch = MagicMock() - monkeypatch.setattr(api_call, "patch", mock_patch) - - # Mock st.success - mock_success = MagicMock() - monkeypatch.setattr(st, "success", mock_success) - - # Call patch_model - models.patch_model(test_model) - - # Verify API was called - mock_patch.assert_called_once() - assert mock_success.called - - # Verify model was cleared from client settings since it was disabled - assert state.client_settings["ll_model"]["model"] is None - - def test_delete_model_success(self, monkeypatch): - """Test deleting a model""" - with temporary_sys_path(os.path.join(os.path.dirname(__file__), "../../../../../../src")): - from client.content.config.tabs import models - from client.utils import api_call - import streamlit as st - from streamlit import session_state as state - - # Setup state with client settings - state.client_settings = { - "ll_model": {"model": "openai/test-model"}, - "testbed": { - "judge_model": None, - "qa_ll_model": None, - "qa_embed_model": None, - }, - } - - # Mock API call - mock_delete = MagicMock() - monkeypatch.setattr(api_call, "delete", mock_delete) - - # Mock st.success - mock_success = MagicMock() - monkeypatch.setattr(st, "success", mock_success) - - # Mock sleep to speed up test - monkeypatch.setattr("time.sleep", MagicMock()) - - # Call delete_model - models.delete_model("openai", "test-model") - - # Verify API was called - mock_delete.assert_called_once_with(endpoint="v1/models/openai/test-model") - assert mock_success.called - - # Verify model was cleared from client settings - assert state.client_settings["ll_model"]["model"] is None diff --git a/tests/integration/client/content/config/tabs/test_settings.py b/tests/integration/client/content/config/tabs/test_settings.py index ffbd2a8a..ced9c771 100644 --- a/tests/integration/client/content/config/tabs/test_settings.py +++ b/tests/integration/client/content/config/tabs/test_settings.py @@ -2,16 +2,20 @@ """ Copyright (c) 2024, 2025, Oracle and/or its affiliates. Licensed under the Universal Permissive License v1.0 as shown at http://oss.oracle.com/licenses/upl. + +Integration tests for settings.py that require the actual Streamlit app running. +These tests use app_test fixture to interact with real session state. + +Note: Pure function tests (compare_settings, spring_ai_conf_check, save_settings) +and mock-heavy tests are in tests/unit/client/content/config/tabs/test_settings_unit.py """ # spell-checker: disable import json -import zipfile -from pathlib import Path -from types import SimpleNamespace -from unittest.mock import patch, MagicMock, mock_open +from unittest.mock import patch, MagicMock import pytest +from shared_fixtures import call_spring_ai_obaas_with_mocks # Streamlit File ST_FILE = "../src/client/content/config/tabs/settings.py" @@ -181,10 +185,10 @@ def test_basic_configuration(self, app_server, app_test): ############################################################################# -# Test Functions Directly +# Test Get/Save Settings with Real State ############################################################################# class TestSettingsGetSave: - """Test get_settings and save_settings functions""" + """Test get_settings and save_settings functions with real app state""" def _setup_get_settings_test(self, app_test, run_app=True): """Helper method to set up common test configuration for get_settings tests""" @@ -223,33 +227,6 @@ def test_get_settings_other_api_error_raises(self, app_server, app_test): result = get_settings() assert result is not None - def test_save_settings(self): - """Test save_settings function""" - from client.content.config.tabs.settings import save_settings - - test_settings = {"client_settings": {"client": "old-client"}, "other": "data"} - - with patch("client.content.config.tabs.settings.datetime") as mock_datetime: - mock_now = MagicMock() - mock_now.strftime.return_value = "25-SEP-2024T1430" - mock_datetime.now.return_value = mock_now - - result = save_settings(test_settings) - result_dict = json.loads(result) - - assert result_dict["client_settings"]["client"] == "25-SEP-2024T1430" - assert result_dict["other"] == "data" - - def test_save_settings_no_client_settings(self): - """Test save_settings with no client_settings""" - from client.content.config.tabs.settings import save_settings - - test_settings = {"other": "data"} - result = save_settings(test_settings) - result_dict = json.loads(result) - - assert result_dict == {"other": "data"} - def test_apply_uploaded_settings_success(self, app_server, app_test): """Test apply_uploaded_settings with successful API call""" from client.content.config.tabs.settings import apply_uploaded_settings @@ -278,339 +255,6 @@ def test_apply_uploaded_settings_api_error(self, app_server, app_test): apply_uploaded_settings(uploaded_settings) # Just verify it handles errors gracefully - -############################################################################# -# Test Spring AI Configuration Functions -############################################################################# -class TestSpringAIFunctions: - """Test Spring AI configuration and export functions""" - - def _create_mock_session_state(self): - """Helper method to create mock session state for spring_ai tests""" - return SimpleNamespace( - client_settings={ - "client": "test-client", - "database": {"alias": "DEFAULT"}, - "vector_search": {"enabled": False}, - }, - prompt_configs=[ - { - "name": "optimizer_basic-default", - "title": "Basic Example", - "description": "Basic default prompt", - "tags": [], - "text": "You are a helpful assistant.", - } - ], - database_configs=[{"name": "DEFAULT", "user": "test_user", "password": "test_pass"}], - ) - - def _setup_get_settings_test(self, app_test, run_app=True): - """Helper method to set up common test configuration for get_settings tests""" - from client.content.config.tabs.settings import get_settings - - at = app_test(ST_FILE) - - at.session_state.client_settings = { - "client": "test-client", - "ll_model": {"id": "gpt-4o-mini"}, - "embed_model": {"id": "text-embedding-3-small"}, - "database": {"alias": "DEFAULT"}, - "sys_prompt": {"name": "optimizer_basic-default"}, - "ctx_prompt": {"name": "optimizer_no-examples"}, - "vector_search": {"enabled": False}, - } - at.session_state.prompt_configs = [ - { - "name": "optimizer_basic-default", - "title": "Basic Example", - "description": "Basic default prompt", - "tags": [], - "default_text": "You are a helpful assistant.", - "override_text": None, - } - ] - at.session_state.database_configs = [{"name": "DEFAULT", "user": "test_user", "password": "test_pass"}] - - if run_app: - at.run() - return get_settings, at - - def test_spring_ai_conf_check_openai(self): - """Test spring_ai_conf_check with OpenAI models""" - from client.content.config.tabs.settings import spring_ai_conf_check - - ll_model = {"provider": "openai"} - embed_model = {"provider": "openai"} - - result = spring_ai_conf_check(ll_model, embed_model) - assert result == "openai" - - def test_spring_ai_conf_check_ollama(self): - """Test spring_ai_conf_check with Ollama models""" - from client.content.config.tabs.settings import spring_ai_conf_check - - ll_model = {"provider": "ollama"} - embed_model = {"provider": "ollama"} - - result = spring_ai_conf_check(ll_model, embed_model) - assert result == "ollama" - - def test_spring_ai_conf_check_hosted_vllm(self): - """Test spring_ai_conf_check with hosted vLLM models""" - from client.content.config.tabs.settings import spring_ai_conf_check - - ll_model = {"provider": "hosted_vllm"} - embed_model = {"provider": "hosted_vllm"} - - result = spring_ai_conf_check(ll_model, embed_model) - assert result == "hosted_vllm" - - def test_spring_ai_conf_check_hybrid(self): - """Test spring_ai_conf_check with mixed providers""" - from client.content.config.tabs.settings import spring_ai_conf_check - - ll_model = {"provider": "openai"} - embed_model = {"provider": "ollama"} - - result = spring_ai_conf_check(ll_model, embed_model) - assert result == "hybrid" - - def test_spring_ai_conf_check_empty_models(self): - """Test spring_ai_conf_check with empty models""" - from client.content.config.tabs.settings import spring_ai_conf_check - - result = spring_ai_conf_check(None, None) - assert result == "hybrid" - - result = spring_ai_conf_check({}, {}) - assert result == "hybrid" - - def test_spring_ai_obaas_shell_template(self): - """Test spring_ai_obaas function with shell template""" - from client.content.config.tabs.settings import spring_ai_obaas - - mock_session_state = self._create_mock_session_state() - mock_template_content = ( - "Provider: {provider}\nPrompt: {sys_prompt}\n" - "LLM: {ll_model}\nEmbed: {vector_search}\nDB: {database_config}" - ) - - with patch("client.content.config.tabs.settings.state", mock_session_state): - with patch("client.content.config.tabs.settings.st_common.state_configs_lookup") as mock_lookup: - with patch("builtins.open", mock_open(read_data=mock_template_content)): - mock_lookup.return_value = {"DEFAULT": {"user": "test_user"}} - - src_dir = Path("/test/path") - result = spring_ai_obaas( - src_dir, "start.sh", "openai", {"model": "gpt-4"}, {"model": "text-embedding-ada-002"} - ) - - assert "Provider: openai" in result - assert "You are a helpful assistant." in result - assert "{'model': 'gpt-4'}" in result - - def test_spring_ai_obaas_non_yaml_file(self): - """Test spring_ai_obaas with non-YAML file""" - from client.content.config.tabs.settings import spring_ai_obaas - - mock_state = SimpleNamespace( - client_settings={ - "database": {"alias": "DEFAULT"}, - "vector_search": {"enabled": False}, - }, - prompt_configs=[ - { - "name": "optimizer_basic-default", - "title": "Basic Example", - "description": "Basic default prompt", - "tags": [], - "text": "You are a helpful assistant.", - } - ], - ) - mock_template_content = ( - "Provider: {provider}\nPrompt: {sys_prompt}\nLLM: {ll_model}\n" - "Embed: {vector_search}\nDB: {database_config}" - ) - - with patch("client.content.config.tabs.settings.state", mock_state): - with patch("client.content.config.tabs.settings.st_common.state_configs_lookup") as mock_lookup: - with patch("builtins.open", mock_open(read_data=mock_template_content)): - mock_lookup.return_value = {"DEFAULT": {"user": "test_user"}} - - src_dir = Path("/test/path") - result = spring_ai_obaas( - src_dir, "start.sh", "openai", {"model": "gpt-4"}, {"model": "text-embedding-ada-002"} - ) - - assert "Provider: openai" in result - assert "You are a helpful assistant." in result - assert "{'model': 'gpt-4'}" in result - - def test_spring_ai_zip_creation(self): - """Test spring_ai_zip function creates proper ZIP file""" - from client.content.config.tabs.settings import spring_ai_zip - - mock_session_state = self._create_mock_session_state() - with patch("client.content.config.tabs.settings.state", mock_session_state): - with patch("client.content.config.tabs.settings.st_common.state_configs_lookup") as mock_lookup: - with patch("client.content.config.tabs.settings.shutil.copytree"): - with patch("client.content.config.tabs.settings.shutil.copy"): - with patch("client.content.config.tabs.settings.spring_ai_obaas") as mock_obaas: - mock_lookup.return_value = {"DEFAULT": {"user": "test_user"}} - mock_obaas.return_value = "mock content" - - result = spring_ai_zip("openai", {"model": "gpt-4"}, {"model": "text-embedding-ada-002"}) - - # Verify it's a valid BytesIO object - assert hasattr(result, "read") - assert hasattr(result, "seek") - - # Verify ZIP content - result.seek(0) - with zipfile.ZipFile(result, "r") as zip_file: - files = zip_file.namelist() - assert "start.sh" in files - assert "src/main/resources/application-obaas.yml" in files - - def test_langchain_mcp_zip_creation(self): - """Test langchain_mcp_zip function creates proper ZIP file""" - from client.content.config.tabs.settings import langchain_mcp_zip - - test_settings = {"test": "config"} - - with patch("client.content.config.tabs.settings.shutil.copytree"): - with patch("client.content.config.tabs.settings.save_settings") as mock_save: - with patch("builtins.open", mock_open()): - mock_save.return_value = '{"test": "config"}' - - result = langchain_mcp_zip(test_settings) - - # Verify it's a valid BytesIO object - assert hasattr(result, "read") - assert hasattr(result, "seek") - - # Verify save_settings was called - mock_save.assert_called_once_with(test_settings) - - def test_compare_settings_comprehensive(self): - """Test compare_settings function with comprehensive scenarios""" - from client.content.config.tabs.settings import compare_settings - - current = { - "shared": {"value": "same"}, - "current_only": {"value": "current"}, - "different": {"value": "current_val"}, - "api_key": "current_key", - "nested": {"shared": "same", "different": "current_nested"}, - "list_field": ["a", "b", "c"], - } - - uploaded = { - "shared": {"value": "same"}, - "uploaded_only": {"value": "uploaded"}, - "different": {"value": "uploaded_val"}, - "api_key": "uploaded_key", - "password": "uploaded_pass", - "nested": {"shared": "same", "different": "uploaded_nested", "new_field": "new"}, - "list_field": ["a", "b", "d", "e"], - } - - differences = compare_settings(current, uploaded) - - # Check value mismatches - assert "different.value" in differences["Value Mismatch"] - assert "nested.different" in differences["Value Mismatch"] - assert "api_key" in differences["Value Mismatch"] - - # Check missing fields - assert "current_only" in differences["Missing in Uploaded"] - assert "nested.new_field" in differences["Missing in Current"] - - # Check sensitive key handling - assert "password" in differences["Override on Upload"] - - # Check list handling - assert "list_field[2]" in differences["Value Mismatch"] - assert "list_field[3]" in differences["Missing in Current"] - - def test_compare_settings_client_skip(self): - """Test compare_settings skips client_settings.client path""" - from client.content.config.tabs.settings import compare_settings - - current = {"client_settings": {"client": "current_client"}} - uploaded = {"client_settings": {"client": "uploaded_client"}} - - differences = compare_settings(current, uploaded) - - # Should be empty since client_settings.client is skipped - assert all(not diff_dict for diff_dict in differences.values()) - - def test_compare_settings_sensitive_key_handling(self): - """Test compare_settings handles sensitive keys correctly""" - from client.content.config.tabs.settings import compare_settings - - current = {"api_key": "current_key", "password": "current_pass", "normal_field": "current_val"} - - uploaded = {"api_key": "uploaded_key", "wallet_password": "uploaded_wallet", "normal_field": "uploaded_val"} - - differences = compare_settings(current, uploaded) - - # Sensitive keys should be in Value Mismatch - assert "api_key" in differences["Value Mismatch"] - - # New sensitive keys should be in Override on Upload - assert "wallet_password" in differences["Override on Upload"] - - # Normal fields should be in Value Mismatch - assert "normal_field" in differences["Value Mismatch"] - - # Current-only sensitive key should be silently updated (not in Missing in Uploaded) - assert "password" not in differences["Missing in Uploaded"] - - def test_spring_ai_obaas_error_handling(self): - """Test spring_ai_obaas function error handling""" - from client.content.config.tabs.settings import spring_ai_obaas - - mock_session_state = self._create_mock_session_state() - with patch("client.content.config.tabs.settings.state", mock_session_state): - with patch("client.content.config.tabs.settings.st_common.state_configs_lookup") as mock_lookup: - mock_lookup.return_value = {"DEFAULT": {"user": "test_user"}} - - # Test file not found - with patch("builtins.open", side_effect=FileNotFoundError("File not found")): - with pytest.raises(FileNotFoundError): - spring_ai_obaas( - Path("/test/path"), - "missing.sh", - "openai", - {"model": "gpt-4"}, - {"model": "text-embedding-ada-002"}, - ) - - def test_spring_ai_obaas_yaml_parsing_error(self): - """Test spring_ai_obaas YAML parsing error handling""" - from client.content.config.tabs.settings import spring_ai_obaas - - mock_session_state = self._create_mock_session_state() - invalid_yaml = "invalid: yaml: content: [" - - with patch("client.content.config.tabs.settings.state", mock_session_state): - with patch("client.content.config.tabs.settings.st_common.state_configs_lookup") as mock_lookup: - with patch("builtins.open", mock_open(read_data=invalid_yaml)): - mock_lookup.return_value = {"DEFAULT": {"user": "test_user"}} - - # Should handle YAML parsing errors gracefully - with pytest.raises(Exception): # Could be yaml.YAMLError or similar - spring_ai_obaas( - Path("/test/path"), - "invalid.yaml", - "openai", - {"model": "gpt-4"}, - {"model": "text-embedding-ada-002"}, - ) - def test_get_settings_default_parameters(self, app_server, app_test): """Test get_settings with default parameters""" assert app_server is not None @@ -620,193 +264,65 @@ def test_get_settings_default_parameters(self, app_server, app_test): result = get_settings() # No parameters assert result is not None - def test_save_settings_with_nested_client_settings(self): - """Test save_settings with nested client_settings structure""" - from client.content.config.tabs.settings import save_settings - - test_settings = { - "client_settings": {"client": "old-client", "nested": {"value": "test"}}, - "other_settings": {"value": "unchanged"}, - } - - with patch("client.content.config.tabs.settings.datetime") as mock_datetime: - mock_now = MagicMock() - mock_now.strftime.return_value = "26-SEP-2024T0900" - mock_datetime.now.return_value = mock_now - - result = save_settings(test_settings) - result_dict = json.loads(result) - - # Client should be updated - assert result_dict["client_settings"]["client"] == "26-SEP-2024T0900" - # Nested values should be preserved - assert result_dict["client_settings"]["nested"]["value"] == "test" - # Other settings should be unchanged - assert result_dict["other_settings"]["value"] == "unchanged" - ############################################################################# -# Test Compare Settings Functions +# Test Spring AI Functions with Real State ############################################################################# -class TestCompareSettingsFunctions: - """Test compare_settings utility function""" - - def test_compare_settings_with_none_values(self): - """Test compare_settings with None values""" - from client.content.config.tabs.settings import compare_settings - - current = {"field1": None, "field2": "value"} - uploaded = {"field1": "value", "field2": None} - - differences = compare_settings(current, uploaded) +class TestSpringAIIntegration: + """Integration tests for Spring AI functions using real app state""" - assert "field1" in differences["Value Mismatch"] - assert "field2" in differences["Value Mismatch"] - - def test_compare_settings_empty_structures(self): - """Test compare_settings with empty structures""" - from client.content.config.tabs.settings import compare_settings - - # Test empty dictionaries - differences = compare_settings({}, {}) - assert all(not diff_dict for diff_dict in differences.values()) - - # Test empty lists - differences = compare_settings([], []) - assert all(not diff_dict for diff_dict in differences.values()) - - # Test mixed empty structures - current = {"empty_dict": {}, "empty_list": []} - uploaded = {"empty_dict": {}, "empty_list": []} - differences = compare_settings(current, uploaded) - assert all(not diff_dict for diff_dict in differences.values()) - - def test_compare_settings_ignores_created_timestamps(self): - """Test compare_settings ignores 'created' timestamp fields""" - from client.content.config.tabs.settings import compare_settings - - current = { - "model_configs": [ - {"id": "gpt-4", "created": 1758808962, "model": "gpt-4"}, - {"id": "gpt-3.5", "created": 1758808962, "model": "gpt-3.5-turbo"}, - ], - "client_settings": {"ll_model": {"model": "openai/gpt-4o-mini"}}, - } - - uploaded = { - "model_configs": [ - {"id": "gpt-4", "created": 1758808458, "model": "gpt-4"}, - {"id": "gpt-3.5", "created": 1758808458, "model": "gpt-3.5-turbo"}, - ], - "client_settings": {"ll_model": {"model": None}}, - } - - differences = compare_settings(current, uploaded) - - # 'created' fields should not appear in differences - assert "model_configs[0].created" not in differences["Value Mismatch"] - assert "model_configs[1].created" not in differences["Value Mismatch"] - - # But other fields should still be compared - assert "client_settings.ll_model.model" in differences["Value Mismatch"] - - def test_compare_settings_ignores_nested_created_fields(self): - """Test compare_settings ignores deeply nested 'created' fields""" - from client.content.config.tabs.settings import compare_settings - - current = { - "nested": { - "config": {"created": 123456789, "value": "current"}, - "another": {"created": 987654321, "setting": "test"}, - } - } - - uploaded = { - "nested": { - "config": {"created": 111111111, "value": "current"}, - "another": {"created": 222222222, "setting": "changed"}, - } - } - - differences = compare_settings(current, uploaded) - - # 'created' fields should be ignored - assert "nested.config.created" not in differences["Value Mismatch"] - assert "nested.another.created" not in differences["Value Mismatch"] - - # But actual value differences should be detected - assert "nested.another.setting" in differences["Value Mismatch"] - assert differences["Value Mismatch"]["nested.another.setting"]["current"] == "test" - assert differences["Value Mismatch"]["nested.another.setting"]["uploaded"] == "changed" + def test_spring_ai_obaas_with_real_state_basic(self, app_server, app_test): + """Test spring_ai_obaas uses basic prompt with real state when Vector Search not in tools_enabled""" + from client.content.config.tabs.settings import spring_ai_obaas - def test_compare_settings_ignores_created_in_lists(self): - """Test compare_settings ignores 'created' fields within list items""" - from client.content.config.tabs.settings import compare_settings + assert app_server is not None + at = app_test(ST_FILE).run() - current = { - "items": [ - {"name": "item1", "created": 1111, "enabled": True}, - {"name": "item2", "created": 2222, "enabled": False}, - ] - } + # Set up state with tools_enabled NOT containing "Vector Search" + at.session_state.client_settings["tools_enabled"] = ["Other Tool"] + at.session_state.client_settings["database"] = {"alias": "DEFAULT"} - uploaded = { - "items": [ - {"name": "item1", "created": 9999, "enabled": True}, - {"name": "item2", "created": 8888, "enabled": True}, - ] - } + result = call_spring_ai_obaas_with_mocks(at.session_state, "Prompt: {sys_prompt}", spring_ai_obaas) - differences = compare_settings(current, uploaded) + # Should use basic prompt - result should not be None + assert result is not None - # 'created' fields should be ignored - assert "items[0].created" not in differences["Value Mismatch"] - assert "items[1].created" not in differences["Value Mismatch"] + def test_spring_ai_obaas_with_real_state_vector_search(self, app_server, app_test): + """Test spring_ai_obaas uses VS prompt with real state when Vector Search IS in tools_enabled""" + from client.content.config.tabs.settings import spring_ai_obaas - # But other field differences should be detected - assert "items[1].enabled" in differences["Value Mismatch"] - assert differences["Value Mismatch"]["items[1].enabled"]["current"] is False - assert differences["Value Mismatch"]["items[1].enabled"]["uploaded"] is True + assert app_server is not None + at = app_test(ST_FILE).run() - def test_compare_settings_mixed_created_and_regular_fields(self): - """Test compare_settings with a mix of 'created' and regular fields""" - from client.content.config.tabs.settings import compare_settings + # Set up state with tools_enabled containing "Vector Search" + at.session_state.client_settings["tools_enabled"] = ["Vector Search"] + at.session_state.client_settings["database"] = {"alias": "DEFAULT"} - current = { - "config": { - "created": 123456, - "modified": 789012, - "name": "current_config", - "settings": {"created": 345678, "value": "old_value"}, - } - } + result = call_spring_ai_obaas_with_mocks(at.session_state, "Prompt: {sys_prompt}", spring_ai_obaas) - uploaded = { - "config": { - "created": 999999, # Different created - should be ignored - "modified": 888888, # Different modified - should be detected - "name": "current_config", # Same name - no difference - "settings": { - "created": 777777, # Different created - should be ignored - "value": "new_value", # Different value - should be detected - }, - } - } + # Should use VS prompt - result should not be None + assert result is not None - differences = compare_settings(current, uploaded) + def test_spring_ai_obaas_tools_enabled_not_set(self, app_server, app_test): + """Test spring_ai_obaas handles missing tools_enabled gracefully""" + from client.content.config.tabs.settings import spring_ai_obaas - # 'created' fields should be ignored - assert "config.created" not in differences["Value Mismatch"] - assert "config.settings.created" not in differences["Value Mismatch"] + assert app_server is not None + at = app_test(ST_FILE).run() - # Regular field differences should be detected - assert "config.modified" in differences["Value Mismatch"] - assert "config.settings.value" in differences["Value Mismatch"] + # Remove tools_enabled if it exists + if "tools_enabled" in at.session_state.client_settings: + del at.session_state.client_settings["tools_enabled"] + at.session_state.client_settings["database"] = {"alias": "DEFAULT"} - # Same values should not appear in differences - assert "config.name" not in differences["Value Mismatch"] + # Should not raise - uses .get() with default empty list + result = call_spring_ai_obaas_with_mocks(at.session_state, "Prompt: {sys_prompt}", spring_ai_obaas) + assert result is not None +############################################################################# +# Test Prompt Config Upload with Real State +############################################################################# class TestPromptConfigUpload: """Test prompt configuration upload scenarios via Streamlit UI""" diff --git a/tests/integration/client/content/tools/tabs/test_split_embed.py b/tests/integration/client/content/tools/tabs/test_split_embed.py index 50d22d5e..82b00501 100644 --- a/tests/integration/client/content/tools/tabs/test_split_embed.py +++ b/tests/integration/client/content/tools/tabs/test_split_embed.py @@ -12,23 +12,6 @@ from integration.client.conftest import enable_test_embed_models -############################################################################# -# Test Helpers -############################################################################# -class MockState: - """Mock session state for testing OCI-related functionality""" - - def __init__(self): - self.client_settings = {"oci": {"auth_profile": "DEFAULT"}} - - def __getitem__(self, key): - return getattr(self, key) - - def get(self, key, default=None): - """Get method for dict-like access""" - return getattr(self, key, default) - - ############################################################################# # Test Streamlit UI ############################################################################# @@ -312,58 +295,6 @@ def test_file_source_radio_with_oke_workload_identity(self, app_server, app_test # OCI may or may not appear depending on namespace availability -############################################################################# -# Test Split & Embed Functions -############################################################################# -class TestSplitEmbedFunctions: - """Test individual functions from split_embed.py""" - - # Streamlit File path - ST_FILE = "../src/client/content/tools/tabs/split_embed.py" - - def test_get_buckets_success(self, monkeypatch): - """Test get_buckets function with successful API call""" - from client.content.tools.tabs.split_embed import get_buckets - - # Mock session state with proper attribute access - monkeypatch.setattr("client.content.tools.tabs.split_embed.state", MockState()) - - mock_buckets = ["bucket1", "bucket2", "bucket3"] - monkeypatch.setattr("client.utils.api_call.get", lambda endpoint: mock_buckets) - - result = get_buckets("test-compartment") - assert result == mock_buckets - - def test_get_buckets_api_error(self, monkeypatch): - """Test get_buckets function when API call fails""" - from client.content.tools.tabs.split_embed import get_buckets - from client.utils.api_call import ApiError - - # Mock session state with proper attribute access - monkeypatch.setattr("client.content.tools.tabs.split_embed.state", MockState()) - - def mock_get_with_error(endpoint): - raise ApiError("Access denied") - - monkeypatch.setattr("client.utils.api_call.get", mock_get_with_error) - - result = get_buckets("test-compartment") - assert result == ["No Access to Buckets in this Compartment"] - - def test_get_bucket_objects(self, monkeypatch): - """Test get_bucket_objects function""" - from client.content.tools.tabs.split_embed import get_bucket_objects - - # Mock session state with proper attribute access - monkeypatch.setattr("client.content.tools.tabs.split_embed.state", MockState()) - - mock_objects = ["file1.txt", "file2.pdf", "document.docx"] - monkeypatch.setattr("client.utils.api_call.get", lambda endpoint: mock_objects) - - result = get_bucket_objects("test-bucket") - assert result == mock_objects - - ############################################################################# # Test UI Components ############################################################################# @@ -670,38 +601,3 @@ def test_populate_button_shown_in_create_new_mode(self, app_server, app_test): # NOTE: This may not be present if embedding models aren't accessible # Just checking the button logic - verification happens implicitly via page load pass - - def test_get_compartments(self, monkeypatch): - """Test get_compartments function with successful API call""" - from client.content.tools.tabs.split_embed import get_compartments - - # Mock session state using module-level MockState - monkeypatch.setattr("client.content.tools.tabs.split_embed.state", MockState()) - - # Mock API response - def mock_get(**_kwargs): - return {"comp1": "ocid1.compartment.oc1..test1", "comp2": "ocid1.compartment.oc1..test2"} - - monkeypatch.setattr("client.utils.api_call.get", mock_get) - - result = get_compartments() - assert isinstance(result, dict) - assert len(result) == 2 - assert "comp1" in result - - def test_files_data_editor(self, monkeypatch): - """Test files_data_editor function""" - from client.content.tools.tabs.split_embed import files_data_editor - - # Create test dataframe - test_df = pd.DataFrame({"File": ["file1.txt", "file2.txt"], "Process": [True, False]}) - - # Mock st.data_editor - def mock_data_editor(data, **_kwargs): - return data - - monkeypatch.setattr("streamlit.data_editor", mock_data_editor) - - result = files_data_editor(test_df, key="test_key") - assert isinstance(result, pd.DataFrame) - assert len(result) == 2 diff --git a/tests/integration/common/test_functions.py b/tests/integration/common/test_functions.py index 167ccf16..05c2a10b 100644 --- a/tests/integration/common/test_functions.py +++ b/tests/integration/common/test_functions.py @@ -12,9 +12,9 @@ import tempfile import pytest +from db_fixtures import TEST_DB_CONFIG from common import functions -from db_fixtures import TEST_DB_CONFIG class TestIsUrlAccessibleIntegration: diff --git a/tests/integration/server/api/conftest.py b/tests/integration/server/api/conftest.py index e8814b38..43ece217 100644 --- a/tests/integration/server/api/conftest.py +++ b/tests/integration/server/api/conftest.py @@ -25,8 +25,6 @@ import numpy as np import pytest from fastapi.testclient import TestClient - -from server.bootstrap.bootstrap import DATABASE_OBJECTS, MODEL_OBJECTS, SETTINGS_OBJECTS # Import constants and helpers needed by fixtures in this file from db_fixtures import TEST_DB_CONFIG from shared_fixtures import ( @@ -38,6 +36,8 @@ restore_env_state, ) +from server.bootstrap.bootstrap import DATABASE_OBJECTS, MODEL_OBJECTS, SETTINGS_OBJECTS + # Test configuration - extends shared DB config with integration-specific settings TEST_CONFIG = { "client": "integration_test", diff --git a/tests/integration/server/bootstrap/test_bootstrap_databases.py b/tests/integration/server/bootstrap/test_bootstrap_databases.py index 32f6492d..63fcca60 100644 --- a/tests/integration/server/bootstrap/test_bootstrap_databases.py +++ b/tests/integration/server/bootstrap/test_bootstrap_databases.py @@ -13,14 +13,14 @@ import os import pytest - -from server.bootstrap import databases as databases_module from shared_fixtures import ( assert_database_list_valid, assert_has_default_database, get_database_by_name, ) +from server.bootstrap import databases as databases_module + @pytest.mark.usefixtures("reset_config_store", "clean_bootstrap_env") class TestDatabasesBootstrapWithConfig: diff --git a/tests/integration/server/bootstrap/test_bootstrap_models.py b/tests/integration/server/bootstrap/test_bootstrap_models.py index 49d63e23..8991c0e3 100644 --- a/tests/integration/server/bootstrap/test_bootstrap_models.py +++ b/tests/integration/server/bootstrap/test_bootstrap_models.py @@ -14,9 +14,9 @@ from unittest.mock import patch import pytest +from shared_fixtures import assert_model_list_valid, get_model_by_id from server.bootstrap import models as models_module -from shared_fixtures import assert_model_list_valid, get_model_by_id @pytest.mark.usefixtures("reset_config_store", "clean_bootstrap_env") diff --git a/tests/shared_fixtures.py b/tests/shared_fixtures.py index 62f2aa53..5fc14500 100644 --- a/tests/shared_fixtures.py +++ b/tests/shared_fixtures.py @@ -593,6 +593,41 @@ def make_auth_headers(auth_token: str, client_id: str) -> dict: } +################################################# +# Spring AI Test Helpers +################################################# + + +def call_spring_ai_obaas_with_mocks(mock_state, template_content, spring_ai_obaas_func): + """Call spring_ai_obaas with standard mocking setup. + + This helper encapsulates the common patching pattern for spring_ai_obaas tests, + reducing code duplication between unit and integration tests. + + Args: + mock_state: The state object to use (mock or real session_state) + template_content: The template file content to return from mock open + spring_ai_obaas_func: The spring_ai_obaas function to call + + Returns: + The result from calling spring_ai_obaas + """ + # pylint: disable=import-outside-toplevel + from unittest.mock import patch, mock_open + + with patch("client.content.config.tabs.settings.state", mock_state): + with patch("client.content.config.tabs.settings.st_common.state_configs_lookup") as mock_lookup: + with patch("builtins.open", mock_open(read_data=template_content)): + mock_lookup.return_value = {"DEFAULT": {"user": "test_user"}} + return spring_ai_obaas_func( + Path("/test/path"), + "start.sh", + "openai", + {"model": "gpt-4"}, + {"model": "text-embedding-ada-002"}, + ) + + ################################################# # Vector Store Test Data ################################################# diff --git a/tests/unit/client/content/config/tabs/test_models_unit.py b/tests/unit/client/content/config/tabs/test_models_unit.py index bc5736a6..b86c62e3 100644 --- a/tests/unit/client/content/config/tabs/test_models_unit.py +++ b/tests/unit/client/content/config/tabs/test_models_unit.py @@ -275,3 +275,120 @@ def test_clear_client_models_no_match(self): # Verify nothing was changed assert state.client_settings["ll_model"]["model"] == "openai/gpt-4" + + +############################################################################# +# Test Model CRUD Operations +############################################################################# +class TestModelCRUD: + """Test model create/patch/delete operations""" + + def test_create_model_success(self, monkeypatch): + """Test creating a new model""" + from client.content.config.tabs import models + from client.utils import api_call + import streamlit as st + + # Setup test model + test_model = { + "id": "new-model", + "provider": "openai", + "type": "ll", + "enabled": True, + } + + # Mock API call + mock_post = MagicMock() + monkeypatch.setattr(api_call, "post", mock_post) + + # Mock st.success + mock_success = MagicMock() + monkeypatch.setattr(st, "success", mock_success) + + # Call create_model + models.create_model(test_model) + + # Verify API was called + mock_post.assert_called_once() + assert mock_success.called + + def test_patch_model_success(self, monkeypatch): + """Test patching an existing model""" + from client.content.config.tabs import models + from client.utils import api_call + import streamlit as st + from streamlit import session_state as state + + # Setup test model + test_model = { + "id": "existing-model", + "provider": "openai", + "type": "ll", + "enabled": False, + } + + # Setup state with client settings + state.client_settings = { + "ll_model": {"model": "openai/existing-model"}, + "testbed": { + "judge_model": None, + "qa_ll_model": None, + "qa_embed_model": None, + }, + } + + # Mock API call + mock_patch = MagicMock() + monkeypatch.setattr(api_call, "patch", mock_patch) + + # Mock st.success + mock_success = MagicMock() + monkeypatch.setattr(st, "success", mock_success) + + # Call patch_model + models.patch_model(test_model) + + # Verify API was called + mock_patch.assert_called_once() + assert mock_success.called + + # Verify model was cleared from client settings since it was disabled + assert state.client_settings["ll_model"]["model"] is None + + def test_delete_model_success(self, monkeypatch): + """Test deleting a model""" + from client.content.config.tabs import models + from client.utils import api_call + import streamlit as st + from streamlit import session_state as state + + # Setup state with client settings + state.client_settings = { + "ll_model": {"model": "openai/test-model"}, + "testbed": { + "judge_model": None, + "qa_ll_model": None, + "qa_embed_model": None, + }, + } + + # Mock API call + mock_delete = MagicMock() + monkeypatch.setattr(api_call, "delete", mock_delete) + + # Mock st.success + mock_success = MagicMock() + monkeypatch.setattr(st, "success", mock_success) + + # Mock sleep to speed up test + monkeypatch.setattr("time.sleep", MagicMock()) + + # Call delete_model + models.delete_model("openai", "test-model") + + # Verify API was called + mock_delete.assert_called_once_with(endpoint="v1/models/openai/test-model") + assert mock_success.called + + # Verify model was cleared from client settings + assert state.client_settings["ll_model"]["model"] is None diff --git a/tests/unit/client/content/config/tabs/test_settings_unit.py b/tests/unit/client/content/config/tabs/test_settings_unit.py new file mode 100644 index 00000000..073eace1 --- /dev/null +++ b/tests/unit/client/content/config/tabs/test_settings_unit.py @@ -0,0 +1,572 @@ +# pylint: disable=protected-access,import-error,import-outside-toplevel +""" +Copyright (c) 2024, 2025, Oracle and/or its affiliates. +Licensed under the Universal Permissive License v1.0 as shown at http://oss.oracle.com/licenses/upl. + +Unit tests for settings.py functions that don't require server integration. +These tests use mocks to isolate the functions under test. +""" +# spell-checker: disable + +import json +import zipfile +from pathlib import Path +from types import SimpleNamespace +from unittest.mock import patch, MagicMock, mock_open + +import pytest +from shared_fixtures import call_spring_ai_obaas_with_mocks + + +############################################################################# +# Test Spring AI Configuration Check Function +############################################################################# +class TestSpringAIConfCheck: + """Test spring_ai_conf_check function - pure function tests""" + + def test_spring_ai_conf_check_openai(self): + """Test spring_ai_conf_check with OpenAI models""" + from client.content.config.tabs.settings import spring_ai_conf_check + + ll_model = {"provider": "openai"} + embed_model = {"provider": "openai"} + + result = spring_ai_conf_check(ll_model, embed_model) + assert result == "openai" + + def test_spring_ai_conf_check_ollama(self): + """Test spring_ai_conf_check with Ollama models""" + from client.content.config.tabs.settings import spring_ai_conf_check + + ll_model = {"provider": "ollama"} + embed_model = {"provider": "ollama"} + + result = spring_ai_conf_check(ll_model, embed_model) + assert result == "ollama" + + def test_spring_ai_conf_check_hosted_vllm(self): + """Test spring_ai_conf_check with hosted vLLM models""" + from client.content.config.tabs.settings import spring_ai_conf_check + + ll_model = {"provider": "hosted_vllm"} + embed_model = {"provider": "hosted_vllm"} + + result = spring_ai_conf_check(ll_model, embed_model) + assert result == "hosted_vllm" + + def test_spring_ai_conf_check_hybrid(self): + """Test spring_ai_conf_check with mixed providers""" + from client.content.config.tabs.settings import spring_ai_conf_check + + ll_model = {"provider": "openai"} + embed_model = {"provider": "ollama"} + + result = spring_ai_conf_check(ll_model, embed_model) + assert result == "hybrid" + + def test_spring_ai_conf_check_empty_models(self): + """Test spring_ai_conf_check with empty models""" + from client.content.config.tabs.settings import spring_ai_conf_check + + result = spring_ai_conf_check(None, None) + assert result == "hybrid" + + result = spring_ai_conf_check({}, {}) + assert result == "hybrid" + + +############################################################################# +# Test Spring AI OBaaS Function +############################################################################# +class TestSpringAIObaas: + """Test spring_ai_obaas function with mocked state""" + + def _create_mock_session_state(self, tools_enabled=None): + """Helper method to create mock session state for spring_ai tests""" + client_settings = { + "client": "test-client", + "database": {"alias": "DEFAULT"}, + "vector_search": {"enabled": False}, + } + if tools_enabled is not None: + client_settings["tools_enabled"] = tools_enabled + + return SimpleNamespace( + client_settings=client_settings, + prompt_configs=[ + { + "name": "optimizer_basic-default", + "title": "Basic Example", + "description": "Basic default prompt", + "tags": [], + "text": "You are a helpful assistant.", + }, + { + "name": "optimizer_vs-no-tools-default", + "title": "VS No Tools", + "description": "Vector search prompt without tools", + "tags": [], + "text": "You are a vector search assistant.", + }, + ], + database_configs=[{"name": "DEFAULT", "user": "test_user", "password": "test_pass"}], + ) + + def test_spring_ai_obaas_shell_template(self): + """Test spring_ai_obaas function with shell template""" + from client.content.config.tabs.settings import spring_ai_obaas + + mock_session_state = self._create_mock_session_state() + template = ( + "Provider: {provider}\nPrompt: {sys_prompt}\nLLM: {ll_model}\n" + "Embed: {vector_search}\nDB: {database_config}" + ) + + result = call_spring_ai_obaas_with_mocks(mock_session_state, template, spring_ai_obaas) + + assert "Provider: openai" in result + assert "You are a helpful assistant." in result + assert "{'model': 'gpt-4'}" in result + + def test_spring_ai_obaas_with_vector_search_tool_enabled(self): + """Test spring_ai_obaas uses vs-no-tools-default prompt when Vector Search is in tools_enabled""" + from client.content.config.tabs.settings import spring_ai_obaas + + mock_state = self._create_mock_session_state(tools_enabled=["Vector Search"]) + + result = call_spring_ai_obaas_with_mocks(mock_state, "Prompt: {sys_prompt}", spring_ai_obaas) + + # Should use the vector search prompt when "Vector Search" is in tools_enabled + assert "You are a vector search assistant." in result + + def test_spring_ai_obaas_without_vector_search_tool(self): + """Test spring_ai_obaas uses basic-default prompt when Vector Search is NOT in tools_enabled""" + from client.content.config.tabs.settings import spring_ai_obaas + + mock_state = self._create_mock_session_state(tools_enabled=["Other Tool"]) + + result = call_spring_ai_obaas_with_mocks(mock_state, "Prompt: {sys_prompt}", spring_ai_obaas) + + # Should use the basic prompt when "Vector Search" is NOT in tools_enabled + assert "You are a helpful assistant." in result + + def test_spring_ai_obaas_with_empty_tools_enabled(self): + """Test spring_ai_obaas uses basic-default prompt when tools_enabled is empty""" + from client.content.config.tabs.settings import spring_ai_obaas + + mock_state = self._create_mock_session_state(tools_enabled=[]) + + result = call_spring_ai_obaas_with_mocks(mock_state, "Prompt: {sys_prompt}", spring_ai_obaas) + + # Should use the basic prompt when tools_enabled is empty + assert "You are a helpful assistant." in result + + def test_spring_ai_obaas_error_handling(self): + """Test spring_ai_obaas function error handling""" + from client.content.config.tabs.settings import spring_ai_obaas + + mock_session_state = self._create_mock_session_state() + with patch("client.content.config.tabs.settings.state", mock_session_state): + with patch("client.content.config.tabs.settings.st_common.state_configs_lookup") as mock_lookup: + mock_lookup.return_value = {"DEFAULT": {"user": "test_user"}} + + # Test file not found + with patch("builtins.open", side_effect=FileNotFoundError("File not found")): + with pytest.raises(FileNotFoundError): + spring_ai_obaas( + Path("/test/path"), + "missing.sh", + "openai", + {"model": "gpt-4"}, + {"model": "text-embedding-ada-002"}, + ) + + def test_spring_ai_obaas_yaml_parsing_error(self): + """Test spring_ai_obaas YAML parsing error handling""" + from client.content.config.tabs.settings import spring_ai_obaas + + mock_session_state = self._create_mock_session_state() + invalid_yaml = "invalid: yaml: content: [" + + with patch("client.content.config.tabs.settings.state", mock_session_state): + with patch("client.content.config.tabs.settings.st_common.state_configs_lookup") as mock_lookup: + with patch("builtins.open", mock_open(read_data=invalid_yaml)): + mock_lookup.return_value = {"DEFAULT": {"user": "test_user"}} + + # Should handle YAML parsing errors gracefully + with pytest.raises(Exception): # Could be yaml.YAMLError or similar + spring_ai_obaas( + Path("/test/path"), + "invalid.yaml", + "openai", + {"model": "gpt-4"}, + {"model": "text-embedding-ada-002"}, + ) + + +############################################################################# +# Test Spring AI ZIP Creation +############################################################################# +class TestSpringAIZip: + """Test spring_ai_zip and langchain_mcp_zip functions""" + + def _create_mock_session_state(self): + """Helper method to create mock session state""" + return SimpleNamespace( + client_settings={ + "client": "test-client", + "database": {"alias": "DEFAULT"}, + "vector_search": {"enabled": False}, + }, + prompt_configs=[ + { + "name": "optimizer_basic-default", + "title": "Basic Example", + "description": "Basic default prompt", + "tags": [], + "text": "You are a helpful assistant.", + } + ], + database_configs=[{"name": "DEFAULT", "user": "test_user", "password": "test_pass"}], + ) + + def test_spring_ai_zip_creation(self): + """Test spring_ai_zip function creates proper ZIP file""" + from client.content.config.tabs.settings import spring_ai_zip + + mock_session_state = self._create_mock_session_state() + with patch("client.content.config.tabs.settings.state", mock_session_state): + with patch("client.content.config.tabs.settings.st_common.state_configs_lookup") as mock_lookup: + with patch("client.content.config.tabs.settings.shutil.copytree"): + with patch("client.content.config.tabs.settings.shutil.copy"): + with patch("client.content.config.tabs.settings.spring_ai_obaas") as mock_obaas: + mock_lookup.return_value = {"DEFAULT": {"user": "test_user"}} + mock_obaas.return_value = "mock content" + + result = spring_ai_zip("openai", {"model": "gpt-4"}, {"model": "text-embedding-ada-002"}) + + # Verify it's a valid BytesIO object + assert hasattr(result, "read") + assert hasattr(result, "seek") + + # Verify ZIP content + result.seek(0) + with zipfile.ZipFile(result, "r") as zip_file: + files = zip_file.namelist() + assert "start.sh" in files + assert "src/main/resources/application-obaas.yml" in files + + def test_langchain_mcp_zip_creation(self): + """Test langchain_mcp_zip function creates proper ZIP file""" + from client.content.config.tabs.settings import langchain_mcp_zip + + test_settings = {"test": "config"} + + with patch("client.content.config.tabs.settings.shutil.copytree"): + with patch("client.content.config.tabs.settings.save_settings") as mock_save: + with patch("builtins.open", mock_open()): + mock_save.return_value = '{"test": "config"}' + + result = langchain_mcp_zip(test_settings) + + # Verify it's a valid BytesIO object + assert hasattr(result, "read") + assert hasattr(result, "seek") + + # Verify save_settings was called + mock_save.assert_called_once_with(test_settings) + + +############################################################################# +# Test Save Settings Function +############################################################################# +class TestSaveSettings: + """Test save_settings function - pure function tests""" + + def test_save_settings(self): + """Test save_settings function""" + from client.content.config.tabs.settings import save_settings + + test_settings = {"client_settings": {"client": "old-client"}, "other": "data"} + + with patch("client.content.config.tabs.settings.datetime") as mock_datetime: + mock_now = MagicMock() + mock_now.strftime.return_value = "25-SEP-2024T1430" + mock_datetime.now.return_value = mock_now + + result = save_settings(test_settings) + result_dict = json.loads(result) + + assert result_dict["client_settings"]["client"] == "25-SEP-2024T1430" + assert result_dict["other"] == "data" + + def test_save_settings_no_client_settings(self): + """Test save_settings with no client_settings""" + from client.content.config.tabs.settings import save_settings + + test_settings = {"other": "data"} + result = save_settings(test_settings) + result_dict = json.loads(result) + + assert result_dict == {"other": "data"} + + def test_save_settings_with_nested_client_settings(self): + """Test save_settings with nested client_settings structure""" + from client.content.config.tabs.settings import save_settings + + test_settings = { + "client_settings": {"client": "old-client", "nested": {"value": "test"}}, + "other_settings": {"value": "unchanged"}, + } + + with patch("client.content.config.tabs.settings.datetime") as mock_datetime: + mock_now = MagicMock() + mock_now.strftime.return_value = "26-SEP-2024T0900" + mock_datetime.now.return_value = mock_now + + result = save_settings(test_settings) + result_dict = json.loads(result) + + # Client should be updated + assert result_dict["client_settings"]["client"] == "26-SEP-2024T0900" + # Nested values should be preserved + assert result_dict["client_settings"]["nested"]["value"] == "test" + # Other settings should be unchanged + assert result_dict["other_settings"]["value"] == "unchanged" + + +############################################################################# +# Test Compare Settings Function +############################################################################# +class TestCompareSettings: + """Test compare_settings function - pure function tests""" + + def test_compare_settings_comprehensive(self): + """Test compare_settings function with comprehensive scenarios""" + from client.content.config.tabs.settings import compare_settings + + current = { + "shared": {"value": "same"}, + "current_only": {"value": "current"}, + "different": {"value": "current_val"}, + "api_key": "current_key", + "nested": {"shared": "same", "different": "current_nested"}, + "list_field": ["a", "b", "c"], + } + + uploaded = { + "shared": {"value": "same"}, + "uploaded_only": {"value": "uploaded"}, + "different": {"value": "uploaded_val"}, + "api_key": "uploaded_key", + "password": "uploaded_pass", + "nested": {"shared": "same", "different": "uploaded_nested", "new_field": "new"}, + "list_field": ["a", "b", "d", "e"], + } + + differences = compare_settings(current, uploaded) + + # Check value mismatches + assert "different.value" in differences["Value Mismatch"] + assert "nested.different" in differences["Value Mismatch"] + assert "api_key" in differences["Value Mismatch"] + + # Check missing fields + assert "current_only" in differences["Missing in Uploaded"] + assert "nested.new_field" in differences["Missing in Current"] + + # Check sensitive key handling + assert "password" in differences["Override on Upload"] + + # Check list handling + assert "list_field[2]" in differences["Value Mismatch"] + assert "list_field[3]" in differences["Missing in Current"] + + def test_compare_settings_client_skip(self): + """Test compare_settings skips client_settings.client path""" + from client.content.config.tabs.settings import compare_settings + + current = {"client_settings": {"client": "current_client"}} + uploaded = {"client_settings": {"client": "uploaded_client"}} + + differences = compare_settings(current, uploaded) + + # Should be empty since client_settings.client is skipped + assert all(not diff_dict for diff_dict in differences.values()) + + def test_compare_settings_sensitive_key_handling(self): + """Test compare_settings handles sensitive keys correctly""" + from client.content.config.tabs.settings import compare_settings + + current = {"api_key": "current_key", "password": "current_pass", "normal_field": "current_val"} + + uploaded = {"api_key": "uploaded_key", "wallet_password": "uploaded_wallet", "normal_field": "uploaded_val"} + + differences = compare_settings(current, uploaded) + + # Sensitive keys should be in Value Mismatch + assert "api_key" in differences["Value Mismatch"] + + # New sensitive keys should be in Override on Upload + assert "wallet_password" in differences["Override on Upload"] + + # Normal fields should be in Value Mismatch + assert "normal_field" in differences["Value Mismatch"] + + # Current-only sensitive key should be silently updated (not in Missing in Uploaded) + assert "password" not in differences["Missing in Uploaded"] + + def test_compare_settings_with_none_values(self): + """Test compare_settings with None values""" + from client.content.config.tabs.settings import compare_settings + + current = {"field1": None, "field2": "value"} + uploaded = {"field1": "value", "field2": None} + + differences = compare_settings(current, uploaded) + + assert "field1" in differences["Value Mismatch"] + assert "field2" in differences["Value Mismatch"] + + def test_compare_settings_empty_structures(self): + """Test compare_settings with empty structures""" + from client.content.config.tabs.settings import compare_settings + + # Test empty dictionaries + differences = compare_settings({}, {}) + assert all(not diff_dict for diff_dict in differences.values()) + + # Test empty lists + differences = compare_settings([], []) + assert all(not diff_dict for diff_dict in differences.values()) + + # Test mixed empty structures + current = {"empty_dict": {}, "empty_list": []} + uploaded = {"empty_dict": {}, "empty_list": []} + differences = compare_settings(current, uploaded) + assert all(not diff_dict for diff_dict in differences.values()) + + def test_compare_settings_ignores_created_timestamps(self): + """Test compare_settings ignores 'created' timestamp fields""" + from client.content.config.tabs.settings import compare_settings + + current = { + "model_configs": [ + {"id": "gpt-4", "created": 1758808962, "model": "gpt-4"}, + {"id": "gpt-3.5", "created": 1758808962, "model": "gpt-3.5-turbo"}, + ], + "client_settings": {"ll_model": {"model": "openai/gpt-4o-mini"}}, + } + + uploaded = { + "model_configs": [ + {"id": "gpt-4", "created": 1758808458, "model": "gpt-4"}, + {"id": "gpt-3.5", "created": 1758808458, "model": "gpt-3.5-turbo"}, + ], + "client_settings": {"ll_model": {"model": None}}, + } + + differences = compare_settings(current, uploaded) + + # 'created' fields should not appear in differences + assert "model_configs[0].created" not in differences["Value Mismatch"] + assert "model_configs[1].created" not in differences["Value Mismatch"] + + # But other fields should still be compared + assert "client_settings.ll_model.model" in differences["Value Mismatch"] + + def test_compare_settings_ignores_nested_created_fields(self): + """Test compare_settings ignores deeply nested 'created' fields""" + from client.content.config.tabs.settings import compare_settings + + current = { + "nested": { + "config": {"created": 123456789, "value": "current"}, + "another": {"created": 987654321, "setting": "test"}, + } + } + + uploaded = { + "nested": { + "config": {"created": 111111111, "value": "current"}, + "another": {"created": 222222222, "setting": "changed"}, + } + } + + differences = compare_settings(current, uploaded) + + # 'created' fields should be ignored + assert "nested.config.created" not in differences["Value Mismatch"] + assert "nested.another.created" not in differences["Value Mismatch"] + + # But actual value differences should be detected + assert "nested.another.setting" in differences["Value Mismatch"] + assert differences["Value Mismatch"]["nested.another.setting"]["current"] == "test" + assert differences["Value Mismatch"]["nested.another.setting"]["uploaded"] == "changed" + + def test_compare_settings_ignores_created_in_lists(self): + """Test compare_settings ignores 'created' fields within list items""" + from client.content.config.tabs.settings import compare_settings + + current = { + "items": [ + {"name": "item1", "created": 1111, "enabled": True}, + {"name": "item2", "created": 2222, "enabled": False}, + ] + } + + uploaded = { + "items": [ + {"name": "item1", "created": 9999, "enabled": True}, + {"name": "item2", "created": 8888, "enabled": True}, + ] + } + + differences = compare_settings(current, uploaded) + + # 'created' fields should be ignored + assert "items[0].created" not in differences["Value Mismatch"] + assert "items[1].created" not in differences["Value Mismatch"] + + # But other field differences should be detected + assert "items[1].enabled" in differences["Value Mismatch"] + assert differences["Value Mismatch"]["items[1].enabled"]["current"] is False + assert differences["Value Mismatch"]["items[1].enabled"]["uploaded"] is True + + def test_compare_settings_mixed_created_and_regular_fields(self): + """Test compare_settings with a mix of 'created' and regular fields""" + from client.content.config.tabs.settings import compare_settings + + current = { + "config": { + "created": 123456, + "modified": 789012, + "name": "current_config", + "settings": {"created": 345678, "value": "old_value"}, + } + } + + uploaded = { + "config": { + "created": 999999, # Different created - should be ignored + "modified": 888888, # Different modified - should be detected + "name": "current_config", # Same name - no difference + "settings": { + "created": 777777, # Different created - should be ignored + "value": "new_value", # Different value - should be detected + }, + } + } + + differences = compare_settings(current, uploaded) + + # 'created' fields should be ignored + assert "config.created" not in differences["Value Mismatch"] + assert "config.settings.created" not in differences["Value Mismatch"] + + # Regular field differences should be detected + assert "config.modified" in differences["Value Mismatch"] + assert "config.settings.value" in differences["Value Mismatch"] + + # Same values should not appear in differences + assert "config.name" not in differences["Value Mismatch"] diff --git a/tests/unit/client/content/test_testbed_records_unit.py b/tests/unit/client/content/test_testbed_records_unit.py new file mode 100644 index 00000000..9edd7733 --- /dev/null +++ b/tests/unit/client/content/test_testbed_records_unit.py @@ -0,0 +1,479 @@ +# pylint: disable=protected-access,import-error,import-outside-toplevel +""" +Copyright (c) 2024, 2025, Oracle and/or its affiliates. +Licensed under the Universal Permissive License v1.0 as shown at http://oss.oracle.com/licenses/upl. + +Unit tests for testbed.py record management functions. +Extracted from test_testbed_unit.py to reduce file size. +""" +# spell-checker: disable + +import sys +from unittest.mock import MagicMock + + +############################################################################# +# Test qa_update_db Function +############################################################################# +class TestQAUpdateDB: + """Test qa_update_db function""" + + def test_qa_update_db_success(self, monkeypatch): + """Test qa_update_db successfully updates database""" + from client.content import testbed + from client.utils import api_call, st_common + from streamlit import session_state as state + + # Setup state + state.testbed = {"testset_id": "test123", "qa_index": 0} + state.selected_new_testset_name = "Updated Test Set" + state.testbed_qa = [ + {"question": "Q1", "reference_answer": "A1"}, + {"question": "Q2", "reference_answer": "A2"}, + ] + state["selected_q_0"] = "Q1" + state["selected_a_0"] = "A1" + + # Mock API call + mock_post = MagicMock(return_value={"status": "success"}) + monkeypatch.setattr(api_call, "post", mock_post) + + # Mock get_testbed_db_testsets + mock_get_testsets = MagicMock(return_value={"testsets": []}) + testbed.get_testbed_db_testsets = mock_get_testsets + testbed.get_testbed_db_testsets.clear = MagicMock() + + # Mock clear_state_key + monkeypatch.setattr(st_common, "clear_state_key", MagicMock()) + + # Call qa_update_db + testbed.qa_update_db() + + # Verify API was called correctly + assert mock_post.called + call_args = mock_post.call_args + assert call_args[1]["endpoint"] == "v1/testbed/testset_load" + assert call_args[1]["params"]["name"] == "Updated Test Set" + assert call_args[1]["params"]["tid"] == "test123" + + def test_qa_update_db_clears_cache(self, monkeypatch): + """Test qa_update_db clears testbed cache""" + from client.content import testbed + from client.utils import api_call, st_common + from streamlit import session_state as state + + # Setup state + state.testbed = {"testset_id": "test123", "qa_index": 0} + state.selected_new_testset_name = "Test Set" + state.testbed_qa = [{"question": "Q1", "reference_answer": "A1"}] + state["selected_q_0"] = "Q1" + state["selected_a_0"] = "A1" + + # Mock functions + monkeypatch.setattr(api_call, "post", MagicMock()) + mock_clear_state = MagicMock() + monkeypatch.setattr(st_common, "clear_state_key", mock_clear_state) + + mock_clear_cache = MagicMock() + testbed.get_testbed_db_testsets = MagicMock(return_value={"testsets": []}) + testbed.get_testbed_db_testsets.clear = mock_clear_cache + + # Call qa_update_db + testbed.qa_update_db() + + # Verify cache was cleared + mock_clear_state.assert_called_with("testbed_db_testsets") + mock_clear_cache.assert_called_once() + + +############################################################################# +# Test qa_delete Function +############################################################################# +class TestQADelete: + """Test qa_delete function""" + + def test_qa_delete_success(self, monkeypatch): + """Test qa_delete successfully deletes testset""" + from client.content import testbed + from client.utils import api_call + from streamlit import session_state as state + import streamlit as st + + # Setup state + state.testbed = {"testset_id": "test123", "testset_name": "My Test Set"} + + # Mock API call + mock_delete = MagicMock() + monkeypatch.setattr(api_call, "delete", mock_delete) + + # Mock reset_testset + mock_reset = MagicMock() + monkeypatch.setattr(testbed, "reset_testset", mock_reset) + + # Mock st.success + mock_success = MagicMock() + monkeypatch.setattr(st, "success", mock_success) + + # Call qa_delete + testbed.qa_delete() + + # Verify delete was called + mock_delete.assert_called_once_with(endpoint="v1/testbed/testset_delete/test123") + + # Verify success message shown + assert mock_success.called + success_msg = mock_success.call_args[0][0] + assert "My Test Set" in success_msg + + # Verify reset_testset called with cache=True + mock_reset.assert_called_once_with(True) + + def test_qa_delete_api_error(self, monkeypatch): + """Test qa_delete when API call fails""" + from client.content import testbed + from client.utils import api_call + from streamlit import session_state as state + import streamlit as st + + # Setup state + state.testbed = {"testset_id": "test123", "testset_name": "My Test Set"} + + # Mock API call to raise error + def mock_delete(endpoint): + raise api_call.ApiError("Delete failed") + + monkeypatch.setattr(api_call, "delete", mock_delete) + + # Mock st.error + mock_error = MagicMock() + monkeypatch.setattr(st, "error", mock_error) + + # Call qa_delete - should handle error gracefully + testbed.qa_delete() + + # Verify error was logged + assert True # Function should complete without raising exception + + +############################################################################# +# Test update_record Function +############################################################################# +class TestUpdateRecord: + """Test update_record function""" + + def test_update_record_forward(self, monkeypatch): + """Test update_record with forward direction""" + # Mock st.fragment to be a no-op decorator BEFORE importing testbed + import streamlit as st + + monkeypatch.setattr(st, "fragment", lambda: lambda func: func) + + # Force reload of testbed module and all client.content modules to pick up the mocked decorator + modules_to_delete = [k for k in sys.modules if k.startswith("client.content")] + for mod in modules_to_delete: + del sys.modules[mod] + + from client.content import testbed + from streamlit import session_state as state + + # Setup state + state.testbed = {"qa_index": 0} + state.testbed_qa = [ + {"question": "Q1", "reference_answer": "A1", "reference_context": "", "metadata": ""}, + {"question": "Q2", "reference_answer": "A2", "reference_context": "", "metadata": ""}, + ] + state["selected_q_0"] = "Q1 Updated" + state["selected_a_0"] = "A1 Updated" + state["selected_c_0"] = "" + state["selected_m_0"] = "" + + # Call update_record with direction=1 (forward) + testbed.update_record(direction=1) + + # Verify record was updated + assert state.testbed_qa[0]["question"] == "Q1 Updated" + assert state.testbed_qa[0]["reference_answer"] == "A1 Updated" + + # Verify index moved forward + assert state.testbed["qa_index"] == 1 + + def test_update_record_backward(self, monkeypatch): + """Test update_record with backward direction""" + # Mock st.fragment to be a no-op decorator BEFORE importing testbed + import streamlit as st + + monkeypatch.setattr(st, "fragment", lambda: lambda func: func) + + # Force reload of testbed module and all client.content modules to pick up the mocked decorator + modules_to_delete = [k for k in sys.modules if k.startswith("client.content")] + for mod in modules_to_delete: + del sys.modules[mod] + + from client.content import testbed + from streamlit import session_state as state + + # Setup state + state.testbed = {"qa_index": 1} + state.testbed_qa = [ + {"question": "Q1", "reference_answer": "A1", "reference_context": "", "metadata": ""}, + {"question": "Q2", "reference_answer": "A2", "reference_context": "", "metadata": ""}, + ] + state["selected_q_1"] = "Q2 Updated" + state["selected_a_1"] = "A2 Updated" + state["selected_c_1"] = "" + state["selected_m_1"] = "" + + # Call update_record with direction=-1 (backward) + testbed.update_record(direction=-1) + + # Verify record was updated + assert state.testbed_qa[1]["question"] == "Q2 Updated" + assert state.testbed_qa[1]["reference_answer"] == "A2 Updated" + + # Verify index moved backward + assert state.testbed["qa_index"] == 0 + + def test_update_record_no_direction(self, monkeypatch): + """Test update_record with no direction (stays in place)""" + # Mock st.fragment to be a no-op decorator BEFORE importing testbed + import streamlit as st + + monkeypatch.setattr(st, "fragment", lambda: lambda func: func) + + # Force reload of testbed module and all client.content modules to pick up the mocked decorator + modules_to_delete = [k for k in sys.modules if k.startswith("client.content")] + for mod in modules_to_delete: + del sys.modules[mod] + + from client.content import testbed + from streamlit import session_state as state + + # Setup state + state.testbed = {"qa_index": 1} + state.testbed_qa = [ + {"question": "Q1", "reference_answer": "A1", "reference_context": "", "metadata": ""}, + {"question": "Q2", "reference_answer": "A2", "reference_context": "", "metadata": ""}, + ] + state["selected_q_1"] = "Q2 Modified" + state["selected_a_1"] = "A2 Modified" + state["selected_c_1"] = "" + state["selected_m_1"] = "" + + # Call update_record with direction=0 (no movement) + testbed.update_record(direction=0) + + # Verify record was updated + assert state.testbed_qa[1]["question"] == "Q2 Modified" + assert state.testbed_qa[1]["reference_answer"] == "A2 Modified" + + # Verify index stayed the same + assert state.testbed["qa_index"] == 1 + + +############################################################################# +# Test delete_record Function +############################################################################# +class TestDeleteRecord: + """Test delete_record function""" + + def test_delete_record_middle(self, monkeypatch): + """Test deleting a record from the middle""" + # Mock st.fragment to be a no-op decorator BEFORE importing testbed + import streamlit as st + + monkeypatch.setattr(st, "fragment", lambda: lambda func: func) + + # Force reload of testbed module and all client.content modules to pick up the mocked decorator + modules_to_delete = [k for k in sys.modules if k.startswith("client.content")] + for mod in modules_to_delete: + del sys.modules[mod] + + from client.content import testbed + from streamlit import session_state as state + + # Setup state with 3 records, index at 1 + state.testbed = {"qa_index": 1} + state.testbed_qa = [ + {"question": "Q1", "reference_answer": "A1"}, + {"question": "Q2", "reference_answer": "A2"}, + {"question": "Q3", "reference_answer": "A3"}, + ] + + # Delete record at index 1 + testbed.delete_record() + + # Verify record was deleted + assert len(state.testbed_qa) == 2 + assert state.testbed_qa[0]["question"] == "Q1" + assert state.testbed_qa[1]["question"] == "Q3" + + # Verify index stayed at 1 (still valid, now points to Q3) + assert state.testbed["qa_index"] == 1 + + def test_delete_record_first(self, monkeypatch): + """Test deleting the first record (index 0)""" + # Mock st.fragment to be a no-op decorator BEFORE importing testbed + import streamlit as st + + monkeypatch.setattr(st, "fragment", lambda: lambda func: func) + + # Force reload of testbed module and all client.content modules to pick up the mocked decorator + modules_to_delete = [k for k in sys.modules if k.startswith("client.content")] + for mod in modules_to_delete: + del sys.modules[mod] + + from client.content import testbed + from streamlit import session_state as state + + # Setup state with index at 0 + state.testbed = {"qa_index": 0} + state.testbed_qa = [ + {"question": "Q1", "reference_answer": "A1"}, + {"question": "Q2", "reference_answer": "A2"}, + ] + + # Delete record at index 0 + testbed.delete_record() + + # Verify record was deleted + assert len(state.testbed_qa) == 1 + assert state.testbed_qa[0]["question"] == "Q2" + + # Verify index stayed at 0 (doesn't go negative) + assert state.testbed["qa_index"] == 0 + + def test_delete_record_last(self, monkeypatch): + """Test deleting the last record""" + # Mock st.fragment to be a no-op decorator BEFORE importing testbed + import streamlit as st + + monkeypatch.setattr(st, "fragment", lambda: lambda func: func) + + # Force reload of testbed module and all client.content modules to pick up the mocked decorator + modules_to_delete = [k for k in sys.modules if k.startswith("client.content")] + for mod in modules_to_delete: + del sys.modules[mod] + + from client.content import testbed + from streamlit import session_state as state + + # Setup state with index at last position + state.testbed = {"qa_index": 2} + state.testbed_qa = [ + {"question": "Q1", "reference_answer": "A1"}, + {"question": "Q2", "reference_answer": "A2"}, + {"question": "Q3", "reference_answer": "A3"}, + ] + + # Delete record at index 2 + testbed.delete_record() + + # Verify record was deleted + assert len(state.testbed_qa) == 2 + + # Verify index moved back + assert state.testbed["qa_index"] == 1 + + +############################################################################# +# Test qa_update_gui Function +############################################################################# +class TestQAUpdateGUI: + """Test qa_update_gui function""" + + def test_qa_update_gui_multiple_records(self, monkeypatch): + """Test qa_update_gui with multiple records""" + from client.content import testbed + from streamlit import session_state as state + import streamlit as st + + # Setup state + state.testbed = {"qa_index": 1} + qa_testset = [ + {"question": "Q1", "reference_answer": "A1", "reference_context": "C1", "metadata": "M1"}, + {"question": "Q2", "reference_answer": "A2", "reference_context": "C2", "metadata": "M2"}, + {"question": "Q3", "reference_answer": "A3", "reference_context": "C3", "metadata": "M3"}, + ] + + # Mock streamlit functions + mock_write = MagicMock() + mock_columns = MagicMock(return_value=[MagicMock(), MagicMock(), MagicMock(), MagicMock()]) + mock_text_area = MagicMock() + mock_text_input = MagicMock() + + monkeypatch.setattr(st, "write", mock_write) + monkeypatch.setattr(st, "columns", mock_columns) + monkeypatch.setattr(st, "text_area", mock_text_area) + monkeypatch.setattr(st, "text_input", mock_text_input) + + # Call qa_update_gui + testbed.qa_update_gui(qa_testset) + + # Verify record counter was displayed + mock_write.assert_called_once() + assert "2/3" in mock_write.call_args[0][0] + + # Verify text areas were created + assert mock_text_area.call_count >= 3 # Question, Answer, Context + + def test_qa_update_gui_single_record(self, monkeypatch): + """Test qa_update_gui with single record (delete disabled)""" + from client.content import testbed + from streamlit import session_state as state + import streamlit as st + + # Setup state with single record + state.testbed = {"qa_index": 0} + qa_testset = [ + {"question": "Q1", "reference_answer": "A1", "reference_context": "C1", "metadata": "M1"}, + ] + + # Mock streamlit functions + mock_button_col = MagicMock() + mock_columns = MagicMock(return_value=[MagicMock(), MagicMock(), MagicMock(), mock_button_col]) + + monkeypatch.setattr(st, "write", MagicMock()) + monkeypatch.setattr(st, "columns", mock_columns) + monkeypatch.setattr(st, "text_area", MagicMock()) + monkeypatch.setattr(st, "text_input", MagicMock()) + + # Call qa_update_gui + testbed.qa_update_gui(qa_testset) + + # Verify delete button is disabled + delete_button_call = mock_button_col.button.call_args + assert delete_button_call[1]["disabled"] is True + + def test_qa_update_gui_navigation_buttons(self, monkeypatch): + """Test qa_update_gui navigation button states""" + from client.content import testbed + from streamlit import session_state as state + import streamlit as st + + # Setup state at first record + state.testbed = {"qa_index": 0} + qa_testset = [ + {"question": "Q1", "reference_answer": "A1", "reference_context": "C1", "metadata": "M1"}, + {"question": "Q2", "reference_answer": "A2", "reference_context": "C2", "metadata": "M2"}, + ] + + # Mock streamlit functions + prev_col = MagicMock() + next_col = MagicMock() + mock_columns = MagicMock(return_value=[prev_col, next_col, MagicMock(), MagicMock()]) + + monkeypatch.setattr(st, "write", MagicMock()) + monkeypatch.setattr(st, "columns", mock_columns) + monkeypatch.setattr(st, "text_area", MagicMock()) + monkeypatch.setattr(st, "text_input", MagicMock()) + + # Call qa_update_gui + testbed.qa_update_gui(qa_testset) + + # Verify Previous button is disabled at first record + prev_button_call = prev_col.button.call_args + assert prev_button_call[1]["disabled"] is True + + # Verify Next button is enabled + next_button_call = next_col.button.call_args + assert next_button_call[1]["disabled"] is False diff --git a/tests/unit/client/content/test_testbed_ui_unit.py b/tests/unit/client/content/test_testbed_ui_unit.py new file mode 100644 index 00000000..bf655ee9 --- /dev/null +++ b/tests/unit/client/content/test_testbed_ui_unit.py @@ -0,0 +1,174 @@ +# pylint: disable=protected-access,import-error,import-outside-toplevel +""" +Copyright (c) 2024, 2025, Oracle and/or its affiliates. +Licensed under the Universal Permissive License v1.0 as shown at http://oss.oracle.com/licenses/upl. + +Unit tests for testbed.py UI rendering functions. +Extracted from test_testbed_unit.py to reduce file size. +""" +# spell-checker: disable + +from unittest.mock import MagicMock + + +############################################################################# +# Test render_existing_testset_ui Function +############################################################################# +class TestRenderExistingTestsetUI: + """Test render_existing_testset_ui function""" + + def test_render_existing_testset_ui_database_with_selection(self, monkeypatch): + """Test render_existing_testset_ui correctly extracts testset_id when database test set is selected""" + from client.content import testbed + import streamlit as st + from streamlit import session_state as state + + # Mock session state + state.testbed_db_testsets = [ + {"tid": "test1", "name": "Test Set 1", "created": "2024-01-01 10:00:00"}, + {"tid": "test2", "name": "Test Set 2", "created": "2024-01-02 11:00:00"}, + ] + state.testbed = {"uploader_key": 1} + + # Mock streamlit components + mock_radio = MagicMock(return_value="Database") + mock_selectbox = MagicMock(return_value="Test Set 1 -- Created: 2024-01-01 10:00:00") + mock_file_uploader = MagicMock() + + monkeypatch.setattr(st, "radio", mock_radio) + monkeypatch.setattr(st, "selectbox", mock_selectbox) + monkeypatch.setattr(st, "file_uploader", mock_file_uploader) + + # Call the function + testset_sources = ["Database", "Local"] + source, endpoint, disabled, testset_id = testbed.render_existing_testset_ui(testset_sources) + + # Verify the return values + assert source == "Database", "Should return Database as source" + assert endpoint == "v1/testbed/testset_qa", "Should return correct endpoint for database" + assert disabled is False, "Button should not be disabled when test set is selected" + assert testset_id == "test1", f"Should extract correct testset_id 'test1', got {testset_id}" + + def test_render_existing_testset_ui_database_no_selection(self, monkeypatch): + """Test render_existing_testset_ui when no database test set is selected""" + from client.content import testbed + import streamlit as st + from streamlit import session_state as state + + # Mock session state + state.testbed_db_testsets = [ + {"tid": "test1", "name": "Test Set 1", "created": "2024-01-01 10:00:00"}, + ] + state.testbed = {"uploader_key": 1} + + # Mock streamlit components + mock_radio = MagicMock(return_value="Database") + mock_selectbox = MagicMock(return_value=None) # No selection + mock_file_uploader = MagicMock() + + monkeypatch.setattr(st, "radio", mock_radio) + monkeypatch.setattr(st, "selectbox", mock_selectbox) + monkeypatch.setattr(st, "file_uploader", mock_file_uploader) + + # Call the function + testset_sources = ["Database", "Local"] + source, endpoint, disabled, testset_id = testbed.render_existing_testset_ui(testset_sources) + + # Verify the return values + assert source == "Database", "Should return Database as source" + assert endpoint == "v1/testbed/testset_qa", "Should return correct endpoint" + assert disabled is True, "Button should be disabled when no test set is selected" + assert testset_id is None, "Should return None for testset_id when nothing selected" + + def test_render_existing_testset_ui_local_mode_no_files(self, monkeypatch): + """Test render_existing_testset_ui in Local mode with no files uploaded""" + from client.content import testbed + import streamlit as st + from streamlit import session_state as state + + # Mock session state + state.testbed = {"uploader_key": 1} + state.testbed_db_testsets = [] + + # Mock streamlit components + mock_radio = MagicMock(return_value="Local") + mock_selectbox = MagicMock() + mock_file_uploader = MagicMock(return_value=[]) # No files uploaded + + monkeypatch.setattr(st, "radio", mock_radio) + monkeypatch.setattr(st, "selectbox", mock_selectbox) + monkeypatch.setattr(st, "file_uploader", mock_file_uploader) + + # Call the function + testset_sources = ["Database", "Local"] + source, endpoint, disabled, testset_id = testbed.render_existing_testset_ui(testset_sources) + + # Verify the return values + assert source == "Local", "Should return Local as source" + assert endpoint == "v1/testbed/testset_load", "Should return correct endpoint for local" + assert disabled is True, "Button should be disabled when no files uploaded" + assert testset_id is None, "Should return None for testset_id in Local mode" + + def test_render_existing_testset_ui_local_mode_with_files(self, monkeypatch): + """Test render_existing_testset_ui in Local mode with files uploaded""" + from client.content import testbed + import streamlit as st + from streamlit import session_state as state + + # Mock session state + state.testbed = {"uploader_key": 1} + state.testbed_db_testsets = [] + + # Mock streamlit components + mock_radio = MagicMock(return_value="Local") + mock_selectbox = MagicMock() + mock_file_uploader = MagicMock(return_value=["file1.json", "file2.json"]) # Files uploaded + + monkeypatch.setattr(st, "radio", mock_radio) + monkeypatch.setattr(st, "selectbox", mock_selectbox) + monkeypatch.setattr(st, "file_uploader", mock_file_uploader) + + # Call the function + testset_sources = ["Database", "Local"] + source, endpoint, disabled, testset_id = testbed.render_existing_testset_ui(testset_sources) + + # Verify the return values + assert source == "Local", "Should return Local as source" + assert endpoint == "v1/testbed/testset_load", "Should return correct endpoint for local" + assert disabled is False, "Button should be enabled when files are uploaded" + assert testset_id is None, "Should return None for testset_id in Local mode" + + def test_render_existing_testset_ui_with_multiple_testsets(self, monkeypatch): + """Test render_existing_testset_ui correctly identifies testset when multiple exist with same name""" + from client.content import testbed + import streamlit as st + from streamlit import session_state as state + + # Mock session state with multiple test sets (some with same name) + state.testbed_db_testsets = [ + {"tid": "test1", "name": "Production Tests", "created": "2024-01-01 10:00:00"}, + { + "tid": "test2", + "name": "Production Tests", + "created": "2024-01-02 11:00:00", + }, # Same name, different date + {"tid": "test3", "name": "Dev Tests", "created": "2024-01-03 12:00:00"}, + ] + state.testbed = {"uploader_key": 1} + + # Mock streamlit components - select the second "Production Tests" + mock_radio = MagicMock(return_value="Database") + mock_selectbox = MagicMock(return_value="Production Tests -- Created: 2024-01-02 11:00:00") + mock_file_uploader = MagicMock() + + monkeypatch.setattr(st, "radio", mock_radio) + monkeypatch.setattr(st, "selectbox", mock_selectbox) + monkeypatch.setattr(st, "file_uploader", mock_file_uploader) + + # Call the function + testset_sources = ["Database", "Local"] + _, _, disabled, testset_id = testbed.render_existing_testset_ui(testset_sources) + + # Verify it extracted the correct testset_id (test2, not test1) + assert testset_id == "test2", f"Should extract 'test2' for second Production Tests, got {testset_id}" + assert disabled is False, "Button should not be disabled" diff --git a/tests/unit/client/content/test_testbed_unit.py b/tests/unit/client/content/test_testbed_unit.py index fabb8b9c..13beba3c 100644 --- a/tests/unit/client/content/test_testbed_unit.py +++ b/tests/unit/client/content/test_testbed_unit.py @@ -3,17 +3,19 @@ Copyright (c) 2024, 2025, Oracle and/or its affiliates. Licensed under the Universal Permissive License v1.0 as shown at http://oss.oracle.com/licenses/upl. -Additional tests for testbed.py to increase coverage from 36% to 85%+ +Unit tests for testbed.py evaluation_report function. + +Note: Other testbed tests are split across: +- test_testbed_records_unit.py: qa_update_db, qa_delete, update_record, delete_record, qa_update_gui +- test_testbed_ui_unit.py: render_existing_testset_ui """ # spell-checker: disable -import sys from unittest.mock import MagicMock import plotly.graph_objects as go - ############################################################################# # Test evaluation_report Function ############################################################################# @@ -61,6 +63,7 @@ def test_create_gauge_function(self, monkeypatch): # Reload testbed to apply the mock decorator import importlib + importlib.reload(testbed) mock_plotly_chart = MagicMock() @@ -124,6 +127,7 @@ def test_evaluation_report_with_eid(self, monkeypatch): # Reload testbed to apply the mock decorator import importlib + importlib.reload(testbed) mock_columns = MagicMock(return_value=[MagicMock(), MagicMock(), MagicMock()]) @@ -188,6 +192,7 @@ def test_evaluation_report_with_vector_search_enabled(self, monkeypatch): # Reload testbed to apply the mock decorator import importlib + importlib.reload(testbed) mock_markdown = MagicMock() @@ -255,6 +260,7 @@ def test_evaluation_report_with_mmr_search_type(self, monkeypatch): # Reload testbed to apply the mock decorator import importlib + importlib.reload(testbed) mock_dataframe = MagicMock() @@ -279,626 +285,140 @@ def test_evaluation_report_with_mmr_search_type(self, monkeypatch): ############################################################################# -# Test qa_update_db Function +# Test evaluation_report backward compatibility ############################################################################# -class TestQAUpdateDB: - """Test qa_update_db function""" - - def test_qa_update_db_success(self, monkeypatch): - """Test qa_update_db successfully updates database""" - from client.content import testbed - from client.utils import api_call, st_common - from streamlit import session_state as state - - # Setup state - state.testbed = {"testset_id": "test123", "qa_index": 0} - state.selected_new_testset_name = "Updated Test Set" - state.testbed_qa = [ - {"question": "Q1", "reference_answer": "A1"}, - {"question": "Q2", "reference_answer": "A2"}, - ] - state["selected_q_0"] = "Q1" - state["selected_a_0"] = "A1" - - # Mock API call - mock_post = MagicMock(return_value={"status": "success"}) - monkeypatch.setattr(api_call, "post", mock_post) - - # Mock get_testbed_db_testsets - mock_get_testsets = MagicMock(return_value={"testsets": []}) - testbed.get_testbed_db_testsets = mock_get_testsets - testbed.get_testbed_db_testsets.clear = MagicMock() - - # Mock clear_state_key - monkeypatch.setattr(st_common, "clear_state_key", MagicMock()) - - # Call qa_update_db - testbed.qa_update_db() - - # Verify API was called correctly - assert mock_post.called - call_args = mock_post.call_args - assert call_args[1]["endpoint"] == "v1/testbed/testset_load" - assert call_args[1]["params"]["name"] == "Updated Test Set" - assert call_args[1]["params"]["tid"] == "test123" - - def test_qa_update_db_clears_cache(self, monkeypatch): - """Test qa_update_db clears testbed cache""" - from client.content import testbed - from client.utils import api_call, st_common - from streamlit import session_state as state - - # Setup state - state.testbed = {"testset_id": "test123", "qa_index": 0} - state.selected_new_testset_name = "Test Set" - state.testbed_qa = [{"question": "Q1", "reference_answer": "A1"}] - state["selected_q_0"] = "Q1" - state["selected_a_0"] = "A1" - - # Mock functions - monkeypatch.setattr(api_call, "post", MagicMock()) - mock_clear_state = MagicMock() - monkeypatch.setattr(st_common, "clear_state_key", mock_clear_state) - - mock_clear_cache = MagicMock() - testbed.get_testbed_db_testsets = MagicMock(return_value={"testsets": []}) - testbed.get_testbed_db_testsets.clear = mock_clear_cache - - # Call qa_update_db - testbed.qa_update_db() - - # Verify cache was cleared - mock_clear_state.assert_called_with("testbed_db_testsets") - mock_clear_cache.assert_called_once() - - -############################################################################# -# Test qa_delete Function -############################################################################# -class TestQADelete: - """Test qa_delete function""" - - def test_qa_delete_success(self, monkeypatch): - """Test qa_delete successfully deletes testset""" - from client.content import testbed - from client.utils import api_call - from streamlit import session_state as state - import streamlit as st - - # Setup state - state.testbed = { - "testset_id": "test123", - "testset_name": "My Test Set" - } - - # Mock API call - mock_delete = MagicMock() - monkeypatch.setattr(api_call, "delete", mock_delete) - - # Mock reset_testset - mock_reset = MagicMock() - monkeypatch.setattr(testbed, "reset_testset", mock_reset) +class TestEvaluationReportBackwardCompatibility: + """Test evaluation_report backward compatibility when vector_search.enabled is missing""" - # Mock st.success - mock_success = MagicMock() - monkeypatch.setattr(st, "success", mock_success) - - # Call qa_delete - testbed.qa_delete() - - # Verify delete was called - mock_delete.assert_called_once_with(endpoint="v1/testbed/testset_delete/test123") - - # Verify success message shown - assert mock_success.called - success_msg = mock_success.call_args[0][0] - assert "My Test Set" in success_msg - - # Verify reset_testset called with cache=True - mock_reset.assert_called_once_with(True) - - def test_qa_delete_api_error(self, monkeypatch): - """Test qa_delete when API call fails""" + def test_evaluation_report_fallback_to_tools_enabled(self, monkeypatch): + """Test evaluation_report falls back to tools_enabled when vector_search.enabled is missing""" from client.content import testbed - from client.utils import api_call - from streamlit import session_state as state - import streamlit as st - # Setup state - state.testbed = { - "testset_id": "test123", - "testset_name": "My Test Set" + # Create report WITHOUT vector_search.enabled but WITH tools_enabled containing "Vector Search" + mock_report = { + "settings": { + "ll_model": { + "model": "gpt-4", + "temperature": 0.7, + "streaming": False, + "chat_history": False, + "max_input_tokens": 1000, + "max_tokens": 500, + }, + "testbed": {"judge_model": None}, + "tools_enabled": ["Vector Search"], # Vector Search enabled via tools_enabled + "database": {"alias": "DEFAULT"}, + "vector_search": { + # NO "enabled" key - tests backward compatibility + "vector_store": "my_vs", + "alias": "my_alias", + "search_type": "Similarity", + "score_threshold": 0.7, + "fetch_k": 10, + "lambda_mult": 0.5, + "top_k": 5, + "grading": True, + }, + }, + "correctness": 0.85, + "correct_by_topic": [], + "failures": [], + "report": [], } - # Mock API call to raise error - def mock_delete(endpoint): - raise api_call.ApiError("Delete failed") - - monkeypatch.setattr(api_call, "delete", mock_delete) - - # Mock st.error - mock_error = MagicMock() - monkeypatch.setattr(st, "error", mock_error) - - # Call qa_delete - should handle error gracefully - testbed.qa_delete() - - # Verify error was logged - assert True # Function should complete without raising exception - - -############################################################################# -# Test update_record Function -############################################################################# -class TestUpdateRecord: - """Test update_record function""" - - def test_update_record_forward(self, monkeypatch): - """Test update_record with forward direction""" - # Mock st.fragment to be a no-op decorator BEFORE importing testbed - import streamlit as st - monkeypatch.setattr(st, "fragment", lambda: lambda func: func) - - # Force reload of testbed module and all client.content modules to pick up the mocked decorator - modules_to_delete = [k for k in sys.modules if k.startswith('client.content')] - for mod in modules_to_delete: - del sys.modules[mod] - - from client.content import testbed - from streamlit import session_state as state - - # Setup state - state.testbed = {"qa_index": 0} - state.testbed_qa = [ - {"question": "Q1", "reference_answer": "A1", "reference_context": "", "metadata": ""}, - {"question": "Q2", "reference_answer": "A2", "reference_context": "", "metadata": ""}, - ] - state["selected_q_0"] = "Q1 Updated" - state["selected_a_0"] = "A1 Updated" - state["selected_c_0"] = "" - state["selected_m_0"] = "" - - # Call update_record with direction=1 (forward) - testbed.update_record(direction=1) - - # Verify record was updated - assert state.testbed_qa[0]["question"] == "Q1 Updated" - assert state.testbed_qa[0]["reference_answer"] == "A1 Updated" - - # Verify index moved forward - assert state.testbed["qa_index"] == 1 - - def test_update_record_backward(self, monkeypatch): - """Test update_record with backward direction""" - # Mock st.fragment to be a no-op decorator BEFORE importing testbed - import streamlit as st - monkeypatch.setattr(st, "fragment", lambda: lambda func: func) - - # Force reload of testbed module and all client.content modules to pick up the mocked decorator - modules_to_delete = [k for k in sys.modules if k.startswith('client.content')] - for mod in modules_to_delete: - del sys.modules[mod] - - from client.content import testbed - from streamlit import session_state as state - - # Setup state - state.testbed = {"qa_index": 1} - state.testbed_qa = [ - {"question": "Q1", "reference_answer": "A1", "reference_context": "", "metadata": ""}, - {"question": "Q2", "reference_answer": "A2", "reference_context": "", "metadata": ""}, - ] - state["selected_q_1"] = "Q2 Updated" - state["selected_a_1"] = "A2 Updated" - state["selected_c_1"] = "" - state["selected_m_1"] = "" - - # Call update_record with direction=-1 (backward) - testbed.update_record(direction=-1) - - # Verify record was updated - assert state.testbed_qa[1]["question"] == "Q2 Updated" - assert state.testbed_qa[1]["reference_answer"] == "A2 Updated" - - # Verify index moved backward - assert state.testbed["qa_index"] == 0 - - def test_update_record_no_direction(self, monkeypatch): - """Test update_record with no direction (stays in place)""" - # Mock st.fragment to be a no-op decorator BEFORE importing testbed - import streamlit as st - monkeypatch.setattr(st, "fragment", lambda: lambda func: func) - - # Force reload of testbed module and all client.content modules to pick up the mocked decorator - modules_to_delete = [k for k in sys.modules if k.startswith('client.content')] - for mod in modules_to_delete: - del sys.modules[mod] - - from client.content import testbed - from streamlit import session_state as state - - # Setup state - state.testbed = {"qa_index": 1} - state.testbed_qa = [ - {"question": "Q1", "reference_answer": "A1", "reference_context": "", "metadata": ""}, - {"question": "Q2", "reference_answer": "A2", "reference_context": "", "metadata": ""}, - ] - state["selected_q_1"] = "Q2 Modified" - state["selected_a_1"] = "A2 Modified" - state["selected_c_1"] = "" - state["selected_m_1"] = "" - - # Call update_record with direction=0 (no movement) - testbed.update_record(direction=0) - - # Verify record was updated - assert state.testbed_qa[1]["question"] == "Q2 Modified" - assert state.testbed_qa[1]["reference_answer"] == "A2 Modified" - - # Verify index stayed the same - assert state.testbed["qa_index"] == 1 - - -############################################################################# -# Test delete_record Function -############################################################################# -class TestDeleteRecord: - """Test delete_record function""" - - def test_delete_record_middle(self, monkeypatch): - """Test deleting a record from the middle""" - # Mock st.fragment to be a no-op decorator BEFORE importing testbed - import streamlit as st - monkeypatch.setattr(st, "fragment", lambda: lambda func: func) - - # Force reload of testbed module and all client.content modules to pick up the mocked decorator - modules_to_delete = [k for k in sys.modules if k.startswith('client.content')] - for mod in modules_to_delete: - del sys.modules[mod] - - from client.content import testbed - from streamlit import session_state as state - - # Setup state with 3 records, index at 1 - state.testbed = {"qa_index": 1} - state.testbed_qa = [ - {"question": "Q1", "reference_answer": "A1"}, - {"question": "Q2", "reference_answer": "A2"}, - {"question": "Q3", "reference_answer": "A3"}, - ] - - # Delete record at index 1 - testbed.delete_record() - - # Verify record was deleted - assert len(state.testbed_qa) == 2 - assert state.testbed_qa[0]["question"] == "Q1" - assert state.testbed_qa[1]["question"] == "Q3" - - # Verify index stayed at 1 (still valid, now points to Q3) - assert state.testbed["qa_index"] == 1 - - def test_delete_record_first(self, monkeypatch): - """Test deleting the first record (index 0)""" - # Mock st.fragment to be a no-op decorator BEFORE importing testbed - import streamlit as st - monkeypatch.setattr(st, "fragment", lambda: lambda func: func) - - # Force reload of testbed module and all client.content modules to pick up the mocked decorator - modules_to_delete = [k for k in sys.modules if k.startswith('client.content')] - for mod in modules_to_delete: - del sys.modules[mod] - - from client.content import testbed - from streamlit import session_state as state - - # Setup state with index at 0 - state.testbed = {"qa_index": 0} - state.testbed_qa = [ - {"question": "Q1", "reference_answer": "A1"}, - {"question": "Q2", "reference_answer": "A2"}, - ] - - # Delete record at index 0 - testbed.delete_record() - - # Verify record was deleted - assert len(state.testbed_qa) == 1 - assert state.testbed_qa[0]["question"] == "Q2" - - # Verify index stayed at 0 (doesn't go negative) - assert state.testbed["qa_index"] == 0 - - def test_delete_record_last(self, monkeypatch): - """Test deleting the last record""" - # Mock st.fragment to be a no-op decorator BEFORE importing testbed + # Mock streamlit functions import streamlit as st - monkeypatch.setattr(st, "fragment", lambda: lambda func: func) - - # Force reload of testbed module and all client.content modules to pick up the mocked decorator - modules_to_delete = [k for k in sys.modules if k.startswith('client.content')] - for mod in modules_to_delete: - del sys.modules[mod] - - from client.content import testbed - from streamlit import session_state as state - # Setup state with index at last position - state.testbed = {"qa_index": 2} - state.testbed_qa = [ - {"question": "Q1", "reference_answer": "A1"}, - {"question": "Q2", "reference_answer": "A2"}, - {"question": "Q3", "reference_answer": "A3"}, - ] - - # Delete record at index 2 - testbed.delete_record() - - # Verify record was deleted - assert len(state.testbed_qa) == 2 - - # Verify index moved back - assert state.testbed["qa_index"] == 1 - - -############################################################################# -# Test qa_update_gui Function -############################################################################# -class TestQAUpdateGUI: - """Test qa_update_gui function""" + # Mock st.dialog decorator + original_dialog = st.dialog + mock_dialog = MagicMock(side_effect=lambda *args, **kwargs: lambda f: f) + monkeypatch.setattr(st, "dialog", mock_dialog) - def test_qa_update_gui_multiple_records(self, monkeypatch): - """Test qa_update_gui with multiple records""" - from client.content import testbed - from streamlit import session_state as state - import streamlit as st + # Reload testbed to apply the mock decorator + import importlib - # Setup state - state.testbed = {"qa_index": 1} - qa_testset = [ - {"question": "Q1", "reference_answer": "A1", "reference_context": "C1", "metadata": "M1"}, - {"question": "Q2", "reference_answer": "A2", "reference_context": "C2", "metadata": "M2"}, - {"question": "Q3", "reference_answer": "A3", "reference_context": "C3", "metadata": "M3"}, - ] + importlib.reload(testbed) - # Mock streamlit functions - mock_write = MagicMock() - mock_columns = MagicMock(return_value=[MagicMock(), MagicMock(), MagicMock(), MagicMock()]) - mock_text_area = MagicMock() - mock_text_input = MagicMock() + mock_markdown = MagicMock() + mock_columns = MagicMock(return_value=[MagicMock(), MagicMock(), MagicMock()]) - monkeypatch.setattr(st, "write", mock_write) + monkeypatch.setattr(st, "markdown", mock_markdown) + monkeypatch.setattr(st, "plotly_chart", MagicMock()) + monkeypatch.setattr(st, "subheader", MagicMock()) + monkeypatch.setattr(st, "dataframe", MagicMock()) monkeypatch.setattr(st, "columns", mock_columns) - monkeypatch.setattr(st, "text_area", mock_text_area) - monkeypatch.setattr(st, "text_input", mock_text_input) - - # Call qa_update_gui - testbed.qa_update_gui(qa_testset) - - # Verify record counter was displayed - mock_write.assert_called_once() - assert "2/3" in mock_write.call_args[0][0] - - # Verify text areas were created - assert mock_text_area.call_count >= 3 # Question, Answer, Context - - def test_qa_update_gui_single_record(self, monkeypatch): - """Test qa_update_gui with single record (delete disabled)""" - from client.content import testbed - from streamlit import session_state as state - import streamlit as st - # Setup state with single record - state.testbed = {"qa_index": 0} - qa_testset = [ - {"question": "Q1", "reference_answer": "A1", "reference_context": "C1", "metadata": "M1"}, - ] - - # Mock streamlit functions - mock_button_col = MagicMock() - mock_columns = MagicMock(return_value=[MagicMock(), MagicMock(), MagicMock(), mock_button_col]) - - monkeypatch.setattr(st, "write", MagicMock()) - monkeypatch.setattr(st, "columns", mock_columns) - monkeypatch.setattr(st, "text_area", MagicMock()) - monkeypatch.setattr(st, "text_input", MagicMock()) + # Call evaluation_report - should NOT raise KeyError + testbed.evaluation_report(report=mock_report) - # Call qa_update_gui - testbed.qa_update_gui(qa_testset) + # Verify vector search info was displayed (backward compatibility worked) + calls = [str(call) for call in mock_markdown.call_args_list] + assert any("DEFAULT" in str(call) for call in calls), "Should display database info for vector search" + assert any("my_vs" in str(call) for call in calls), "Should display vector store info" - # Verify delete button is disabled - delete_button_call = mock_button_col.button.call_args - assert delete_button_call[1]["disabled"] is True + # Restore original dialog decorator and reload + monkeypatch.setattr(st, "dialog", original_dialog) + importlib.reload(testbed) - def test_qa_update_gui_navigation_buttons(self, monkeypatch): - """Test qa_update_gui navigation button states""" + def test_evaluation_report_fallback_vs_not_in_tools(self, monkeypatch): + """Test evaluation_report shows 'without Vector Search' when tools_enabled doesn't contain Vector Search""" from client.content import testbed - from streamlit import session_state as state - import streamlit as st - # Setup state at first record - state.testbed = {"qa_index": 0} - qa_testset = [ - {"question": "Q1", "reference_answer": "A1", "reference_context": "C1", "metadata": "M1"}, - {"question": "Q2", "reference_answer": "A2", "reference_context": "C2", "metadata": "M2"}, - ] + # Create report WITHOUT vector_search.enabled and WITHOUT Vector Search in tools_enabled + mock_report = { + "settings": { + "ll_model": { + "model": "gpt-4", + "temperature": 0.7, + "streaming": False, + "chat_history": False, + "max_input_tokens": 1000, + "max_tokens": 500, + }, + "testbed": {"judge_model": None}, + "tools_enabled": ["Other Tool"], # Vector Search NOT in tools_enabled + "vector_search": { + # NO "enabled" key - tests backward compatibility + "vector_store": "my_vs", + }, + }, + "correctness": 0.85, + "correct_by_topic": [], + "failures": [], + "report": [], + } # Mock streamlit functions - prev_col = MagicMock() - next_col = MagicMock() - mock_columns = MagicMock(return_value=[prev_col, next_col, MagicMock(), MagicMock()]) - - monkeypatch.setattr(st, "write", MagicMock()) - monkeypatch.setattr(st, "columns", mock_columns) - monkeypatch.setattr(st, "text_area", MagicMock()) - monkeypatch.setattr(st, "text_input", MagicMock()) - - # Call qa_update_gui - testbed.qa_update_gui(qa_testset) - - # Verify Previous button is disabled at first record - prev_button_call = prev_col.button.call_args - assert prev_button_call[1]["disabled"] is True - - # Verify Next button is enabled - next_button_call = next_col.button.call_args - assert next_button_call[1]["disabled"] is False - - -############################################################################# -# Test render_existing_testset_ui Function -############################################################################# -class TestRenderExistingTestsetUI: - """Test render_existing_testset_ui function""" - - def test_render_existing_testset_ui_database_with_selection(self, monkeypatch): - """Test render_existing_testset_ui correctly extracts testset_id when database test set is selected""" - from client.content import testbed - import streamlit as st - from streamlit import session_state as state - - # Mock session state - state.testbed_db_testsets = [ - {"tid": "test1", "name": "Test Set 1", "created": "2024-01-01 10:00:00"}, - {"tid": "test2", "name": "Test Set 2", "created": "2024-01-02 11:00:00"}, - ] - state.testbed = {"uploader_key": 1} - - # Mock streamlit components - mock_radio = MagicMock(return_value="Database") - mock_selectbox = MagicMock(return_value="Test Set 1 -- Created: 2024-01-01 10:00:00") - mock_file_uploader = MagicMock() - - monkeypatch.setattr(st, "radio", mock_radio) - monkeypatch.setattr(st, "selectbox", mock_selectbox) - monkeypatch.setattr(st, "file_uploader", mock_file_uploader) - - # Call the function - testset_sources = ["Database", "Local"] - source, endpoint, disabled, testset_id = testbed.render_existing_testset_ui(testset_sources) - - # Verify the return values - assert source == "Database", "Should return Database as source" - assert endpoint == "v1/testbed/testset_qa", "Should return correct endpoint for database" - assert disabled is False, "Button should not be disabled when test set is selected" - assert testset_id == "test1", f"Should extract correct testset_id 'test1', got {testset_id}" - - def test_render_existing_testset_ui_database_no_selection(self, monkeypatch): - """Test render_existing_testset_ui when no database test set is selected""" - from client.content import testbed - import streamlit as st - from streamlit import session_state as state - - # Mock session state - state.testbed_db_testsets = [ - {"tid": "test1", "name": "Test Set 1", "created": "2024-01-01 10:00:00"}, - ] - state.testbed = {"uploader_key": 1} - - # Mock streamlit components - mock_radio = MagicMock(return_value="Database") - mock_selectbox = MagicMock(return_value=None) # No selection - mock_file_uploader = MagicMock() - - monkeypatch.setattr(st, "radio", mock_radio) - monkeypatch.setattr(st, "selectbox", mock_selectbox) - monkeypatch.setattr(st, "file_uploader", mock_file_uploader) - - # Call the function - testset_sources = ["Database", "Local"] - source, endpoint, disabled, testset_id = testbed.render_existing_testset_ui(testset_sources) - - # Verify the return values - assert source == "Database", "Should return Database as source" - assert endpoint == "v1/testbed/testset_qa", "Should return correct endpoint" - assert disabled is True, "Button should be disabled when no test set is selected" - assert testset_id is None, "Should return None for testset_id when nothing selected" - - def test_render_existing_testset_ui_local_mode_no_files(self, monkeypatch): - """Test render_existing_testset_ui in Local mode with no files uploaded""" - from client.content import testbed import streamlit as st - from streamlit import session_state as state - # Mock session state - state.testbed = {"uploader_key": 1} - state.testbed_db_testsets = [] - - # Mock streamlit components - mock_radio = MagicMock(return_value="Local") - mock_selectbox = MagicMock() - mock_file_uploader = MagicMock(return_value=[]) # No files uploaded - - monkeypatch.setattr(st, "radio", mock_radio) - monkeypatch.setattr(st, "selectbox", mock_selectbox) - monkeypatch.setattr(st, "file_uploader", mock_file_uploader) - - # Call the function - testset_sources = ["Database", "Local"] - source, endpoint, disabled, testset_id = testbed.render_existing_testset_ui(testset_sources) - - # Verify the return values - assert source == "Local", "Should return Local as source" - assert endpoint == "v1/testbed/testset_load", "Should return correct endpoint for local" - assert disabled is True, "Button should be disabled when no files uploaded" - assert testset_id is None, "Should return None for testset_id in Local mode" + # Mock st.dialog decorator + original_dialog = st.dialog + mock_dialog = MagicMock(side_effect=lambda *args, **kwargs: lambda f: f) + monkeypatch.setattr(st, "dialog", mock_dialog) - def test_render_existing_testset_ui_local_mode_with_files(self, monkeypatch): - """Test render_existing_testset_ui in Local mode with files uploaded""" - from client.content import testbed - import streamlit as st - from streamlit import session_state as state + # Reload testbed to apply the mock decorator + import importlib - # Mock session state - state.testbed = {"uploader_key": 1} - state.testbed_db_testsets = [] + importlib.reload(testbed) - # Mock streamlit components - mock_radio = MagicMock(return_value="Local") - mock_selectbox = MagicMock() - mock_file_uploader = MagicMock(return_value=["file1.json", "file2.json"]) # Files uploaded + mock_markdown = MagicMock() + mock_columns = MagicMock(return_value=[MagicMock(), MagicMock(), MagicMock()]) - monkeypatch.setattr(st, "radio", mock_radio) - monkeypatch.setattr(st, "selectbox", mock_selectbox) - monkeypatch.setattr(st, "file_uploader", mock_file_uploader) + monkeypatch.setattr(st, "markdown", mock_markdown) + monkeypatch.setattr(st, "plotly_chart", MagicMock()) + monkeypatch.setattr(st, "subheader", MagicMock()) + monkeypatch.setattr(st, "dataframe", MagicMock()) + monkeypatch.setattr(st, "columns", mock_columns) - # Call the function - testset_sources = ["Database", "Local"] - source, endpoint, disabled, testset_id = testbed.render_existing_testset_ui(testset_sources) + # Call evaluation_report - should NOT raise KeyError + testbed.evaluation_report(report=mock_report) - # Verify the return values - assert source == "Local", "Should return Local as source" - assert endpoint == "v1/testbed/testset_load", "Should return correct endpoint for local" - assert disabled is False, "Button should be enabled when files are uploaded" - assert testset_id is None, "Should return None for testset_id in Local mode" + # Verify "without Vector Search" message was displayed + calls = [str(call) for call in mock_markdown.call_args_list] + assert any("without Vector Search" in str(call) for call in calls), ( + "Should display 'without Vector Search' when VS not enabled" + ) - def test_render_existing_testset_ui_with_multiple_testsets(self, monkeypatch): - """Test render_existing_testset_ui correctly identifies testset when multiple exist with same name""" - from client.content import testbed - import streamlit as st - from streamlit import session_state as state - - # Mock session state with multiple test sets (some with same name) - state.testbed_db_testsets = [ - {"tid": "test1", "name": "Production Tests", "created": "2024-01-01 10:00:00"}, - {"tid": "test2", "name": "Production Tests", "created": "2024-01-02 11:00:00"}, # Same name, different date - {"tid": "test3", "name": "Dev Tests", "created": "2024-01-03 12:00:00"}, - ] - state.testbed = {"uploader_key": 1} - - # Mock streamlit components - select the second "Production Tests" - mock_radio = MagicMock(return_value="Database") - mock_selectbox = MagicMock(return_value="Production Tests -- Created: 2024-01-02 11:00:00") - mock_file_uploader = MagicMock() - - monkeypatch.setattr(st, "radio", mock_radio) - monkeypatch.setattr(st, "selectbox", mock_selectbox) - monkeypatch.setattr(st, "file_uploader", mock_file_uploader) - - # Call the function - testset_sources = ["Database", "Local"] - _, _, disabled, testset_id = testbed.render_existing_testset_ui(testset_sources) - - # Verify it extracted the correct testset_id (test2, not test1) - assert testset_id == "test2", f"Should extract 'test2' for second Production Tests, got {testset_id}" - assert disabled is False, "Button should not be disabled" + # Restore original dialog decorator and reload + monkeypatch.setattr(st, "dialog", original_dialog) + importlib.reload(testbed) diff --git a/tests/unit/client/content/tools/tabs/test_split_embed_unit.py b/tests/unit/client/content/tools/tabs/test_split_embed_unit.py index fcfd6b9d..9a4c813d 100644 --- a/tests/unit/client/content/tools/tabs/test_split_embed_unit.py +++ b/tests/unit/client/content/tools/tabs/test_split_embed_unit.py @@ -124,6 +124,24 @@ def test_get_buckets_success(self, monkeypatch): assert isinstance(result, list) assert len(result) == 3 + def test_get_buckets_api_error(self, monkeypatch): + """Test get_buckets function when API call fails""" + from client.content.tools.tabs.split_embed import get_buckets + from client.utils import api_call + from client.utils.api_call import ApiError + from streamlit import session_state as state + + # Setup state with OCI config + state.client_settings = {"oci": {"auth_profile": "DEFAULT"}} + + def mock_get_with_error(endpoint): + raise ApiError("Access denied") + + monkeypatch.setattr(api_call, "get", mock_get_with_error) + + result = get_buckets("test-compartment") + assert result == ["No Access to Buckets in this Compartment"] + def test_get_bucket_objects_success(self, monkeypatch): """Test get_bucket_objects with successful API call""" from client.content.tools.tabs.split_embed import get_bucket_objects @@ -200,6 +218,24 @@ def test_files_data_frame_with_process(self): assert "Process" in result.columns assert bool(result["Process"][0]) is True + def test_files_data_editor(self, monkeypatch): + """Test files_data_editor function""" + from client.content.tools.tabs.split_embed import files_data_editor + import streamlit as st + + # Create test dataframe + test_df = pd.DataFrame({"File": ["file1.txt", "file2.txt"], "Process": [True, False]}) + + # Mock st.data_editor to return the input data + def mock_data_editor(data, **_kwargs): + return data + + monkeypatch.setattr(st, "data_editor", mock_data_editor) + + result = files_data_editor(test_df, key="test_key") + assert isinstance(result, pd.DataFrame) + assert len(result) == 2 + ############################################################################# # Test Chunk Size/Overlap Functions diff --git a/tests/unit/server/api/conftest.py b/tests/unit/server/api/conftest.py index 4a951486..9088301c 100644 --- a/tests/unit/server/api/conftest.py +++ b/tests/unit/server/api/conftest.py @@ -15,12 +15,6 @@ from unittest.mock import MagicMock, AsyncMock import pytest - -from common.schema import ( - DatabaseAuth, - DatabaseVectorStorage, - ChatRequest, -) # Import constants needed by fixtures in this file from shared_fixtures import ( TEST_DB_USER, @@ -28,6 +22,12 @@ TEST_DB_DSN, ) +from common.schema import ( + DatabaseAuth, + DatabaseVectorStorage, + ChatRequest, +) + @pytest.fixture def make_database_auth(): diff --git a/tests/unit/server/api/utils/test_utils_chat.py b/tests/unit/server/api/utils/test_utils_chat.py index 22f952a6..cb3b6dcc 100644 --- a/tests/unit/server/api/utils/test_utils_chat.py +++ b/tests/unit/server/api/utils/test_utils_chat.py @@ -298,15 +298,3 @@ async def mock_astream(**kwargs): pass assert captured_kwargs["config"]["metadata"]["streaming"] is True - - -class TestLoggerConfiguration: - """Tests for logger configuration.""" - - def test_logger_exists(self): - """Logger should be configured.""" - assert hasattr(utils_chat, "logger") - - def test_logger_name(self): - """Logger should have correct name.""" - assert utils_chat.logger.name == "api.utils.chat" diff --git a/tests/unit/server/api/utils/test_utils_databases.py b/tests/unit/server/api/utils/test_utils_databases.py index 5a2afb92..9e84f1c2 100644 --- a/tests/unit/server/api/utils/test_utils_databases.py +++ b/tests/unit/server/api/utils/test_utils_databases.py @@ -16,12 +16,12 @@ import pytest import oracledb +from db_fixtures import TEST_DB_CONFIG +from shared_fixtures import TEST_DB_WALLET_PASSWORD from common.schema import DatabaseSettings from server.api.utils import databases as utils_databases from server.api.utils.databases import DbException, ExistsDatabaseError, UnknownDatabaseError -from db_fixtures import TEST_DB_CONFIG -from shared_fixtures import TEST_DB_WALLET_PASSWORD class TestDbException: @@ -645,15 +645,3 @@ def test_get_vs_parses_genai_comment(self, db_connection): except oracledb.DatabaseError: pass cursor.close() - - -class TestLoggerConfiguration: - """Tests for logger configuration.""" - - def test_logger_exists(self): - """Logger should be configured.""" - assert hasattr(utils_databases, "logger") - - def test_logger_name(self): - """Logger should have correct name.""" - assert utils_databases.logger.name == "api.utils.database" diff --git a/tests/unit/server/api/utils/test_utils_embed.py b/tests/unit/server/api/utils/test_utils_embed.py index dbaf2c4d..968e7f09 100644 --- a/tests/unit/server/api/utils/test_utils_embed.py +++ b/tests/unit/server/api/utils/test_utils_embed.py @@ -791,15 +791,3 @@ def test_get_document_loader_txt(self, tmp_path): _, split = utils_embed._get_document_loader(str(test_file), "txt") assert split is True - - -class TestLoggerConfiguration: - """Tests for logger configuration.""" - - def test_logger_exists(self): - """Logger should be configured.""" - assert hasattr(utils_embed, "logger") - - def test_logger_name(self): - """Logger should have correct name.""" - assert utils_embed.logger.name == "api.utils.embed" diff --git a/tests/unit/server/api/utils/test_utils_mcp.py b/tests/unit/server/api/utils/test_utils_mcp.py index 93169c25..901e9930 100644 --- a/tests/unit/server/api/utils/test_utils_mcp.py +++ b/tests/unit/server/api/utils/test_utils_mcp.py @@ -10,9 +10,9 @@ from unittest.mock import AsyncMock, MagicMock, patch import pytest +from shared_fixtures import TEST_API_KEY, TEST_API_KEY_ALT from server.api.utils import mcp -from shared_fixtures import TEST_API_KEY, TEST_API_KEY_ALT class TestGetClient: @@ -180,15 +180,3 @@ async def test_list_prompts_closes_client_on_exception(self, mock_client_class): await mcp.list_prompts(mock_mcp_engine) mock_client.close.assert_called_once() - - -class TestLoggerConfiguration: - """Tests for logger configuration.""" - - def test_logger_exists(self): - """Logger should be configured.""" - assert hasattr(mcp, "logger") - - def test_logger_name(self): - """Logger should have correct name.""" - assert mcp.logger.name == "api.utils.mcp" diff --git a/tests/unit/server/api/utils/test_utils_models.py b/tests/unit/server/api/utils/test_utils_models.py index 8616ca9e..2bb5b2b4 100644 --- a/tests/unit/server/api/utils/test_utils_models.py +++ b/tests/unit/server/api/utils/test_utils_models.py @@ -419,15 +419,3 @@ def test_process_model_entry_handles_exception(self, mock_litellm): result = utils_models._process_model_entry("bad-model", type_to_modes, allowed_modes, "openai") assert result == {"key": "bad-model"} - - -class TestLoggerConfiguration: - """Tests for logger configuration.""" - - def test_logger_exists(self): - """Logger should be configured.""" - assert hasattr(utils_models, "logger") - - def test_logger_name(self): - """Logger should have correct name.""" - assert utils_models.logger.name == "api.utils.models" diff --git a/tests/unit/server/api/utils/test_utils_module_config.py b/tests/unit/server/api/utils/test_utils_module_config.py new file mode 100644 index 00000000..f50a51f0 --- /dev/null +++ b/tests/unit/server/api/utils/test_utils_module_config.py @@ -0,0 +1,47 @@ +""" +Copyright (c) 2024, 2025, Oracle and/or its affiliates. +Licensed under the Universal Permissive License v1.0 as shown at http://oss.oracle.com/licenses/upl. + +Consolidated tests for API utils module configuration (loggers). +These parameterized tests replace individual boilerplate tests in each module file. +""" + +import pytest + +from server.api.utils import chat as utils_chat +from server.api.utils import databases as utils_databases +from server.api.utils import embed as utils_embed +from server.api.utils import mcp +from server.api.utils import models as utils_models +from server.api.utils import oci as utils_oci +from server.api.utils import settings as utils_settings +from server.api.utils import testbed as utils_testbed + + +# Module configurations for parameterized tests +API_UTILS_MODULES = [ + pytest.param(utils_chat, "api.utils.chat", id="chat"), + pytest.param(utils_databases, "api.utils.database", id="databases"), + pytest.param(utils_embed, "api.utils.embed", id="embed"), + pytest.param(mcp, "api.utils.mcp", id="mcp"), + pytest.param(utils_models, "api.utils.models", id="models"), + pytest.param(utils_oci, "api.utils.oci", id="oci"), + pytest.param(utils_settings, "api.core.settings", id="settings"), + pytest.param(utils_testbed, "api.utils.testbed", id="testbed"), +] + + +class TestLoggerConfiguration: + """Parameterized tests for logger configuration across all API utils modules.""" + + @pytest.mark.parametrize("module,_logger_name", API_UTILS_MODULES) + def test_logger_exists(self, module, _logger_name): + """Each API utils module should have a logger configured.""" + assert hasattr(module, "logger"), f"{module.__name__} should have 'logger'" + + @pytest.mark.parametrize("module,expected_name", API_UTILS_MODULES) + def test_logger_name(self, module, expected_name): + """Each API utils module logger should have the correct name.""" + assert module.logger.name == expected_name, ( + f"{module.__name__} logger name should be '{expected_name}', got '{module.logger.name}'" + ) diff --git a/tests/unit/server/api/utils/test_utils_oci.py b/tests/unit/server/api/utils/test_utils_oci.py index f59a855f..4473eec1 100644 --- a/tests/unit/server/api/utils/test_utils_oci.py +++ b/tests/unit/server/api/utils/test_utils_oci.py @@ -804,15 +804,3 @@ def test_get_raises_when_derived_profile_not_found(self, mock_oci, mock_settings utils_oci.get(client="test_client") assert "No settings found for client" in str(exc_info.value) - - -class TestLoggerConfiguration: - """Tests for logger configuration.""" - - def test_logger_exists(self): - """Logger should be configured.""" - assert hasattr(utils_oci, "logger") - - def test_logger_name(self): - """Logger should have correct name.""" - assert utils_oci.logger.name == "api.utils.oci" diff --git a/tests/unit/server/api/utils/test_utils_settings.py b/tests/unit/server/api/utils/test_utils_settings.py index 42299559..e9ba3d27 100644 --- a/tests/unit/server/api/utils/test_utils_settings.py +++ b/tests/unit/server/api/utils/test_utils_settings.py @@ -386,15 +386,3 @@ def test_read_config_from_json_file_success(self, mock_open, mock_access, mock_i result = utils_settings.read_config_from_json_file() assert result is not None - - -class TestLoggerConfiguration: - """Tests for logger configuration.""" - - def test_logger_exists(self): - """Logger should be configured.""" - assert hasattr(utils_settings, "logger") - - def test_logger_name(self): - """Logger should have correct name.""" - assert utils_settings.logger.name == "api.core.settings" diff --git a/tests/unit/server/api/utils/test_utils_testbed.py b/tests/unit/server/api/utils/test_utils_testbed.py index f68ab797..99834d1e 100644 --- a/tests/unit/server/api/utils/test_utils_testbed.py +++ b/tests/unit/server/api/utils/test_utils_testbed.py @@ -310,15 +310,3 @@ def test_process_report_success(self, mock_pickle, mock_execute, make_settings): assert result.eid == "eid123" assert result.correctness == 0.85 - - -class TestLoggerConfiguration: - """Tests for logger configuration.""" - - def test_logger_exists(self): - """Logger should be configured.""" - assert hasattr(utils_testbed, "logger") - - def test_logger_name(self): - """Logger should have correct name.""" - assert utils_testbed.logger.name == "api.utils.testbed" diff --git a/tests/unit/server/api/utils/test_utils_webscrape.py b/tests/unit/server/api/utils/test_utils_webscrape.py index 33be5ad1..2e873f43 100644 --- a/tests/unit/server/api/utils/test_utils_webscrape.py +++ b/tests/unit/server/api/utils/test_utils_webscrape.py @@ -12,9 +12,9 @@ import pytest from bs4 import BeautifulSoup +from unit.server.api.conftest import create_mock_aiohttp_session from server.api.utils import webscrape -from unit.server.api.conftest import create_mock_aiohttp_session class TestNormalizeWs: diff --git a/tests/unit/server/api/v1/test_v1_chat.py b/tests/unit/server/api/v1/test_v1_chat.py index 9c0b9b75..8952a196 100644 --- a/tests/unit/server/api/v1/test_v1_chat.py +++ b/tests/unit/server/api/v1/test_v1_chat.py @@ -228,31 +228,3 @@ async def test_chat_history_return_empty_history(self, mock_convert, mock_graph) result = await chat.chat_history_return(client="test_client") assert result == [] - - -class TestRouterConfiguration: - """Tests for router configuration.""" - - def test_auth_router_exists(self): - """The auth router should be defined.""" - assert hasattr(chat, "auth") - - def test_auth_router_has_routes(self): - """The auth router should have registered routes.""" - routes = [route.path for route in chat.auth.routes] - - assert "/completions" in routes - assert "/streams" in routes - assert "/history" in routes - - -class TestLoggerConfiguration: - """Tests for logger configuration.""" - - def test_logger_exists(self): - """Logger should be configured.""" - assert hasattr(chat, "logger") - - def test_logger_name(self): - """Logger should have correct name.""" - assert chat.logger.name == "endpoints.v1.chat" diff --git a/tests/unit/server/api/v1/test_v1_databases.py b/tests/unit/server/api/v1/test_v1_databases.py index cafe6ed1..07ec0c54 100644 --- a/tests/unit/server/api/v1/test_v1_databases.py +++ b/tests/unit/server/api/v1/test_v1_databases.py @@ -265,30 +265,3 @@ async def test_databases_update_disconnects_other_databases( # Verify: OTHER_DB is now disconnected assert other_db.connected is False - - -class TestRouterConfiguration: - """Tests for router configuration.""" - - def test_auth_router_exists(self): - """The auth router should be defined.""" - assert hasattr(databases, "auth") - - def test_auth_router_has_routes(self): - """The auth router should have registered routes.""" - routes = [route.path for route in databases.auth.routes] - - assert "" in routes - assert "/{name}" in routes - - -class TestLoggerConfiguration: - """Tests for logger configuration.""" - - def test_logger_exists(self): - """Logger should be configured.""" - assert hasattr(databases, "logger") - - def test_logger_name(self): - """Logger should have correct name.""" - assert databases.logger.name == "endpoints.v1.databases" diff --git a/tests/unit/server/api/v1/test_v1_embed.py b/tests/unit/server/api/v1/test_v1_embed.py index fc431f54..42070cdc 100644 --- a/tests/unit/server/api/v1/test_v1_embed.py +++ b/tests/unit/server/api/v1/test_v1_embed.py @@ -16,11 +16,11 @@ import pytest from fastapi import HTTPException, UploadFile from pydantic import HttpUrl +from unit.server.api.conftest import create_mock_aiohttp_session from common.schema import DatabaseVectorStorage, VectorStoreRefreshRequest from server.api.v1 import embed from server.api.utils.databases import DbException -from unit.server.api.conftest import create_mock_aiohttp_session @pytest.fixture @@ -736,36 +736,3 @@ async def test_refresh_vector_store_raises_500_on_generic_exception( assert exc_info.value.status_code == 500 assert "Embedding service unavailable" in exc_info.value.detail - - -class TestRouterConfiguration: - """Tests for router configuration.""" - - def test_auth_router_exists(self): - """The auth router should be defined.""" - assert hasattr(embed, "auth") - - def test_auth_router_has_routes(self): - """The auth router should have registered routes.""" - routes = [route.path for route in embed.auth.routes] - - assert "/{vs}" in routes - assert "/{vs}/files" in routes - assert "/comment" in routes - assert "/sql/store" in routes - assert "/web/store" in routes - assert "/local/store" in routes - assert "/" in routes - assert "/refresh" in routes - - -class TestLoggerConfiguration: - """Tests for logger configuration.""" - - def test_logger_exists(self): - """Logger should be configured.""" - assert hasattr(embed, "logger") - - def test_logger_name(self): - """Logger should have correct name.""" - assert embed.logger.name == "api.v1.embed" diff --git a/tests/unit/server/api/v1/test_v1_mcp.py b/tests/unit/server/api/v1/test_v1_mcp.py index 14461014..e0270f84 100644 --- a/tests/unit/server/api/v1/test_v1_mcp.py +++ b/tests/unit/server/api/v1/test_v1_mcp.py @@ -11,9 +11,9 @@ from unittest.mock import AsyncMock, MagicMock, patch import pytest +from shared_fixtures import TEST_API_KEY from server.api.v1 import mcp -from shared_fixtures import TEST_API_KEY class TestGetMcp: @@ -141,31 +141,3 @@ async def test_mcp_list_resources_returns_empty_list(self, mock_client_class, mo result = await mcp.mcp_list_resources(mcp_engine=mock_fastmcp) assert result == [] - - -class TestRouterConfiguration: - """Tests for router configuration.""" - - def test_auth_router_exists(self): - """The auth router should be defined.""" - assert hasattr(mcp, "auth") - - def test_auth_router_has_routes(self): - """The auth router should have registered routes.""" - routes = [route.path for route in mcp.auth.routes] - - assert "/client" in routes - assert "/tools" in routes - assert "/resources" in routes - - -class TestLoggerConfiguration: - """Tests for logger configuration.""" - - def test_logger_exists(self): - """Logger should be configured.""" - assert hasattr(mcp, "logger") - - def test_logger_name(self): - """Logger should have correct name.""" - assert mcp.logger.name == "api.v1.mcp" diff --git a/tests/unit/server/api/v1/test_v1_mcp_prompts.py b/tests/unit/server/api/v1/test_v1_mcp_prompts.py index 46c518c8..1a7a76dc 100644 --- a/tests/unit/server/api/v1/test_v1_mcp_prompts.py +++ b/tests/unit/server/api/v1/test_v1_mcp_prompts.py @@ -200,30 +200,3 @@ async def test_mcp_update_prompt_none_instructions(self, mock_fastmcp): await mcp_prompts.mcp_update_prompt(name="test-prompt", payload=payload, mcp_engine=mock_fastmcp) assert exc_info.value.status_code == 400 - - -class TestRouterConfiguration: - """Tests for router configuration.""" - - def test_auth_router_exists(self): - """The auth router should be defined.""" - assert hasattr(mcp_prompts, "auth") - - def test_auth_router_has_routes(self): - """The auth router should have registered routes.""" - routes = [route.path for route in mcp_prompts.auth.routes] - - assert "/prompts" in routes - assert "/prompts/{name}" in routes - - -class TestLoggerConfiguration: - """Tests for logger configuration.""" - - def test_logger_exists(self): - """Logger should be configured.""" - assert hasattr(mcp_prompts, "logger") - - def test_logger_name(self): - """Logger should have correct name.""" - assert mcp_prompts.logger.name == "api.v1.mcp_prompts" diff --git a/tests/unit/server/api/v1/test_v1_models.py b/tests/unit/server/api/v1/test_v1_models.py index 6a4f721e..2daeab63 100644 --- a/tests/unit/server/api/v1/test_v1_models.py +++ b/tests/unit/server/api/v1/test_v1_models.py @@ -224,31 +224,3 @@ async def test_models_delete_response_contains_message(self, mock_delete): body = json.loads(result.body) assert "openai/gpt-4" in body["message"] - - -class TestRouterConfiguration: - """Tests for router configuration.""" - - def test_auth_router_exists(self): - """The auth router should be defined.""" - assert hasattr(models, "auth") - - def test_auth_router_has_routes(self): - """The auth router should have registered routes.""" - routes = [route.path for route in models.auth.routes] - - assert "" in routes - assert "/supported" in routes - assert "/{model_provider}/{model_id:path}" in routes - - -class TestLoggerConfiguration: - """Tests for logger configuration.""" - - def test_logger_exists(self): - """Logger should be configured.""" - assert hasattr(models, "logger") - - def test_logger_name(self): - """Logger should have correct name.""" - assert models.logger.name == "endpoints.v1.models" diff --git a/tests/unit/server/api/v1/test_v1_module_config.py b/tests/unit/server/api/v1/test_v1_module_config.py new file mode 100644 index 00000000..b69f091d --- /dev/null +++ b/tests/unit/server/api/v1/test_v1_module_config.py @@ -0,0 +1,100 @@ +""" +Copyright (c) 2024, 2025, Oracle and/or its affiliates. +Licensed under the Universal Permissive License v1.0 as shown at http://oss.oracle.com/licenses/upl. + +Consolidated tests for API v1 module configuration (routers and loggers). +These parameterized tests replace individual boilerplate tests in each module file. +""" + +import pytest + +from server.api.v1 import chat +from server.api.v1 import databases +from server.api.v1 import embed +from server.api.v1 import mcp +from server.api.v1 import mcp_prompts +from server.api.v1 import models +from server.api.v1 import oci +from server.api.v1 import settings +from server.api.v1 import testbed + + +# Module configurations for parameterized tests +API_V1_MODULES = [ + pytest.param(chat, "endpoints.v1.chat", id="chat"), + pytest.param(databases, "endpoints.v1.databases", id="databases"), + pytest.param(embed, "api.v1.embed", id="embed"), + pytest.param(mcp, "api.v1.mcp", id="mcp"), + pytest.param(mcp_prompts, "api.v1.mcp_prompts", id="mcp_prompts"), + pytest.param(models, "endpoints.v1.models", id="models"), + pytest.param(oci, "endpoints.v1.oci", id="oci"), + pytest.param(settings, "endpoints.v1.settings", id="settings"), + pytest.param(testbed, "endpoints.v1.testbed", id="testbed"), +] + +# Expected routes for each module +MODULE_ROUTES = { + "chat": ["/completions", "/streams", "/history"], + "databases": ["", "/{name}"], + "embed": ["/{vs}", "/{vs}/files", "/comment", "/sql/store", "/web/store", "/local/store", "/", "/refresh"], + "mcp": ["/client", "/tools", "/resources"], + "mcp_prompts": ["/prompts", "/prompts/{name}"], + "models": ["", "/supported", "/{model_provider}/{model_id:path}"], + "oci": ["", "/{auth_profile}", "/regions/{auth_profile}", "/genai/{auth_profile}", "/compartments/{auth_profile}"], + "settings": ["", "/load/file", "/load/json"], + "testbed": [ + "/testsets", + "/evaluations", + "/evaluation", + "/testset_qa", + "/testset_delete/{tid}", + "/testset_load", + "/testset_generate", + "/evaluate", + ], +} + + +class TestRouterConfiguration: + """Parameterized tests for router configuration across all API v1 modules.""" + + @pytest.mark.parametrize("module,_logger_name", API_V1_MODULES) + def test_auth_router_exists(self, module, _logger_name): + """Each API v1 module should have an auth router defined.""" + assert hasattr(module, "auth"), f"{module.__name__} should have 'auth' router" + + @pytest.mark.parametrize( + "module,expected_routes", + [ + pytest.param(chat, MODULE_ROUTES["chat"], id="chat"), + pytest.param(databases, MODULE_ROUTES["databases"], id="databases"), + pytest.param(embed, MODULE_ROUTES["embed"], id="embed"), + pytest.param(mcp, MODULE_ROUTES["mcp"], id="mcp"), + pytest.param(mcp_prompts, MODULE_ROUTES["mcp_prompts"], id="mcp_prompts"), + pytest.param(models, MODULE_ROUTES["models"], id="models"), + pytest.param(oci, MODULE_ROUTES["oci"], id="oci"), + pytest.param(settings, MODULE_ROUTES["settings"], id="settings"), + pytest.param(testbed, MODULE_ROUTES["testbed"], id="testbed"), + ], + ) + def test_auth_router_has_routes(self, module, expected_routes): + """Each API v1 module should have the expected routes registered.""" + routes = [route.path for route in module.auth.routes] + for expected_route in expected_routes: + assert expected_route in routes, f"{module.__name__} missing route: {expected_route}" + + +class TestLoggerConfiguration: + """Parameterized tests for logger configuration across all API v1 modules.""" + + @pytest.mark.parametrize("module,_logger_name", API_V1_MODULES) + def test_logger_exists(self, module, _logger_name): + """Each API v1 module should have a logger configured.""" + assert hasattr(module, "logger"), f"{module.__name__} should have 'logger'" + + @pytest.mark.parametrize("module,expected_name", API_V1_MODULES) + def test_logger_name(self, module, expected_name): + """Each API v1 module logger should have the correct name.""" + assert module.logger.name == expected_name, ( + f"{module.__name__} logger name should be '{expected_name}', got '{module.logger.name}'" + ) diff --git a/tests/unit/server/api/v1/test_v1_oci.py b/tests/unit/server/api/v1/test_v1_oci.py index 4402e96c..ae5402fe 100644 --- a/tests/unit/server/api/v1/test_v1_oci.py +++ b/tests/unit/server/api/v1/test_v1_oci.py @@ -330,33 +330,3 @@ async def test_oci_create_genai_models_raises_on_oci_exception( await oci.oci_create_genai_models(auth_profile="DEFAULT") assert exc_info.value.status_code == 500 - - -class TestRouterConfiguration: - """Tests for router configuration.""" - - def test_auth_router_exists(self): - """The auth router should be defined.""" - assert hasattr(oci, "auth") - - def test_auth_router_has_routes(self): - """The auth router should have registered routes.""" - routes = [route.path for route in oci.auth.routes] - - assert "" in routes - assert "/{auth_profile}" in routes - assert "/regions/{auth_profile}" in routes - assert "/genai/{auth_profile}" in routes - assert "/compartments/{auth_profile}" in routes - - -class TestLoggerConfiguration: - """Tests for logger configuration.""" - - def test_logger_exists(self): - """Logger should be configured.""" - assert hasattr(oci, "logger") - - def test_logger_name(self): - """Logger should have correct name.""" - assert oci.logger.name == "endpoints.v1.oci" diff --git a/tests/unit/server/api/v1/test_v1_settings.py b/tests/unit/server/api/v1/test_v1_settings.py index 348613a1..15508c62 100644 --- a/tests/unit/server/api/v1/test_v1_settings.py +++ b/tests/unit/server/api/v1/test_v1_settings.py @@ -296,31 +296,3 @@ def test_incl_readonly_param_true(self): """_incl_readonly_param should return True when set.""" result = settings._incl_readonly_param(incl_readonly=True) assert result is True - - -class TestRouterConfiguration: - """Tests for router configuration.""" - - def test_auth_router_exists(self): - """The auth router should be defined.""" - assert hasattr(settings, "auth") - - def test_auth_router_has_routes(self): - """The auth router should have registered routes.""" - routes = [route.path for route in settings.auth.routes] - - assert "" in routes # Get, Update, Create - assert "/load/file" in routes - assert "/load/json" in routes - - -class TestLoggerConfiguration: - """Tests for logger configuration.""" - - def test_logger_exists(self): - """Logger should be configured.""" - assert hasattr(settings, "logger") - - def test_logger_name(self): - """Logger should have correct name.""" - assert settings.logger.name == "endpoints.v1.settings" diff --git a/tests/unit/server/api/v1/test_v1_testbed.py b/tests/unit/server/api/v1/test_v1_testbed.py index a69afab4..ebc18dcd 100644 --- a/tests/unit/server/api/v1/test_v1_testbed.py +++ b/tests/unit/server/api/v1/test_v1_testbed.py @@ -609,15 +609,3 @@ async def test_testbed_evaluate_raises_500_on_correctness_key_error( assert exc_info.value.status_code == 500 assert "correctness" in str(exc_info.value.detail) - - -class TestLoggerConfiguration: - """Tests for logger configuration.""" - - def test_logger_exists(self): - """Logger should be configured.""" - assert hasattr(testbed, "logger") - - def test_logger_name(self): - """Logger should have correct name.""" - assert testbed.logger.name == "endpoints.v1.testbed" diff --git a/tests/unit/server/bootstrap/test_bootstrap_bootstrap.py b/tests/unit/server/bootstrap/test_bootstrap_bootstrap.py index 542baee1..f4eb0be2 100644 --- a/tests/unit/server/bootstrap/test_bootstrap_bootstrap.py +++ b/tests/unit/server/bootstrap/test_bootstrap_bootstrap.py @@ -169,15 +169,3 @@ def test_stores_settings_results(self, make_settings): importlib.reload(bootstrap) assert len(bootstrap.SETTINGS_OBJECTS) == 2 - - -class TestLoggerConfiguration: - """Tests for logger configuration.""" - - def test_logger_exists(self): - """Logger should be configured in bootstrap module.""" - assert hasattr(bootstrap, "logger") - - def test_logger_name(self): - """Logger should have correct name.""" - assert bootstrap.logger.name == "bootstrap" diff --git a/tests/unit/server/bootstrap/test_bootstrap_configfile.py b/tests/unit/server/bootstrap/test_bootstrap_configfile.py index 12c505ca..da15ec80 100644 --- a/tests/unit/server/bootstrap/test_bootstrap_configfile.py +++ b/tests/unit/server/bootstrap/test_bootstrap_configfile.py @@ -16,7 +16,6 @@ import pytest -from server.bootstrap import configfile from server.bootstrap.configfile import config_file_path @@ -215,15 +214,3 @@ def test_config_file_path_parent_is_server_directory(self): path_obj = Path(path) # Should be under server/etc/configuration.json assert path_obj.parent.name == "etc" - - -class TestLoggerConfiguration: - """Tests for logger configuration.""" - - def test_logger_exists(self): - """Logger should be configured in configfile module.""" - assert hasattr(configfile, "logger") - - def test_logger_name(self): - """Logger should have correct name.""" - assert configfile.logger.name == "bootstrap.configfile" diff --git a/tests/unit/server/bootstrap/test_bootstrap_databases.py b/tests/unit/server/bootstrap/test_bootstrap_databases.py index f536aede..113eb65a 100644 --- a/tests/unit/server/bootstrap/test_bootstrap_databases.py +++ b/tests/unit/server/bootstrap/test_bootstrap_databases.py @@ -11,14 +11,14 @@ import os import pytest - -from server.bootstrap import databases as databases_module from shared_fixtures import ( assert_database_list_valid, assert_has_default_database, get_database_by_name, ) +from server.bootstrap import databases as databases_module + @pytest.mark.usefixtures("reset_config_store", "clean_env") class TestDatabasesMain: @@ -204,15 +204,3 @@ def test_main_callable_directly(self): """main() should be callable when running as script.""" result = databases_module.main() assert result is not None - - -class TestLoggerConfiguration: - """Tests for logger configuration.""" - - def test_logger_exists(self): - """Logger should be configured in databases module.""" - assert hasattr(databases_module, "logger") - - def test_logger_name(self): - """Logger should have correct name.""" - assert databases_module.logger.name == "bootstrap.databases" diff --git a/tests/unit/server/bootstrap/test_bootstrap_models.py b/tests/unit/server/bootstrap/test_bootstrap_models.py index 8f8b56a7..d6f29c5b 100644 --- a/tests/unit/server/bootstrap/test_bootstrap_models.py +++ b/tests/unit/server/bootstrap/test_bootstrap_models.py @@ -12,9 +12,9 @@ from unittest.mock import patch import pytest +from shared_fixtures import assert_model_list_valid, get_model_by_id, TEST_API_KEY from server.bootstrap import models as models_module -from shared_fixtures import assert_model_list_valid, get_model_by_id, TEST_API_KEY @pytest.mark.usefixtures("reset_config_store", "clean_env", "mock_is_url_accessible") @@ -398,15 +398,3 @@ def test_main_callable_directly(self): """main() should be callable when running as script.""" result = models_module.main() assert result is not None - - -class TestLoggerConfiguration: - """Tests for logger configuration.""" - - def test_logger_exists(self): - """Logger should be configured in models module.""" - assert hasattr(models_module, "logger") - - def test_logger_name(self): - """Logger should have correct name.""" - assert models_module.logger.name == "bootstrap.models" diff --git a/tests/unit/server/bootstrap/test_bootstrap_module_config.py b/tests/unit/server/bootstrap/test_bootstrap_module_config.py new file mode 100644 index 00000000..99953a69 --- /dev/null +++ b/tests/unit/server/bootstrap/test_bootstrap_module_config.py @@ -0,0 +1,43 @@ +""" +Copyright (c) 2024, 2025, Oracle and/or its affiliates. +Licensed under the Universal Permissive License v1.0 as shown at http://oss.oracle.com/licenses/upl. + +Consolidated tests for bootstrap module configuration (loggers). +These parameterized tests replace individual boilerplate tests in each module file. +""" + +import pytest + +from server.bootstrap import bootstrap +from server.bootstrap import configfile +from server.bootstrap import databases as databases_module +from server.bootstrap import models as models_module +from server.bootstrap import oci as oci_module +from server.bootstrap import settings as settings_module + + +# Module configurations for parameterized tests +BOOTSTRAP_MODULES = [ + pytest.param(bootstrap, "bootstrap", id="bootstrap"), + pytest.param(configfile, "bootstrap.configfile", id="configfile"), + pytest.param(databases_module, "bootstrap.databases", id="databases"), + pytest.param(models_module, "bootstrap.models", id="models"), + pytest.param(oci_module, "bootstrap.oci", id="oci"), + pytest.param(settings_module, "bootstrap.settings", id="settings"), +] + + +class TestLoggerConfiguration: + """Parameterized tests for logger configuration across all bootstrap modules.""" + + @pytest.mark.parametrize("module,_logger_name", BOOTSTRAP_MODULES) + def test_logger_exists(self, module, _logger_name): + """Each bootstrap module should have a logger configured.""" + assert hasattr(module, "logger"), f"{module.__name__} should have 'logger'" + + @pytest.mark.parametrize("module,expected_name", BOOTSTRAP_MODULES) + def test_logger_name(self, module, expected_name): + """Each bootstrap module logger should have the correct name.""" + assert module.logger.name == expected_name, ( + f"{module.__name__} logger name should be '{expected_name}', got '{module.logger.name}'" + ) diff --git a/tests/unit/server/bootstrap/test_bootstrap_oci.py b/tests/unit/server/bootstrap/test_bootstrap_oci.py index 89242c8e..09a0f332 100644 --- a/tests/unit/server/bootstrap/test_bootstrap_oci.py +++ b/tests/unit/server/bootstrap/test_bootstrap_oci.py @@ -315,15 +315,3 @@ def test_main_callable_directly(self): with patch("oci.config.from_file", side_effect=oci.exceptions.ConfigFileNotFound()): result = oci_module.main() assert result is not None - - -class TestLoggerConfiguration: - """Tests for logger configuration.""" - - def test_logger_exists(self): - """Logger should be configured in oci module.""" - assert hasattr(oci_module, "logger") - - def test_logger_name(self): - """Logger should have correct name.""" - assert oci_module.logger.name == "bootstrap.oci" diff --git a/tests/unit/server/bootstrap/test_bootstrap_settings.py b/tests/unit/server/bootstrap/test_bootstrap_settings.py index 5bac59b8..514f37b0 100644 --- a/tests/unit/server/bootstrap/test_bootstrap_settings.py +++ b/tests/unit/server/bootstrap/test_bootstrap_settings.py @@ -129,15 +129,3 @@ def test_main_callable_directly(self): # This tests the if __name__ == "__main__" block indirectly result = settings_module.main() assert result is not None - - -class TestLoggerConfiguration: - """Tests for logger configuration.""" - - def test_logger_exists(self): - """Logger should be configured in settings module.""" - assert hasattr(settings_module, "logger") - - def test_logger_name(self): - """Logger should have correct name.""" - assert settings_module.logger.name == "bootstrap.settings" From f7d3589045ff32fe600f7796afcc3569acd8d43d Mon Sep 17 00:00:00 2001 From: gotsysdba Date: Wed, 3 Dec 2025 08:29:34 +0000 Subject: [PATCH 15/20] Fix for when API server crashes --- src/client/content/chatbot.py | 4 +- src/client/content/testbed.py | 4 +- src/client/utils/api_call.py | 93 ++++--- src/client/utils/st_common.py | 58 ----- src/client/utils/tool_options.py | 68 +++++ .../unit/client/content/test_chatbot_unit.py | 13 +- tests/unit/client/utils/test_api_call_unit.py | 232 ++++++++++++++++++ 7 files changed, 351 insertions(+), 121 deletions(-) create mode 100644 src/client/utils/tool_options.py create mode 100644 tests/unit/client/utils/test_api_call_unit.py diff --git a/src/client/content/chatbot.py b/src/client/content/chatbot.py index cd49ee9c..1a7c3271 100644 --- a/src/client/content/chatbot.py +++ b/src/client/content/chatbot.py @@ -16,7 +16,7 @@ from streamlit import session_state as state from client.content.config.tabs.models import get_models -from client.utils import st_common, api_call, client, vs_options +from client.utils import st_common, api_call, client, vs_options, tool_options from client.utils.st_footer import render_chat_footer from common import logging_config @@ -82,7 +82,7 @@ def setup_sidebar(): st.stop() state.enable_client = True - st_common.tools_sidebar() + tool_options.tools_sidebar() st_common.history_sidebar() st_common.ll_sidebar() vs_options.vector_search_sidebar() diff --git a/src/client/content/testbed.py b/src/client/content/testbed.py index a716980b..6f020798 100644 --- a/src/client/content/testbed.py +++ b/src/client/content/testbed.py @@ -17,7 +17,7 @@ from client.content.config.tabs.models import get_models -from client.utils import st_common, api_call, vs_options +from client.utils import st_common, api_call, vs_options, tool_options from common import logging_config @@ -496,7 +496,7 @@ def render_evaluation_ui(available_ll_models: list) -> None: st.subheader("Q&A Evaluation", divider="red") st.info("Use the sidebar settings for chatbot evaluation parameters", icon="⬅️") - st_common.tools_sidebar() + tool_options.tools_sidebar() st_common.ll_sidebar() vs_options.vector_search_sidebar() st.write("Choose a model to judge the correctness of the chatbot answer, then start evaluation.") diff --git a/src/client/utils/api_call.py b/src/client/utils/api_call.py index 9a844e10..be33a115 100644 --- a/src/client/utils/api_call.py +++ b/src/client/utils/api_call.py @@ -42,19 +42,6 @@ def sanitize_sensitive_data(data): return data -def _handle_json_response(response, method: str): - """Parse JSON response and handle parsing errors.""" - try: - data = response.json() - logger.debug("%s Data: %s", method, data) - return response - except (json.JSONDecodeError, ValueError) as json_ex: - error_msg = f"Server returned invalid JSON response. Status: {response.status_code}" - logger.error("Response text: %s", response.text[:500]) - error_msg += f". Response preview: {response.text[:200]}" - raise ApiError(error_msg) from json_ex - - def _handle_http_error(ex: requests.exceptions.HTTPError): """Extract error message from HTTP error response.""" try: @@ -66,6 +53,12 @@ def _handle_http_error(ex: requests.exceptions.HTTPError): return failure +def _error_response(message: str) -> dict: + """Display error to user and return error dict.""" + st.error(f"API Error: {message}") + return {"error": message} + + def send_request( method: str, endpoint: str, @@ -75,30 +68,26 @@ def send_request( retries: int = 3, backoff_factor: float = 2.0, ) -> dict: - """Send API requests with retry logic.""" + """Send API requests with retry logic. Returns JSON response or error dict.""" + method_map = {"GET": requests.get, "POST": requests.post, "PATCH": requests.patch, "DELETE": requests.delete} + if method not in method_map: + return _error_response(f"Unsupported HTTP method: {method}") + url = urljoin(f"{state.server['url']}:{state.server['port']}/", endpoint) payload = payload or {} - token = state.server["key"] - headers = {"Authorization": f"Bearer {token}"} - # Send client in header if it exists + headers = {"Authorization": f"Bearer {state.server['key']}"} if getattr(state, "client_settings", {}).get("client"): headers["Client"] = state.client_settings["client"] - method_map = {"GET": requests.get, "POST": requests.post, "PATCH": requests.patch, "DELETE": requests.delete} - - if method not in method_map: - raise ApiError(f"Unsupported HTTP method: {method}") - - args = { + args = {k: v for k, v in { "url": url, "headers": headers, "timeout": timeout, "params": params, "files": payload.get("files") if method == "POST" else None, "json": payload.get("json") if method in ["POST", "PATCH"] else None, - } - args = {k: v for k, v in args.items() if v is not None} - # Avoid logging out binary data in files + }.items() if v is not None} + log_args = sanitize_sensitive_data(args.copy()) try: if log_args.get("files"): @@ -106,37 +95,40 @@ def send_request( except (ValueError, IndexError): pass logger.info("%s Request: %s", method, log_args) + + result = None for attempt in range(retries + 1): try: response = method_map[method](**args) logger.info("%s Response: %s", method, response) response.raise_for_status() - return _handle_json_response(response, method) + result = response.json() + logger.debug("%s Data: %s", method, result) + break except requests.exceptions.HTTPError as ex: logger.error("HTTP Error: %s", ex) - raise ApiError(_handle_http_error(ex)) from ex + result = _error_response(_handle_http_error(ex)) + break except requests.exceptions.ConnectionError as ex: logger.error("Attempt %d: Connection Error: %s", attempt + 1, ex) if attempt < retries: - sleep_time = backoff_factor * (2**attempt) - logger.info("Retrying in %.1f seconds...", sleep_time) - time.sleep(sleep_time) + time.sleep(backoff_factor * (2**attempt)) continue - raise ApiError(f"Connection failed after {retries + 1} attempts: {str(ex)}") from ex + result = _error_response(f"Connection failed after {retries + 1} attempts") - except requests.exceptions.RequestException as ex: - logger.error("Request Error: %s", ex) - raise ApiError(f"Request failed: {str(ex)}") from ex + except (requests.exceptions.RequestException, json.JSONDecodeError, ValueError) as ex: + logger.error("Request/JSON Error: %s", ex) + result = _error_response(f"Request failed: {str(ex)}") + break - raise ApiError("An unexpected error occurred.") + return result if result is not None else _error_response("An unexpected error occurred.") -def get(endpoint: str, params: Optional[dict] = None, retries: int = 3, backoff_factor: float = 2.0) -> json: +def get(endpoint: str, params: Optional[dict] = None, retries: int = 3, backoff_factor: float = 2.0) -> dict: """GET Requests""" - response = send_request("GET", endpoint, params=params, retries=retries, backoff_factor=backoff_factor) - return response.json() + return send_request("GET", endpoint, params=params, retries=retries, backoff_factor=backoff_factor) def post( @@ -146,9 +138,9 @@ def post( timeout: int = 60, retries: int = 5, backoff_factor: float = 1.5, -) -> json: +) -> dict: """POST Requests""" - response = send_request( + return send_request( "POST", endpoint, params=params, @@ -157,7 +149,6 @@ def post( retries=retries, backoff_factor=backoff_factor, ) - return response.json() def patch( @@ -168,9 +159,9 @@ def patch( retries: int = 5, backoff_factor: float = 1.5, toast=True, -) -> None: +) -> dict: """PATCH Requests""" - response = send_request( + result = send_request( "PATCH", endpoint, params=params, @@ -179,16 +170,16 @@ def patch( retries=retries, backoff_factor=backoff_factor, ) - if toast: + if toast and "error" not in result: st.toast("Update Successful.", icon="✅") time.sleep(1) - return response.json() + return result -def delete(endpoint: str, timeout: int = 60, retries: int = 5, backoff_factor: float = 1.5, toast=True) -> None: +def delete(endpoint: str, timeout: int = 60, retries: int = 5, backoff_factor: float = 1.5, toast=True) -> dict: """DELETE Requests""" - response = send_request("DELETE", endpoint, timeout=timeout, retries=retries, backoff_factor=backoff_factor) - success = response.json()["message"] - if toast: - st.toast(success, icon="✅") + result = send_request("DELETE", endpoint, timeout=timeout, retries=retries, backoff_factor=backoff_factor) + if toast and "error" not in result: + st.toast(result.get("message", "Deleted."), icon="✅") time.sleep(1) + return result diff --git a/src/client/utils/st_common.py b/src/client/utils/st_common.py index 0d0eff19..57aec03c 100644 --- a/src/client/utils/st_common.py +++ b/src/client/utils/st_common.py @@ -232,61 +232,3 @@ def ll_sidebar() -> None: key="selected_ll_model_presence_penalty", on_change=update_client_settings("ll_model"), ) - - -##################################################### -# Tools Options -##################################################### -def tools_sidebar() -> None: - """Tools Sidebar Settings""" - - # Setup Tool Box - state.tool_box = { - "LLM Only": {"description": "Do not use tools", "enabled": True}, - "Vector Search": {"description": "Use AI with Unstructured Data", "enabled": True}, - "NL2SQL": {"description": "Use AI with Structured Data", "enabled": True}, - } - - def _update_set_tool(): - """Update user settings as to which tool is being used""" - state.client_settings["tools_enabled"] = [state.selected_tool] - - def _disable_tool(tool: str, reason: str = None) -> None: - """Disable a tool in the tool box""" - if reason: - logger.debug("%s Disabled (%s)", tool, reason) - st.warning(f"{reason}. Disabling {tool}.", icon="⚠️") - state.tool_box[tool]["enabled"] = False - - if not is_db_configured(): - logger.debug("Vector Search/NL2SQL Disabled (Database not configured)") - st.warning("Database is not configured. Disabling Vector Search and NL2SQL tools.", icon="⚠️") - _disable_tool("Vector Search") - _disable_tool("NL2SQL") - else: - # Check to enable Vector Store - embed_models_enabled = enabled_models_lookup("embed") - db_alias = state.client_settings.get("database", {}).get("alias") - database_lookup = state_configs_lookup("database_configs", "name") - if not embed_models_enabled: - _disable_tool("Vector Search", "No embedding models are configured and/or enabled.") - elif not database_lookup[db_alias].get("vector_stores"): - _disable_tool("Vector Search", "Database has no vector stores.") - else: - # Check if any vector stores use an enabled embedding model - vector_stores = database_lookup[db_alias].get("vector_stores", []) - usable_vector_stores = [vs for vs in vector_stores if vs.get("model") in embed_models_enabled] - if not usable_vector_stores: - _disable_tool("Vector Search", "No vector stores match the enabled embedding models") - - tool_box = [key for key, val in state.tool_box.items() if val["enabled"]] - current_tool = state.client_settings["tools_enabled"][0] - tool_index = tool_box.index(current_tool) if current_tool in tool_box else 0 - st.sidebar.selectbox( - "Tool Selection", - tool_box, - index=tool_index, - label_visibility="collapsed", - on_change=_update_set_tool, - key="selected_tool", - ) diff --git a/src/client/utils/tool_options.py b/src/client/utils/tool_options.py new file mode 100644 index 00000000..fb190ccb --- /dev/null +++ b/src/client/utils/tool_options.py @@ -0,0 +1,68 @@ +""" +Copyright (c) 2024, 2025, Oracle and/or its affiliates. +Licensed under the Universal Permissive License v1.0 as shown at http://oss.oracle.com/licenses/upl. +""" +# spell-checker:ignore selectbox + +import streamlit as st +from streamlit import session_state as state + +from client.utils import st_common +from common import logging_config + +logger = logging_config.logging.getLogger("client.utils.st_common") + + +def tools_sidebar() -> None: + """Tools Sidebar Settings""" + + # Setup Tool Box + state.tool_box = { + "LLM Only": {"description": "Do not use tools", "enabled": True}, + "Vector Search": {"description": "Use AI with Unstructured Data", "enabled": True}, + "NL2SQL": {"description": "Use AI with Structured Data", "enabled": True}, + } + + def _update_set_tool(): + """Update user settings as to which tool is being used""" + state.client_settings["tools_enabled"] = [state.selected_tool] + + def _disable_tool(tool: str, reason: str = None) -> None: + """Disable a tool in the tool box""" + if reason: + logger.debug("%s Disabled (%s)", tool, reason) + st.warning(f"{reason}. Disabling {tool}.", icon="⚠️") + state.tool_box[tool]["enabled"] = False + + if not st_common.is_db_configured(): + logger.debug("Vector Search/NL2SQL Disabled (Database not configured)") + st.warning("Database is not configured. Disabling Vector Search and NL2SQL tools.", icon="⚠️") + _disable_tool("Vector Search") + _disable_tool("NL2SQL") + else: + # Check to enable Vector Store + embed_models_enabled = st_common.enabled_models_lookup("embed") + db_alias = state.client_settings.get("database", {}).get("alias") + database_lookup = st_common.state_configs_lookup("database_configs", "name") + if not embed_models_enabled: + _disable_tool("Vector Search", "No embedding models are configured and/or enabled.") + elif not database_lookup[db_alias].get("vector_stores"): + _disable_tool("Vector Search", "Database has no vector stores.") + else: + # Check if any vector stores use an enabled embedding model + vector_stores = database_lookup[db_alias].get("vector_stores", []) + usable_vector_stores = [vs for vs in vector_stores if vs.get("model") in embed_models_enabled] + if not usable_vector_stores: + _disable_tool("Vector Search", "No vector stores match the enabled embedding models") + + tool_box = [key for key, val in state.tool_box.items() if val["enabled"]] + current_tool = state.client_settings["tools_enabled"][0] + tool_index = tool_box.index(current_tool) if current_tool in tool_box else 0 + st.sidebar.selectbox( + "Tool Selection", + tool_box, + index=tool_index, + label_visibility="collapsed", + on_change=_update_set_tool, + key="selected_tool", + ) diff --git a/tests/unit/client/content/test_chatbot_unit.py b/tests/unit/client/content/test_chatbot_unit.py index a2683a23..def715a9 100644 --- a/tests/unit/client/content/test_chatbot_unit.py +++ b/tests/unit/client/content/test_chatbot_unit.py @@ -11,7 +11,6 @@ import pytest - ############################################################################# # Test show_vector_search_refs Function ############################################################################# @@ -135,14 +134,14 @@ def test_setup_sidebar_no_models(self, monkeypatch): def test_setup_sidebar_with_models(self, monkeypatch): """Test setup_sidebar with enabled language models""" from client.content import chatbot - from client.utils import st_common, vs_options + from client.utils import st_common, vs_options, tool_options from streamlit import session_state as state # Mock enabled_models_lookup to return models monkeypatch.setattr(st_common, "enabled_models_lookup", lambda x: {"gpt-4": {}}) # Mock sidebar functions - monkeypatch.setattr(st_common, "tools_sidebar", MagicMock()) + monkeypatch.setattr(tool_options, "tools_sidebar", MagicMock()) monkeypatch.setattr(st_common, "history_sidebar", MagicMock()) monkeypatch.setattr(st_common, "ll_sidebar", MagicMock()) monkeypatch.setattr(vs_options, "vector_search_sidebar", MagicMock()) @@ -159,7 +158,7 @@ def test_setup_sidebar_with_models(self, monkeypatch): def test_setup_sidebar_client_disabled(self, monkeypatch): """Test setup_sidebar when client gets disabled""" from client.content import chatbot - from client.utils import st_common, vs_options + from client.utils import st_common, vs_options, tool_options from streamlit import session_state as state import streamlit as st @@ -169,7 +168,7 @@ def test_setup_sidebar_client_disabled(self, monkeypatch): def disable_client(): state.enable_client = False - monkeypatch.setattr(st_common, "tools_sidebar", disable_client) + monkeypatch.setattr(tool_options, "tools_sidebar", disable_client) monkeypatch.setattr(st_common, "history_sidebar", MagicMock()) monkeypatch.setattr(st_common, "ll_sidebar", MagicMock()) monkeypatch.setattr(vs_options, "vector_search_sidebar", MagicMock()) @@ -219,9 +218,7 @@ def test_create_client_new(self, monkeypatch): assert state.user_client == mock_client_instance # Verify Client was called with correct parameters - mock_client_class.assert_called_once_with( - server=state.server, settings=state.client_settings, timeout=1200 - ) + mock_client_class.assert_called_once_with(server=state.server, settings=state.client_settings, timeout=1200) def test_create_client_existing(self): """Test getting existing client""" diff --git a/tests/unit/client/utils/test_api_call_unit.py b/tests/unit/client/utils/test_api_call_unit.py new file mode 100644 index 00000000..8a750ce2 --- /dev/null +++ b/tests/unit/client/utils/test_api_call_unit.py @@ -0,0 +1,232 @@ +# pylint: disable=protected-access,import-error,import-outside-toplevel +""" +Copyright (c) 2024, 2025, Oracle and/or its affiliates. +Licensed under the Universal Permissive License v1.0 as shown at http://oss.oracle.com/licenses/upl. + +Unit tests for api_call module - focusing on graceful error handling when +API server is disconnected or returns errors. +""" +# spell-checker: disable + +from unittest.mock import MagicMock +import requests + + +############################################################################# +# Test Graceful Error Handling on Server Errors +############################################################################# +class TestGracefulErrorHandling: + """Test that API call functions handle server errors gracefully.""" + + def test_get_handles_http_500_gracefully(self, app_server, monkeypatch): + """Test that get() handles HTTP 500 errors gracefully without raising.""" + assert app_server is not None + + from client.utils import api_call + import streamlit as st + + # Mock the requests.get to raise HTTPError with 500 status + mock_response = MagicMock() + mock_response.status_code = 500 + mock_response.text = "Internal Server Error" + mock_response.json.return_value = {"detail": "Internal Server Error"} + mock_response.raise_for_status.side_effect = requests.exceptions.HTTPError( + response=mock_response + ) + + mock_get = MagicMock(return_value=mock_response) + monkeypatch.setattr(requests, "get", mock_get) + + # Mock st.error to capture error display + mock_error = MagicMock() + monkeypatch.setattr(st, "error", mock_error) + + # Call get() - should NOT raise, should return error dict and show error + result = api_call.get(endpoint="v1/test", retries=0) + + # Should return error dict (not raise) + assert "error" in result + + # Should have shown error to user + assert mock_error.called + + def test_delete_handles_http_500_gracefully(self, app_server, monkeypatch): + """Test that delete() handles HTTP 500 errors gracefully without raising.""" + assert app_server is not None + + from client.utils import api_call + import streamlit as st + + # Mock the requests.delete to raise HTTPError with 500 status + mock_response = MagicMock() + mock_response.status_code = 500 + mock_response.text = "Internal Server Error" + mock_response.json.return_value = {"detail": "Internal Server Error"} + mock_response.raise_for_status.side_effect = requests.exceptions.HTTPError( + response=mock_response + ) + + mock_delete = MagicMock(return_value=mock_response) + monkeypatch.setattr(requests, "delete", mock_delete) + + # Mock st.error to capture error display + mock_error = MagicMock() + monkeypatch.setattr(st, "error", mock_error) + + # Call delete() - should NOT raise, should return error dict and show error + result = api_call.delete(endpoint="v1/test", retries=0, toast=False) + + # Should return error dict (not raise) + assert "error" in result + + # Should have shown error to user + assert mock_error.called + + def test_post_handles_http_500_gracefully(self, app_server, monkeypatch): + """Test that post() handles HTTP 500 errors gracefully without raising.""" + assert app_server is not None + + from client.utils import api_call + import streamlit as st + + # Mock the requests.post to raise HTTPError with 500 status + mock_response = MagicMock() + mock_response.status_code = 500 + mock_response.text = "Internal Server Error" + mock_response.json.return_value = {"detail": "Internal Server Error"} + mock_response.raise_for_status.side_effect = requests.exceptions.HTTPError( + response=mock_response + ) + + mock_post = MagicMock(return_value=mock_response) + monkeypatch.setattr(requests, "post", mock_post) + + # Mock st.error to capture error display + mock_error = MagicMock() + monkeypatch.setattr(st, "error", mock_error) + + # Call post() - should NOT raise, should return error dict and show error + result = api_call.post(endpoint="v1/test", retries=0) + + # Should return error dict (not raise) + assert "error" in result + + # Should have shown error to user + assert mock_error.called + + def test_patch_handles_http_500_gracefully(self, app_server, monkeypatch): + """Test that patch() handles HTTP 500 errors gracefully without raising.""" + assert app_server is not None + + from client.utils import api_call + import streamlit as st + + # Mock the requests.patch to raise HTTPError with 500 status + mock_response = MagicMock() + mock_response.status_code = 500 + mock_response.text = "Internal Server Error" + mock_response.json.return_value = {"detail": "Internal Server Error"} + mock_response.raise_for_status.side_effect = requests.exceptions.HTTPError( + response=mock_response + ) + + mock_patch = MagicMock(return_value=mock_response) + monkeypatch.setattr(requests, "patch", mock_patch) + + # Mock st.error to capture error display + mock_error = MagicMock() + monkeypatch.setattr(st, "error", mock_error) + + # Call patch() - should NOT raise, should return error dict and show error + result = api_call.patch(endpoint="v1/test", retries=0, toast=False) + + # Should return error dict (not raise) + assert "error" in result + + # Should have shown error to user + assert mock_error.called + + def test_get_handles_connection_error_gracefully(self, app_server, monkeypatch): + """Test that get() handles connection errors gracefully after retries exhausted.""" + assert app_server is not None + + from client.utils import api_call + import streamlit as st + + # Mock requests.get to raise ConnectionError + mock_get = MagicMock(side_effect=requests.exceptions.ConnectionError("Connection refused")) + monkeypatch.setattr(requests, "get", mock_get) + + # Mock st.error to capture error display + mock_error = MagicMock() + monkeypatch.setattr(st, "error", mock_error) + + # Call get() with no retries - should NOT raise, should return error dict and show error + result = api_call.get(endpoint="v1/test", retries=0) + + # Should return error dict (not raise) + assert "error" in result + + # Should have shown error to user + assert mock_error.called + + def test_delete_handles_connection_error_gracefully(self, app_server, monkeypatch): + """Test that delete() handles connection errors gracefully after retries exhausted.""" + assert app_server is not None + + from client.utils import api_call + import streamlit as st + + # Mock requests.delete to raise ConnectionError + mock_delete = MagicMock(side_effect=requests.exceptions.ConnectionError("Connection refused")) + monkeypatch.setattr(requests, "delete", mock_delete) + + # Mock st.error to capture error display + mock_error = MagicMock() + monkeypatch.setattr(st, "error", mock_error) + + # Call delete() with no retries - should NOT raise, should return error dict and show error + result = api_call.delete(endpoint="v1/test", retries=0, toast=False) + + # Should return error dict (not raise) + assert "error" in result + + # Should have shown error to user + assert mock_error.called + + +############################################################################# +# Test ApiError Class +############################################################################# +class TestApiError: + """Test ApiError exception class.""" + + def test_api_error_with_string_message(self, app_server): + """Test ApiError with string message.""" + assert app_server is not None + + from client.utils.api_call import ApiError + + error = ApiError("Test error message") + assert str(error) == "Test error message" + assert error.message == "Test error message" + + def test_api_error_with_dict_message(self, app_server): + """Test ApiError with dict message containing detail.""" + assert app_server is not None + + from client.utils.api_call import ApiError + + error = ApiError({"detail": "Detailed error message"}) + assert str(error) == "Detailed error message" + assert error.message == "Detailed error message" + + def test_api_error_with_dict_no_detail(self, app_server): + """Test ApiError with dict message without detail key.""" + assert app_server is not None + + from client.utils.api_call import ApiError + + error = ApiError({"error": "Some error"}) + # Should convert dict to string + assert "error" in str(error) From 6fb3874d2e92ea8db3def82ba547cc61ecec09c6 Mon Sep 17 00:00:00 2001 From: gotsysdba Date: Wed, 3 Dec 2025 08:36:36 +0000 Subject: [PATCH 16/20] Fixed tool selection bug --- src/client/utils/tool_options.py | 2 + .../client/utils/test_tool_options_unit.py | 252 ++++++++++++++++++ 2 files changed, 254 insertions(+) create mode 100644 tests/unit/client/utils/test_tool_options_unit.py diff --git a/src/client/utils/tool_options.py b/src/client/utils/tool_options.py index fb190ccb..769ee983 100644 --- a/src/client/utils/tool_options.py +++ b/src/client/utils/tool_options.py @@ -57,6 +57,8 @@ def _disable_tool(tool: str, reason: str = None) -> None: tool_box = [key for key, val in state.tool_box.items() if val["enabled"]] current_tool = state.client_settings["tools_enabled"][0] + if current_tool not in tool_box: + state.client_settings["tools_enabled"] = ["LLM Only"] tool_index = tool_box.index(current_tool) if current_tool in tool_box else 0 st.sidebar.selectbox( "Tool Selection", diff --git a/tests/unit/client/utils/test_tool_options_unit.py b/tests/unit/client/utils/test_tool_options_unit.py new file mode 100644 index 00000000..0e3779cc --- /dev/null +++ b/tests/unit/client/utils/test_tool_options_unit.py @@ -0,0 +1,252 @@ +# pylint: disable=protected-access,import-error,import-outside-toplevel +""" +Copyright (c) 2024, 2025, Oracle and/or its affiliates. +Licensed under the Universal Permissive License v1.0 as shown at http://oss.oracle.com/licenses/upl. +""" +# spell-checker: disable + +from unittest.mock import MagicMock + +from streamlit import session_state as state + + +############################################################################# +# Test tools_sidebar Function +############################################################################# +class TestToolsSidebar: + """Test tools_sidebar function""" + + def test_selected_tool_becomes_unavailable_resets_to_llm_only(self, app_server, monkeypatch): + """Test that when a previously selected tool becomes unavailable, it resets to LLM Only. + + This tests the bug fix where a user selects Vector Search, then the database + disconnects, and the tool_box no longer contains Vector Search. Without the fix, + tool_box.index(current_tool) would raise ValueError. + """ + assert app_server is not None + + from client.utils import st_common, tool_options + import streamlit as st + + # Setup: User had previously selected Vector Search + state.client_settings = { + "tools_enabled": ["Vector Search"], + "database": {"alias": "DEFAULT"}, + } + + # Mock: Database is not configured (makes Vector Search and NL2SQL unavailable) + monkeypatch.setattr(st_common, "is_db_configured", lambda: False) + + # Mock Streamlit UI components + mock_warning = MagicMock() + mock_selectbox = MagicMock() + mock_sidebar = MagicMock() + mock_sidebar.selectbox = mock_selectbox + + monkeypatch.setattr(st, "warning", mock_warning) + monkeypatch.setattr(st, "sidebar", mock_sidebar) + + # Call tools_sidebar - this should reset to "LLM Only" instead of crashing + tool_options.tools_sidebar() + + # Verify the settings were reset to LLM Only + assert state.client_settings["tools_enabled"] == ["LLM Only"] + + # Verify selectbox was called with only LLM Only available + mock_selectbox.assert_called_once() + call_args = mock_selectbox.call_args + tool_box_arg = call_args[0][1] # Second positional arg is the options list + assert tool_box_arg == ["LLM Only"] + assert call_args[1]["index"] == 0 + + def test_nl2sql_selected_becomes_unavailable_resets_to_llm_only(self, app_server, monkeypatch): + """Test that when NL2SQL was selected and database disconnects, it resets to LLM Only.""" + assert app_server is not None + + from client.utils import st_common, tool_options + import streamlit as st + + # Setup: User had previously selected NL2SQL + state.client_settings = { + "tools_enabled": ["NL2SQL"], + "database": {"alias": "DEFAULT"}, + } + + # Mock: Database is not configured + monkeypatch.setattr(st_common, "is_db_configured", lambda: False) + + # Mock Streamlit UI components + mock_sidebar = MagicMock() + monkeypatch.setattr(st, "warning", MagicMock()) + monkeypatch.setattr(st, "sidebar", mock_sidebar) + + # Call tools_sidebar + tool_options.tools_sidebar() + + # Verify the settings were reset to LLM Only + assert state.client_settings["tools_enabled"] == ["LLM Only"] + + def test_vector_search_disabled_no_embedding_models(self, app_server, monkeypatch): + """Test Vector Search is disabled when no embedding models are configured.""" + assert app_server is not None + + from client.utils import st_common, tool_options + import streamlit as st + + # Setup: User selected Vector Search, but embedding models are disabled + state.client_settings = { + "tools_enabled": ["Vector Search"], + "database": {"alias": "DEFAULT"}, + } + state.database_configs = [{"name": "DEFAULT", "vector_stores": [{"model": "embed-model"}]}] + + # Mock: Database is configured but no embedding models enabled + monkeypatch.setattr(st_common, "is_db_configured", lambda: True) + monkeypatch.setattr(st_common, "enabled_models_lookup", lambda x: {}) + monkeypatch.setattr(st_common, "state_configs_lookup", lambda *args: { + "DEFAULT": {"vector_stores": [{"model": "embed-model"}]} + }) + + # Mock Streamlit UI components + mock_sidebar = MagicMock() + monkeypatch.setattr(st, "warning", MagicMock()) + monkeypatch.setattr(st, "sidebar", mock_sidebar) + + # Call tools_sidebar + tool_options.tools_sidebar() + + # Verify the settings were reset to LLM Only (Vector Search disabled) + assert state.client_settings["tools_enabled"] == ["LLM Only"] + + def test_vector_search_disabled_no_matching_vector_stores(self, app_server, monkeypatch): + """Test Vector Search is disabled when vector stores don't match enabled embedding models.""" + assert app_server is not None + + from client.utils import st_common, tool_options + import streamlit as st + + # Setup: User selected Vector Search + state.client_settings = { + "tools_enabled": ["Vector Search"], + "database": {"alias": "DEFAULT"}, + } + + # Mock: Database has vector stores but they use a different embedding model + monkeypatch.setattr(st_common, "is_db_configured", lambda: True) + monkeypatch.setattr(st_common, "enabled_models_lookup", lambda x: {"openai/text-embed-3": {}}) + monkeypatch.setattr(st_common, "state_configs_lookup", lambda *args: { + "DEFAULT": {"vector_stores": [{"model": "cohere/embed-v3"}]} # Different model + }) + + # Mock Streamlit UI components + mock_sidebar = MagicMock() + monkeypatch.setattr(st, "warning", MagicMock()) + monkeypatch.setattr(st, "sidebar", mock_sidebar) + + # Call tools_sidebar + tool_options.tools_sidebar() + + # Verify the settings were reset to LLM Only + assert state.client_settings["tools_enabled"] == ["LLM Only"] + + def test_all_tools_enabled_when_configured(self, app_server, monkeypatch): + """Test all tools remain enabled when properly configured.""" + assert app_server is not None + + from client.utils import st_common, tool_options + import streamlit as st + + # Setup: User has Vector Search selected + state.client_settings = { + "tools_enabled": ["Vector Search"], + "database": {"alias": "DEFAULT"}, + } + + # Mock: Everything is properly configured + monkeypatch.setattr(st_common, "is_db_configured", lambda: True) + monkeypatch.setattr(st_common, "enabled_models_lookup", lambda x: {"openai/text-embed-3": {}}) + monkeypatch.setattr(st_common, "state_configs_lookup", lambda *args: { + "DEFAULT": {"vector_stores": [{"model": "openai/text-embed-3"}]} + }) + + # Mock Streamlit UI components + mock_sidebar = MagicMock() + monkeypatch.setattr(st, "warning", MagicMock()) + monkeypatch.setattr(st, "sidebar", mock_sidebar) + + # Call tools_sidebar + tool_options.tools_sidebar() + + # Verify the settings remain as Vector Search (not reset) + assert state.client_settings["tools_enabled"] == ["Vector Search"] + + # Verify selectbox was called with all tools available + mock_sidebar.selectbox.assert_called_once() + call_args = mock_sidebar.selectbox.call_args + tool_box_arg = call_args[0][1] + assert "LLM Only" in tool_box_arg + assert "Vector Search" in tool_box_arg + assert "NL2SQL" in tool_box_arg + + def test_llm_only_always_available(self, app_server, monkeypatch): + """Test that LLM Only is always in the tool box regardless of configuration.""" + assert app_server is not None + + from client.utils import st_common, tool_options + import streamlit as st + + # Setup: LLM Only selected + state.client_settings = { + "tools_enabled": ["LLM Only"], + "database": {"alias": "DEFAULT"}, + } + + # Mock: Database not configured (disables other tools) + monkeypatch.setattr(st_common, "is_db_configured", lambda: False) + + # Mock Streamlit UI components + mock_sidebar = MagicMock() + monkeypatch.setattr(st, "warning", MagicMock()) + monkeypatch.setattr(st, "sidebar", mock_sidebar) + + # Call tools_sidebar + tool_options.tools_sidebar() + + # Verify LLM Only remains selected (no reset needed) + assert state.client_settings["tools_enabled"] == ["LLM Only"] + + # Verify selectbox has LLM Only available + call_args = mock_sidebar.selectbox.call_args + tool_box_arg = call_args[0][1] + assert "LLM Only" in tool_box_arg + + def test_vector_search_disabled_no_vector_stores(self, app_server, monkeypatch): + """Test Vector Search is disabled when database has no vector stores.""" + assert app_server is not None + + from client.utils import st_common, tool_options + import streamlit as st + + # Setup: User selected Vector Search + state.client_settings = { + "tools_enabled": ["Vector Search"], + "database": {"alias": "DEFAULT"}, + } + + # Mock: Database configured but has no vector stores + monkeypatch.setattr(st_common, "is_db_configured", lambda: True) + monkeypatch.setattr(st_common, "enabled_models_lookup", lambda x: {"openai/text-embed-3": {}}) + monkeypatch.setattr(st_common, "state_configs_lookup", lambda *args: { + "DEFAULT": {"vector_stores": []} # Empty vector stores + }) + + # Mock Streamlit UI components + mock_sidebar = MagicMock() + monkeypatch.setattr(st, "warning", MagicMock()) + monkeypatch.setattr(st, "sidebar", mock_sidebar) + + # Call tools_sidebar + tool_options.tools_sidebar() + + # Verify the settings were reset to LLM Only + assert state.client_settings["tools_enabled"] == ["LLM Only"] From f9fdbc7b87d8b0e5f16781745cbf28bec23a5e0e Mon Sep 17 00:00:00 2001 From: gotsysdba Date: Wed, 3 Dec 2025 10:11:50 +0000 Subject: [PATCH 17/20] update api_call --- .pylintrc | 2 +- src/client/utils/api_call.py | 18 ++-- tests/unit/client/utils/test_api_call_unit.py | 87 ++++++++++--------- 3 files changed, 56 insertions(+), 51 deletions(-) diff --git a/.pylintrc b/.pylintrc index e3218d31..69e62442 100644 --- a/.pylintrc +++ b/.pylintrc @@ -52,7 +52,7 @@ ignore=CVS,.venv # ignore-list. The regex matches against paths and can be in Posix or Windows # format. Because '\\' represents the directory delimiter on Windows systems, # it can't be used as an escape character. -ignore-paths=.*[/\\]wip[/\\].*,src/client/mcp,docs/themes/relearn,docs/public,docs/static/demoware +ignore-paths=.*[/\\]wip[/\\].*,src/client/mcp,docs/themes/relearn,docs/public,docs/static/demoware,src/server/agents # Files or directories matching the regular expression patterns are skipped. # The regex matches against base names, not paths. The default value ignores diff --git a/src/client/utils/api_call.py b/src/client/utils/api_call.py index be33a115..3678d6ad 100644 --- a/src/client/utils/api_call.py +++ b/src/client/utils/api_call.py @@ -53,10 +53,10 @@ def _handle_http_error(ex: requests.exceptions.HTTPError): return failure -def _error_response(message: str) -> dict: - """Display error to user and return error dict.""" +def _error_response(message: str) -> None: + """Display error to user and raise ApiError.""" st.error(f"API Error: {message}") - return {"error": message} + raise ApiError(message) def send_request( @@ -108,20 +108,18 @@ def send_request( except requests.exceptions.HTTPError as ex: logger.error("HTTP Error: %s", ex) - result = _error_response(_handle_http_error(ex)) - break + _error_response(_handle_http_error(ex)) except requests.exceptions.ConnectionError as ex: logger.error("Attempt %d: Connection Error: %s", attempt + 1, ex) if attempt < retries: time.sleep(backoff_factor * (2**attempt)) continue - result = _error_response(f"Connection failed after {retries + 1} attempts") + _error_response(f"Connection failed after {retries + 1} attempts") except (requests.exceptions.RequestException, json.JSONDecodeError, ValueError) as ex: logger.error("Request/JSON Error: %s", ex) - result = _error_response(f"Request failed: {str(ex)}") - break + _error_response(f"Request failed: {str(ex)}") return result if result is not None else _error_response("An unexpected error occurred.") @@ -170,7 +168,7 @@ def patch( retries=retries, backoff_factor=backoff_factor, ) - if toast and "error" not in result: + if toast: st.toast("Update Successful.", icon="✅") time.sleep(1) return result @@ -179,7 +177,7 @@ def patch( def delete(endpoint: str, timeout: int = 60, retries: int = 5, backoff_factor: float = 1.5, toast=True) -> dict: """DELETE Requests""" result = send_request("DELETE", endpoint, timeout=timeout, retries=retries, backoff_factor=backoff_factor) - if toast and "error" not in result: + if toast: st.toast(result.get("message", "Deleted."), icon="✅") time.sleep(1) return result diff --git a/tests/unit/client/utils/test_api_call_unit.py b/tests/unit/client/utils/test_api_call_unit.py index 8a750ce2..708008d5 100644 --- a/tests/unit/client/utils/test_api_call_unit.py +++ b/tests/unit/client/utils/test_api_call_unit.py @@ -3,23 +3,24 @@ Copyright (c) 2024, 2025, Oracle and/or its affiliates. Licensed under the Universal Permissive License v1.0 as shown at http://oss.oracle.com/licenses/upl. -Unit tests for api_call module - focusing on graceful error handling when +Unit tests for api_call module - focusing on error handling when API server is disconnected or returns errors. """ # spell-checker: disable from unittest.mock import MagicMock +import pytest import requests ############################################################################# -# Test Graceful Error Handling on Server Errors +# Test Error Handling Raises ApiError ############################################################################# -class TestGracefulErrorHandling: - """Test that API call functions handle server errors gracefully.""" +class TestErrorHandlingRaisesApiError: + """Test that API call functions raise ApiError on server errors.""" - def test_get_handles_http_500_gracefully(self, app_server, monkeypatch): - """Test that get() handles HTTP 500 errors gracefully without raising.""" + def test_get_raises_api_error_on_http_500(self, app_server, monkeypatch): + """Test that get() raises ApiError on HTTP 500 errors.""" assert app_server is not None from client.utils import api_call @@ -41,17 +42,18 @@ def test_get_handles_http_500_gracefully(self, app_server, monkeypatch): mock_error = MagicMock() monkeypatch.setattr(st, "error", mock_error) - # Call get() - should NOT raise, should return error dict and show error - result = api_call.get(endpoint="v1/test", retries=0) + # Call get() - should raise ApiError + with pytest.raises(api_call.ApiError) as exc_info: + api_call.get(endpoint="v1/test", retries=0) - # Should return error dict (not raise) - assert "error" in result + # Should have the error message + assert "Internal Server Error" in str(exc_info.value) # Should have shown error to user assert mock_error.called - def test_delete_handles_http_500_gracefully(self, app_server, monkeypatch): - """Test that delete() handles HTTP 500 errors gracefully without raising.""" + def test_delete_raises_api_error_on_http_500(self, app_server, monkeypatch): + """Test that delete() raises ApiError on HTTP 500 errors.""" assert app_server is not None from client.utils import api_call @@ -73,17 +75,18 @@ def test_delete_handles_http_500_gracefully(self, app_server, monkeypatch): mock_error = MagicMock() monkeypatch.setattr(st, "error", mock_error) - # Call delete() - should NOT raise, should return error dict and show error - result = api_call.delete(endpoint="v1/test", retries=0, toast=False) + # Call delete() - should raise ApiError + with pytest.raises(api_call.ApiError) as exc_info: + api_call.delete(endpoint="v1/test", retries=0, toast=False) - # Should return error dict (not raise) - assert "error" in result + # Should have the error message + assert "Internal Server Error" in str(exc_info.value) # Should have shown error to user assert mock_error.called - def test_post_handles_http_500_gracefully(self, app_server, monkeypatch): - """Test that post() handles HTTP 500 errors gracefully without raising.""" + def test_post_raises_api_error_on_http_500(self, app_server, monkeypatch): + """Test that post() raises ApiError on HTTP 500 errors.""" assert app_server is not None from client.utils import api_call @@ -105,17 +108,18 @@ def test_post_handles_http_500_gracefully(self, app_server, monkeypatch): mock_error = MagicMock() monkeypatch.setattr(st, "error", mock_error) - # Call post() - should NOT raise, should return error dict and show error - result = api_call.post(endpoint="v1/test", retries=0) + # Call post() - should raise ApiError + with pytest.raises(api_call.ApiError) as exc_info: + api_call.post(endpoint="v1/test", retries=0) - # Should return error dict (not raise) - assert "error" in result + # Should have the error message + assert "Internal Server Error" in str(exc_info.value) # Should have shown error to user assert mock_error.called - def test_patch_handles_http_500_gracefully(self, app_server, monkeypatch): - """Test that patch() handles HTTP 500 errors gracefully without raising.""" + def test_patch_raises_api_error_on_http_500(self, app_server, monkeypatch): + """Test that patch() raises ApiError on HTTP 500 errors.""" assert app_server is not None from client.utils import api_call @@ -137,17 +141,18 @@ def test_patch_handles_http_500_gracefully(self, app_server, monkeypatch): mock_error = MagicMock() monkeypatch.setattr(st, "error", mock_error) - # Call patch() - should NOT raise, should return error dict and show error - result = api_call.patch(endpoint="v1/test", retries=0, toast=False) + # Call patch() - should raise ApiError + with pytest.raises(api_call.ApiError) as exc_info: + api_call.patch(endpoint="v1/test", retries=0, toast=False) - # Should return error dict (not raise) - assert "error" in result + # Should have the error message + assert "Internal Server Error" in str(exc_info.value) # Should have shown error to user assert mock_error.called - def test_get_handles_connection_error_gracefully(self, app_server, monkeypatch): - """Test that get() handles connection errors gracefully after retries exhausted.""" + def test_get_raises_api_error_on_connection_error(self, app_server, monkeypatch): + """Test that get() raises ApiError on connection errors after retries exhausted.""" assert app_server is not None from client.utils import api_call @@ -161,17 +166,18 @@ def test_get_handles_connection_error_gracefully(self, app_server, monkeypatch): mock_error = MagicMock() monkeypatch.setattr(st, "error", mock_error) - # Call get() with no retries - should NOT raise, should return error dict and show error - result = api_call.get(endpoint="v1/test", retries=0) + # Call get() with no retries - should raise ApiError + with pytest.raises(api_call.ApiError) as exc_info: + api_call.get(endpoint="v1/test", retries=0) - # Should return error dict (not raise) - assert "error" in result + # Should have connection failure message + assert "Connection failed" in str(exc_info.value) # Should have shown error to user assert mock_error.called - def test_delete_handles_connection_error_gracefully(self, app_server, monkeypatch): - """Test that delete() handles connection errors gracefully after retries exhausted.""" + def test_delete_raises_api_error_on_connection_error(self, app_server, monkeypatch): + """Test that delete() raises ApiError on connection errors after retries exhausted.""" assert app_server is not None from client.utils import api_call @@ -185,11 +191,12 @@ def test_delete_handles_connection_error_gracefully(self, app_server, monkeypatc mock_error = MagicMock() monkeypatch.setattr(st, "error", mock_error) - # Call delete() with no retries - should NOT raise, should return error dict and show error - result = api_call.delete(endpoint="v1/test", retries=0, toast=False) + # Call delete() with no retries - should raise ApiError + with pytest.raises(api_call.ApiError) as exc_info: + api_call.delete(endpoint="v1/test", retries=0, toast=False) - # Should return error dict (not raise) - assert "error" in result + # Should have connection failure message + assert "Connection failed" in str(exc_info.value) # Should have shown error to user assert mock_error.called From b0ba81336bc6f65173bd2b05372c624856c276c8 Mon Sep 17 00:00:00 2001 From: gotsysdba Date: Wed, 3 Dec 2025 13:13:35 +0000 Subject: [PATCH 18/20] ignores --- tests/integration/client/content/test_api_server.py | 2 +- tests/integration/common/test_functions.py | 2 ++ 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/tests/integration/client/content/test_api_server.py b/tests/integration/client/content/test_api_server.py index 85cb7b11..c080d280 100644 --- a/tests/integration/client/content/test_api_server.py +++ b/tests/integration/client/content/test_api_server.py @@ -1,9 +1,9 @@ -# pylint: disable=protected-access,import-error,import-outside-toplevel """ Copyright (c) 2024, 2025, Oracle and/or its affiliates. Licensed under the Universal Permissive License v1.0 as shown at http://oss.oracle.com/licenses/upl. """ # spell-checker: disable +# pylint: disable=protected-access,import-error,import-outside-toplevel ############################################################################# diff --git a/tests/integration/common/test_functions.py b/tests/integration/common/test_functions.py index 05c2a10b..93b8054e 100644 --- a/tests/integration/common/test_functions.py +++ b/tests/integration/common/test_functions.py @@ -7,6 +7,8 @@ Tests functions that interact with external systems (URLs, databases). These tests may require network access or database connectivity. """ +# spell-checker: disable +# pylint: disable=protected-access,import-error,import-outside-toplevel import os import tempfile From 310ed3daf7d540fd0a2214d3ef1f2289690bb702 Mon Sep 17 00:00:00 2001 From: gotsysdba Date: Thu, 4 Dec 2025 10:22:26 +0000 Subject: [PATCH 19/20] Change testing URL --- tests/integration/common/test_functions.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/integration/common/test_functions.py b/tests/integration/common/test_functions.py index 93b8054e..8ccf913f 100644 --- a/tests/integration/common/test_functions.py +++ b/tests/integration/common/test_functions.py @@ -25,8 +25,8 @@ class TestIsUrlAccessibleIntegration: @pytest.mark.integration def test_real_accessible_url(self): """is_url_accessible should return True for known accessible URLs.""" - # Using httpbin.org which is a testing service - result, msg = functions.is_url_accessible("https://httpbin.org/status/200") + # Using example.com - IANA-maintained domain specifically for testing/documentation + result, msg = functions.is_url_accessible("https://example.com") assert result is True assert msg is None From 931e73187d06d80773b2c0a61cad5ecd4a680d87 Mon Sep 17 00:00:00 2001 From: gotsysdba Date: Thu, 4 Dec 2025 10:26:23 +0000 Subject: [PATCH 20/20] Update to CodeQL to v4 --- .github/workflows/image_smoke.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/image_smoke.yml b/.github/workflows/image_smoke.yml index 077b1afb..55d06494 100644 --- a/.github/workflows/image_smoke.yml +++ b/.github/workflows/image_smoke.yml @@ -92,7 +92,7 @@ jobs: # Upload security results to GitHub Security tab - name: Upload Trivy Results to GitHub Security if: matrix.build.name == 'aio' - uses: github/codeql-action/upload-sarif@v3 + uses: github/codeql-action/upload-sarif@v4 with: sarif_file: trivy-results-aio.sarif category: trivy-aio