diff --git a/libs/elasticsearch/README.md b/libs/elasticsearch/README.md index f56cbad..4d7df5d 100644 --- a/libs/elasticsearch/README.md +++ b/libs/elasticsearch/README.md @@ -119,7 +119,7 @@ A caching layer for LLMs that uses Elasticsearch. Simple example: ```python -from langchain.globals import set_llm_cache +from langchain_core.globals import set_llm_cache from langchain_elasticsearch import ElasticsearchCache @@ -151,7 +151,7 @@ The new cache class can be applied also to a pre-existing cache index: import json from typing import Any, Dict, List -from langchain.globals import set_llm_cache +from langchain_core.globals import set_llm_cache from langchain_core.caches import RETURN_VAL_TYPE from langchain_elasticsearch import ElasticsearchCache diff --git a/libs/elasticsearch/pyproject.toml b/libs/elasticsearch/pyproject.toml index 1b3efe2..44ebe57 100644 --- a/libs/elasticsearch/pyproject.toml +++ b/libs/elasticsearch/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "langchain-elasticsearch" -version = "0.4.0" +version = "0.5.0" description = "An integration package connecting Elasticsearch and LangChain" authors = [] readme = "README.md" @@ -12,8 +12,8 @@ license = "MIT" [tool.poetry.dependencies] python = ">=3.10,<4.0" -langchain-core = "^0.3.0" -elasticsearch = {version = ">=8.16.0,<9.0.0", extras = ["vectorstore_mmr"]} +langchain-core = ">=0.3.0,<2.0.0" +elasticsearch = {version = ">=8.16.0,<9.0.0", extras = ["vectorstore_mmr"]} # Pin <9.0 to match ES 8.x server [tool.poetry.group.test] optional = true @@ -24,8 +24,9 @@ freezegun = "^1.2.2" pytest-mock = "^3.10.0" syrupy = "^4.0.2" pytest-watcher = "^0.3.4" -pytest-asyncio = "^0.21.1" -langchain = ">=0.3.10,<1.0.0" +pytest-asyncio = "^0.23.0" +langchain = ">=1.0.0,<2.0.0" +langchain-classic = ">=1.0.0,<2.0.0" aiohttp = "^3.8.3" [tool.poetry.group.codespell] @@ -74,18 +75,7 @@ requires = ["poetry-core>=1.0.0"] build-backend = "poetry.core.masonry.api" [tool.pytest.ini_options] -# --strict-markers will raise errors on unknown marks. -# https://docs.pytest.org/en/7.1.x/how-to/mark.html#raising-errors-on-unknown-marks -# -# https://docs.pytest.org/en/7.1.x/reference/reference.html -# --strict-config any warnings encountered while parsing the `pytest` -# section of the configuration file raise errors. -# -# https://github.com/tophat/syrupy -# --snapshot-warn-unused Prints a warning on unused snapshots rather than fail the test suite. addopts = "--snapshot-warn-unused --strict-markers --strict-config --durations=5" -# Registering custom markers. -# https://docs.pytest.org/en/7.1.x/example/markers.html#registering-markers markers = [ "requires: mark tests as requiring a specific library", "asyncio: mark tests as requiring asyncio", diff --git a/libs/elasticsearch/tests/integration_tests/_async/test_cache.py b/libs/elasticsearch/tests/integration_tests/_async/test_cache.py index 8feda80..b466879 100644 --- a/libs/elasticsearch/tests/integration_tests/_async/test_cache.py +++ b/libs/elasticsearch/tests/integration_tests/_async/test_cache.py @@ -1,9 +1,9 @@ -from typing import AsyncGenerator, Dict, Union +import json +from typing import Any, AsyncGenerator, Dict, List, Union import pytest from elasticsearch.helpers import BulkIndexError -from langchain.embeddings.cache import _value_serializer -from langchain.globals import set_llm_cache +from langchain_core.globals import set_llm_cache from langchain_core.language_models import BaseChatModel from langchain_elasticsearch import ( @@ -14,6 +14,49 @@ from ._test_utilities import clear_test_indices, create_es_client, read_env +def _value_serializer(value: List[float]) -> bytes: + """Serialize embedding values to bytes (replaces private langchain function).""" + return json.dumps(value).encode() + + +@pytest.fixture(autouse=True) +async def _close_async_caches( + monkeypatch: pytest.MonkeyPatch, +) -> AsyncGenerator[None, None]: + """Ensure cache clients close cleanly to avoid aiohttp warnings.""" + created_clients: List = [] + + original_cache_init = AsyncElasticsearchCache.__init__ + original_store_init = AsyncElasticsearchEmbeddingsCache.__init__ + + def wrapped_cache_init(self, *args: Any, **kwargs: Any) -> None: + original_cache_init(self, *args, **kwargs) + created_clients.append(self._es_client) + + def wrapped_store_init(self, *args: Any, **kwargs: Any) -> None: + original_store_init(self, *args, **kwargs) + created_clients.append(self._es_client) + + monkeypatch.setattr(AsyncElasticsearchCache, "__init__", wrapped_cache_init) + monkeypatch.setattr( + AsyncElasticsearchEmbeddingsCache, "__init__", wrapped_store_init + ) + try: + yield + finally: + for client in created_clients: + close = getattr(client, "close", None) + if close: + try: + await close() + except Exception: + pass + monkeypatch.setattr(AsyncElasticsearchCache, "__init__", original_cache_init) + monkeypatch.setattr( + AsyncElasticsearchEmbeddingsCache, "__init__", original_store_init + ) + + @pytest.fixture async def es_env_fx() -> Union[dict, AsyncGenerator]: params = read_env() diff --git a/libs/elasticsearch/tests/integration_tests/_async/test_chat_history.py b/libs/elasticsearch/tests/integration_tests/_async/test_chat_history.py index 91e5dbe..ff22681 100644 --- a/libs/elasticsearch/tests/integration_tests/_async/test_chat_history.py +++ b/libs/elasticsearch/tests/integration_tests/_async/test_chat_history.py @@ -2,8 +2,15 @@ from typing import AsyncIterator import pytest -from langchain.memory import ConversationBufferMemory +from elasticsearch import AsyncElasticsearch from langchain_core.messages import AIMessage, HumanMessage, message_to_dict +from langchain_classic.memory import ConversationBufferMemory + +pytestmark = [ + pytest.mark.filterwarnings( + "ignore:Please see the migration guide.*:langchain_core._api.deprecation.LangChainDeprecationWarning" + ) +] from langchain_elasticsearch.chat_history import AsyncElasticsearchChatMessageHistory @@ -23,11 +30,11 @@ class TestElasticsearch: @pytest.fixture - async def elasticsearch_connection(self) -> AsyncIterator[dict]: + async def elasticsearch_connection(self) -> AsyncIterator[AsyncElasticsearch]: params = read_env() es = create_es_client(params) - yield params + yield es await clear_test_indices(es) await es.close() @@ -38,12 +45,14 @@ def index_name(self) -> str: return f"test_{uuid.uuid4().hex}" async def test_memory_with_message_store( - self, elasticsearch_connection: dict, index_name: str + self, elasticsearch_connection: AsyncElasticsearch, index_name: str ) -> None: """Test the memory with a message store.""" # setup Elasticsearch as a message store message_history = AsyncElasticsearchChatMessageHistory( - **elasticsearch_connection, index=index_name, session_id="test-session" + es_connection=elasticsearch_connection, + index=index_name, + session_id="test-session", ) memory = ConversationBufferMemory( diff --git a/libs/elasticsearch/tests/integration_tests/_async/test_retrievers.py b/libs/elasticsearch/tests/integration_tests/_async/test_retrievers.py index f54c8e0..5c6d10f 100644 --- a/libs/elasticsearch/tests/integration_tests/_async/test_retrievers.py +++ b/libs/elasticsearch/tests/integration_tests/_async/test_retrievers.py @@ -69,7 +69,7 @@ async def test_user_agent_header( ), f"The string '{user_agent}' does not match the expected pattern." await index_test_data(es_client, index_name, "text") - await retriever.aget_relevant_documents("foo") + await retriever.ainvoke("foo") search_request = es_client.transport.requests[-1] # type: ignore[attr-defined] user_agent = search_request["headers"]["User-Agent"] @@ -133,7 +133,7 @@ def body_func(query: str) -> Dict: ) await index_test_data(es_client, index_name, text_field) - result = await retriever.aget_relevant_documents("foo") + result = await retriever.ainvoke("foo") assert {r.page_content for r in result} == {"foo", "foo bar", "foo baz"} assert {r.metadata["_id"] for r in result} == {"3", "1", "5"} @@ -171,7 +171,7 @@ def body_func(query: str) -> Dict: await index_test_data(es_client, index_name_1, text_field_1) await index_test_data(es_client, index_name_2, text_field_2) - result = await retriever.aget_relevant_documents("foo") + result = await retriever.ainvoke("foo") # matches from both indices assert sorted([(r.page_content, r.metadata["_index"]) for r in result]) == [ @@ -206,7 +206,7 @@ def id_as_content(hit: Mapping[str, Any]) -> Document: ) await index_test_data(es_client, index_name, text_field) - result = await retriever.aget_relevant_documents("foo") + result = await retriever.ainvoke("foo") assert [r.page_content for r in result] == ["3", "1", "5"] assert [r.metadata for r in result] == [meta, meta, meta] diff --git a/libs/elasticsearch/tests/integration_tests/_async/test_vectorstores.py b/libs/elasticsearch/tests/integration_tests/_async/test_vectorstores.py index e778e17..ed24fde 100644 --- a/libs/elasticsearch/tests/integration_tests/_async/test_vectorstores.py +++ b/libs/elasticsearch/tests/integration_tests/_async/test_vectorstores.py @@ -14,6 +14,11 @@ from ._test_utilities import clear_test_indices, create_es_client, read_env logging.basicConfig(level=logging.DEBUG) +pytestmark = [ + pytest.mark.filterwarnings( + "ignore:Deprecated field \\[rank\\] used, replaced by \\[retriever\\]:elasticsearch.ElasticsearchWarning" + ) +] """ cd tests/integration_tests @@ -27,6 +32,28 @@ class TestElasticsearch: + @pytest.fixture(autouse=True) + async def _close_async_stores(self, monkeypatch: pytest.MonkeyPatch) -> AsyncIterator[None]: + created: list[AsyncElasticsearchStore] = [] + original_init = AsyncElasticsearchStore.__init__ + + def wrapped_init(self, *args: Any, **kwargs: Any) -> None: # type: ignore[misc] + original_init(self, *args, **kwargs) + created.append(self) + + monkeypatch.setattr(AsyncElasticsearchStore, "__init__", wrapped_init) + try: + yield + finally: + for store in created: + aclose = getattr(store, "aclose", None) + if aclose: + try: + await aclose() + except Exception: + pass + monkeypatch.setattr(AsyncElasticsearchStore, "__init__", original_init) + @pytest.fixture async def es_params(self) -> AsyncIterator[dict]: params = read_env() @@ -104,7 +131,7 @@ async def test_search_with_relevance_threshold( search_type="similarity_score_threshold", search_kwargs={"score_threshold": similarity_of_second_ranked}, ) - output = await retriever.aget_relevant_documents(query=query_string) + output = await retriever.ainvoke(query_string) assert output == [ top3[0][0], @@ -145,7 +172,7 @@ async def test_search_by_vector_with_relevance_threshold( search_type="similarity_score_threshold", search_kwargs={"score_threshold": similarity_of_second_ranked}, ) - output = await retriever.aget_relevant_documents(query=query_string) + output = await retriever.ainvoke(query_string) assert output == [ top3[0][0], @@ -1081,7 +1108,7 @@ async def test_elasticsearch_with_relevance_threshold( search_type="similarity_score_threshold", search_kwargs={"score_threshold": similarity_of_second_ranked}, ) - output = await retriever.aget_relevant_documents(query=query_string) + output = await retriever.ainvoke(query_string) assert output == [ top3[0][0], diff --git a/libs/elasticsearch/tests/integration_tests/_sync/test_cache.py b/libs/elasticsearch/tests/integration_tests/_sync/test_cache.py index c393cfd..d788163 100644 --- a/libs/elasticsearch/tests/integration_tests/_sync/test_cache.py +++ b/libs/elasticsearch/tests/integration_tests/_sync/test_cache.py @@ -1,9 +1,9 @@ -from typing import Dict, Generator, Union +import json +from typing import Dict, Generator, List, Union import pytest from elasticsearch.helpers import BulkIndexError -from langchain.embeddings.cache import _value_serializer -from langchain.globals import set_llm_cache +from langchain_core.globals import set_llm_cache from langchain_core.language_models import BaseChatModel from langchain_elasticsearch import ( @@ -11,9 +11,15 @@ ElasticsearchEmbeddingsCache, ) + from ._test_utilities import clear_test_indices, create_es_client, read_env +def _value_serializer(value: List[float]) -> bytes: + """Serialize embedding values to bytes (replaces private langchain function).""" + return json.dumps(value).encode() + + @pytest.fixture def es_env_fx() -> Union[dict, Generator]: params = read_env() diff --git a/libs/elasticsearch/tests/integration_tests/_sync/test_chat_history.py b/libs/elasticsearch/tests/integration_tests/_sync/test_chat_history.py index ec9793b..d26f17a 100644 --- a/libs/elasticsearch/tests/integration_tests/_sync/test_chat_history.py +++ b/libs/elasticsearch/tests/integration_tests/_sync/test_chat_history.py @@ -2,8 +2,14 @@ from typing import Iterator import pytest -from langchain.memory import ConversationBufferMemory from langchain_core.messages import AIMessage, HumanMessage, message_to_dict +from langchain_classic.memory import ConversationBufferMemory + +pytestmark = [ + pytest.mark.filterwarnings( + "ignore:Please see the migration guide.*:langchain_core._api.deprecation.LangChainDeprecationWarning" + ) +] from langchain_elasticsearch.chat_history import ElasticsearchChatMessageHistory diff --git a/libs/elasticsearch/tests/integration_tests/_sync/test_retrievers.py b/libs/elasticsearch/tests/integration_tests/_sync/test_retrievers.py index 457b1da..3d068b5 100644 --- a/libs/elasticsearch/tests/integration_tests/_sync/test_retrievers.py +++ b/libs/elasticsearch/tests/integration_tests/_sync/test_retrievers.py @@ -65,7 +65,7 @@ def test_user_agent_header(self, es_client: Elasticsearch, index_name: str) -> N ), f"The string '{user_agent}' does not match the expected pattern." index_test_data(es_client, index_name, "text") - retriever.get_relevant_documents("foo") + retriever.invoke("foo") search_request = es_client.transport.requests[-1] # type: ignore[attr-defined] user_agent = search_request["headers"]["User-Agent"] @@ -127,7 +127,7 @@ def body_func(query: str) -> Dict: ) index_test_data(es_client, index_name, text_field) - result = retriever.get_relevant_documents("foo") + result = retriever.invoke("foo") assert {r.page_content for r in result} == {"foo", "foo bar", "foo baz"} assert {r.metadata["_id"] for r in result} == {"3", "1", "5"} @@ -165,7 +165,7 @@ def body_func(query: str) -> Dict: index_test_data(es_client, index_name_1, text_field_1) index_test_data(es_client, index_name_2, text_field_2) - result = retriever.get_relevant_documents("foo") + result = retriever.invoke("foo") # matches from both indices assert sorted([(r.page_content, r.metadata["_index"]) for r in result]) == [ @@ -198,7 +198,7 @@ def id_as_content(hit: Mapping[str, Any]) -> Document: ) index_test_data(es_client, index_name, text_field) - result = retriever.get_relevant_documents("foo") + result = retriever.invoke("foo") assert [r.page_content for r in result] == ["3", "1", "5"] assert [r.metadata for r in result] == [meta, meta, meta] diff --git a/libs/elasticsearch/tests/integration_tests/_sync/test_vectorstores.py b/libs/elasticsearch/tests/integration_tests/_sync/test_vectorstores.py index 0ef9c8a..430f5c1 100644 --- a/libs/elasticsearch/tests/integration_tests/_sync/test_vectorstores.py +++ b/libs/elasticsearch/tests/integration_tests/_sync/test_vectorstores.py @@ -14,6 +14,11 @@ from ._test_utilities import clear_test_indices, create_es_client, read_env logging.basicConfig(level=logging.DEBUG) +pytestmark = [ + pytest.mark.filterwarnings( + "ignore:Deprecated field \\[rank\\] used, replaced by \\[retriever\\]:elasticsearch.ElasticsearchWarning" + ) +] """ cd tests/integration_tests @@ -104,7 +109,7 @@ def test_search_with_relevance_threshold( search_type="similarity_score_threshold", search_kwargs={"score_threshold": similarity_of_second_ranked}, ) - output = retriever.get_relevant_documents(query=query_string) + output = retriever.invoke(query_string) assert output == [ top3[0][0], @@ -145,7 +150,7 @@ def test_search_by_vector_with_relevance_threshold( search_type="similarity_score_threshold", search_kwargs={"score_threshold": similarity_of_second_ranked}, ) - output = retriever.get_relevant_documents(query=query_string) + output = retriever.invoke(query_string) assert output == [ top3[0][0], @@ -1061,7 +1066,7 @@ def test_elasticsearch_with_relevance_threshold( search_type="similarity_score_threshold", search_kwargs={"score_threshold": similarity_of_second_ranked}, ) - output = retriever.get_relevant_documents(query=query_string) + output = retriever.invoke(query_string) assert output == [ top3[0][0], diff --git a/libs/elasticsearch/tests/unit_tests/_async/test_cache.py b/libs/elasticsearch/tests/unit_tests/_async/test_cache.py index fa83ceb..c201354 100644 --- a/libs/elasticsearch/tests/unit_tests/_async/test_cache.py +++ b/libs/elasticsearch/tests/unit_tests/_async/test_cache.py @@ -1,5 +1,6 @@ +import json from datetime import datetime -from typing import Any, Dict +from typing import Any, Dict, List from unittest import mock from unittest.mock import ANY, MagicMock, patch @@ -7,7 +8,6 @@ from _pytest.fixtures import FixtureRequest from elastic_transport import ApiResponseMeta, HttpHeaders, NodeConfig from elasticsearch import NotFoundError -from langchain.embeddings.cache import _value_serializer from langchain_core.load import dumps from langchain_core.outputs import Generation @@ -17,6 +17,11 @@ ) +def _value_serializer(value: List[float]) -> bytes: + """Serialize embedding values to bytes (replaces private langchain function).""" + return json.dumps(value).encode() + + def serialize_encode_vector(vector: Any) -> str: return AsyncElasticsearchEmbeddingsCache.encode_vector(_value_serializer(vector)) diff --git a/libs/elasticsearch/tests/unit_tests/_sync/test_cache.py b/libs/elasticsearch/tests/unit_tests/_sync/test_cache.py index 03ca7b5..7942943 100644 --- a/libs/elasticsearch/tests/unit_tests/_sync/test_cache.py +++ b/libs/elasticsearch/tests/unit_tests/_sync/test_cache.py @@ -1,5 +1,6 @@ +import json from datetime import datetime -from typing import Any, Dict +from typing import Any, Dict, List from unittest import mock from unittest.mock import ANY, MagicMock, patch @@ -7,7 +8,6 @@ from _pytest.fixtures import FixtureRequest from elastic_transport import ApiResponseMeta, HttpHeaders, NodeConfig from elasticsearch import NotFoundError -from langchain.embeddings.cache import _value_serializer from langchain_core.load import dumps from langchain_core.outputs import Generation @@ -17,6 +17,11 @@ ) +def _value_serializer(value: List[float]) -> bytes: + """Serialize embedding values to bytes (replaces private langchain function).""" + return json.dumps(value).encode() + + def serialize_encode_vector(vector: Any) -> str: return ElasticsearchEmbeddingsCache.encode_vector(_value_serializer(vector))