33
44import copy
55import json
6- import re
76import uuid
87from typing import Any , Callable , Iterable , Optional , Sequence
98
@@ -175,7 +174,8 @@ async def create(
175174 stmt = "SELECT column_name, data_type FROM information_schema.columns WHERE table_name = :table_name AND table_schema = :schema_name"
176175 async with engine ._pool .connect () as conn :
177176 result = await conn .execute (
178- text (stmt ), {"table_name" : table_name , "schema_name" : schema_name }
177+ text (stmt ),
178+ {"table_name" : table_name , "schema_name" : schema_name },
179179 )
180180 result_map = result .mappings ()
181181 results = result_map .fetchall ()
@@ -535,7 +535,7 @@ async def __query_collection(
535535 embedding : list [float ],
536536 * ,
537537 k : Optional [int ] = None ,
538- filter : Optional [dict ] | Optional [ str ] = None ,
538+ filter : Optional [dict ] = None ,
539539 ** kwargs : Any ,
540540 ) -> Sequence [RowMapping ]:
541541 """Perform similarity search query on database."""
@@ -553,16 +553,22 @@ async def __query_collection(
553553
554554 column_names = ", " .join (f'"{ col } "' for col in columns )
555555
556+ safe_filter = None
557+ filter_dict = None
556558 if filter and isinstance (filter , dict ):
557- filter = self ._create_filter_clause (filter )
558- filter = f"WHERE { filter } " if filter else ""
559+ safe_filter , filter_dict = self ._create_filter_clause (filter )
560+ param_filter = f"WHERE { safe_filter } " if safe_filter else ""
559561 inline_embed_func = getattr (self .embedding_service , "embed_query_inline" , None )
560562 if not embedding and callable (inline_embed_func ) and "query" in kwargs :
561563 query_embedding = self .embedding_service .embed_query_inline (kwargs ["query" ]) # type: ignore
562564 else :
563565 query_embedding = f"{ [float (dimension ) for dimension in embedding ]} "
564- stmt = f'SELECT { column_names } , { search_function } ("{ self .embedding_column } ", :query_embedding) as distance FROM "{ self .schema_name } "."{ self .table_name } " { filter } ORDER BY "{ self .embedding_column } " { operator } :query_embedding LIMIT :k;'
566+ stmt = f"""SELECT { column_names } , { search_function } ("{ self .embedding_column } ", :query_embedding) as distance
567+ FROM "{ self .schema_name } "."{ self .table_name } " { param_filter } ORDER BY "{ self .embedding_column } " { operator } :query_embedding LIMIT :k;
568+ """
565569 param_dict = {"query_embedding" : query_embedding , "k" : k }
570+ if filter_dict :
571+ param_dict .update (filter_dict )
566572 if self .index_query_options :
567573 async with self .engine .connect () as conn :
568574 # Set each query option individually
@@ -583,7 +589,7 @@ async def asimilarity_search(
583589 self ,
584590 query : str ,
585591 k : Optional [int ] = None ,
586- filter : Optional [dict ] | Optional [ str ] = None ,
592+ filter : Optional [dict ] = None ,
587593 ** kwargs : Any ,
588594 ) -> list [Document ]:
589595 """Return docs selected by similarity search on query."""
@@ -614,7 +620,7 @@ async def asimilarity_search_with_score(
614620 self ,
615621 query : str ,
616622 k : Optional [int ] = None ,
617- filter : Optional [dict ] | Optional [ str ] = None ,
623+ filter : Optional [dict ] = None ,
618624 ** kwargs : Any ,
619625 ) -> list [tuple [Document , float ]]:
620626 """Return docs and distance scores selected by similarity search on query."""
@@ -635,7 +641,7 @@ async def asimilarity_search_by_vector(
635641 self ,
636642 embedding : list [float ],
637643 k : Optional [int ] = None ,
638- filter : Optional [dict ] | Optional [ str ] = None ,
644+ filter : Optional [dict ] = None ,
639645 ** kwargs : Any ,
640646 ) -> list [Document ]:
641647 """Return docs selected by vector similarity search."""
@@ -649,7 +655,7 @@ async def asimilarity_search_with_score_by_vector(
649655 self ,
650656 embedding : list [float ],
651657 k : Optional [int ] = None ,
652- filter : Optional [dict ] | Optional [ str ] = None ,
658+ filter : Optional [dict ] = None ,
653659 ** kwargs : Any ,
654660 ) -> list [tuple [Document , float ]]:
655661 """Return docs and distance scores selected by vector similarity search."""
@@ -685,7 +691,7 @@ async def amax_marginal_relevance_search(
685691 k : Optional [int ] = None ,
686692 fetch_k : Optional [int ] = None ,
687693 lambda_mult : Optional [float ] = None ,
688- filter : Optional [dict ] | Optional [ str ] = None ,
694+ filter : Optional [dict ] = None ,
689695 ** kwargs : Any ,
690696 ) -> list [Document ]:
691697 """Return docs selected using the maximal marginal relevance."""
@@ -706,7 +712,7 @@ async def amax_marginal_relevance_search_by_vector(
706712 k : Optional [int ] = None ,
707713 fetch_k : Optional [int ] = None ,
708714 lambda_mult : Optional [float ] = None ,
709- filter : Optional [dict ] | Optional [ str ] = None ,
715+ filter : Optional [dict ] = None ,
710716 ** kwargs : Any ,
711717 ) -> list [Document ]:
712718 """Return docs selected using the maximal marginal relevance."""
@@ -729,7 +735,7 @@ async def amax_marginal_relevance_search_with_score_by_vector(
729735 k : Optional [int ] = None ,
730736 fetch_k : Optional [int ] = None ,
731737 lambda_mult : Optional [float ] = None ,
732- filter : Optional [dict ] | Optional [ str ] = None ,
738+ filter : Optional [dict ] = None ,
733739 ** kwargs : Any ,
734740 ) -> list [tuple [Document , float ]]:
735741 """Return docs and distance scores selected using the maximal marginal relevance."""
@@ -834,7 +840,7 @@ async def is_valid_index(
834840 ) -> bool :
835841 """Check if index exists in the table."""
836842 index_name = index_name or self .table_name + DEFAULT_INDEX_NAME_SUFFIX
837- query = f """
843+ query = """
838844 SELECT tablename, indexname
839845 FROM pg_indexes
840846 WHERE tablename = :table_name AND schemaname = :schema_name AND indexname = :index_name;
@@ -898,7 +904,7 @@ def _handle_field_filter(
898904 * ,
899905 field : str ,
900906 value : Any ,
901- ) -> str :
907+ ) -> tuple [ str , dict ] :
902908 """Create a filter for a specific field.
903909
904910 Args:
@@ -951,15 +957,17 @@ def _handle_field_filter(
951957 if operator in COMPARISONS_TO_NATIVE :
952958 # Then we implement an equality filter
953959 # native is trusted input
954- if isinstance (filter_value , str ):
955- filter_value = f"'{ filter_value } '"
956960 native = COMPARISONS_TO_NATIVE [operator ]
957- return f"({ field } { native } { filter_value } )"
961+ id = str (uuid .uuid4 ()).split ("-" )[0 ]
962+ return f"{ field } { native } :{ field } _{ id } " , {f"{ field } _{ id } " : filter_value }
958963 elif operator == "$between" :
959964 # Use AND with two comparisons
960965 low , high = filter_value
961966
962- return f"({ field } BETWEEN { low } AND { high } )"
967+ return f"({ field } BETWEEN :{ field } _low AND :{ field } _high)" , {
968+ f"{ field } _low" : low ,
969+ f"{ field } _high" : high ,
970+ }
963971 elif operator in {"$in" , "$nin" , "$like" , "$ilike" }:
964972 # We'll do force coercion to text
965973 if operator in {"$in" , "$nin" }:
@@ -975,15 +983,15 @@ def _handle_field_filter(
975983 )
976984
977985 if operator in {"$in" }:
978- values = str (tuple (val for val in filter_value ))
979- return f"({ field } IN { values } )"
986+ return f"{ field } = ANY(:{ field } _in)" , {f"{ field } _in" : filter_value }
980987 elif operator in {"$nin" }:
981- values = str (tuple (val for val in filter_value ))
982- return f"({ field } NOT IN { values } )"
988+ return f"{ field } <> ALL (:{ field } _nin)" , {f"{ field } _nin" : filter_value }
983989 elif operator in {"$like" }:
984- return f"({ field } LIKE ' { filter_value } ')"
990+ return f"({ field } LIKE : { field } _like)" , { f" { field } _like" : filter_value }
985991 elif operator in {"$ilike" }:
986- return f"({ field } ILIKE '{ filter_value } ')"
992+ return f"({ field } ILIKE :{ field } _ilike)" , {
993+ f"{ field } _ilike" : filter_value
994+ }
987995 else :
988996 raise NotImplementedError ()
989997 elif operator == "$exists" :
@@ -994,13 +1002,13 @@ def _handle_field_filter(
9941002 )
9951003 else :
9961004 if filter_value :
997- return f"({ field } IS NOT NULL)"
1005+ return f"({ field } IS NOT NULL)" , {}
9981006 else :
999- return f"({ field } IS NULL)"
1007+ return f"({ field } IS NULL)" , {}
10001008 else :
10011009 raise NotImplementedError ()
10021010
1003- def _create_filter_clause (self , filters : Any ) -> str :
1011+ def _create_filter_clause (self , filters : Any ) -> tuple [ str , dict ] :
10041012 """Create LangChain filter representation to matching SQL where clauses
10051013
10061014 Args:
@@ -1037,7 +1045,11 @@ def _create_filter_clause(self, filters: Any) -> str:
10371045 op = key [1 :].upper () # Extract the operator
10381046 filter_clause = [self ._create_filter_clause (el ) for el in value ]
10391047 if len (filter_clause ) > 1 :
1040- return f"({ f' { op } ' .join (filter_clause )} )"
1048+ all_clauses = [clause [0 ] for clause in filter_clause ]
1049+ params = {}
1050+ for clause in filter_clause :
1051+ params .update (clause [1 ])
1052+ return f"({ f' { op } ' .join (all_clauses )} )" , params
10411053 elif len (filter_clause ) == 1 :
10421054 return filter_clause [0 ]
10431055 else :
@@ -1050,11 +1062,15 @@ def _create_filter_clause(self, filters: Any) -> str:
10501062 not_conditions = [
10511063 self ._create_filter_clause (item ) for item in value
10521064 ]
1053- not_stmts = [f"NOT { condition } " for condition in not_conditions ]
1054- return f"({ ' AND ' .join (not_stmts )} )"
1065+ all_clauses = [clause [0 ] for clause in not_conditions ]
1066+ params = {}
1067+ for clause in not_conditions :
1068+ params .update (clause [1 ])
1069+ not_stmts = [f"NOT { condition } " for condition in all_clauses ]
1070+ return f"({ ' AND ' .join (not_stmts )} )" , params
10551071 elif isinstance (value , dict ):
1056- not_ = self ._create_filter_clause (value )
1057- return f"(NOT { not_ } )"
1072+ not_ , params = self ._create_filter_clause (value )
1073+ return f"(NOT { not_ } )" , params
10581074 else :
10591075 raise ValueError (
10601076 f"Invalid filter condition. Expected a dictionary "
@@ -1077,7 +1093,11 @@ def _create_filter_clause(self, filters: Any) -> str:
10771093 self ._handle_field_filter (field = k , value = v ) for k , v in filters .items ()
10781094 ]
10791095 if len (and_ ) > 1 :
1080- return f"({ ' AND ' .join (and_ )} )"
1096+ all_clauses = [clause [0 ] for clause in and_ ]
1097+ params = {}
1098+ for clause in and_ :
1099+ params .update (clause [1 ])
1100+ return f"({ ' AND ' .join (all_clauses )} )" , params
10811101 elif len (and_ ) == 1 :
10821102 return and_ [0 ]
10831103 else :
@@ -1086,7 +1106,7 @@ def _create_filter_clause(self, filters: Any) -> str:
10861106 "but got an empty dictionary"
10871107 )
10881108 else :
1089- return ""
1109+ return "" , {}
10901110
10911111 def get_by_ids (self , ids : Sequence [str ]) -> list [Document ]:
10921112 raise NotImplementedError (
@@ -1168,7 +1188,7 @@ def similarity_search(
11681188 self ,
11691189 query : str ,
11701190 k : Optional [int ] = None ,
1171- filter : Optional [dict ] | Optional [ str ] = None ,
1191+ filter : Optional [dict ] = None ,
11721192 ** kwargs : Any ,
11731193 ) -> list [Document ]:
11741194 raise NotImplementedError (
@@ -1179,7 +1199,7 @@ def similarity_search_with_score(
11791199 self ,
11801200 query : str ,
11811201 k : Optional [int ] = None ,
1182- filter : Optional [dict ] | Optional [ str ] = None ,
1202+ filter : Optional [dict ] = None ,
11831203 ** kwargs : Any ,
11841204 ) -> list [tuple [Document , float ]]:
11851205 raise NotImplementedError (
@@ -1190,7 +1210,7 @@ def similarity_search_by_vector(
11901210 self ,
11911211 embedding : list [float ],
11921212 k : Optional [int ] = None ,
1193- filter : Optional [dict ] | Optional [ str ] = None ,
1213+ filter : Optional [dict ] = None ,
11941214 ** kwargs : Any ,
11951215 ) -> list [Document ]:
11961216 raise NotImplementedError (
@@ -1201,7 +1221,7 @@ def similarity_search_with_score_by_vector(
12011221 self ,
12021222 embedding : list [float ],
12031223 k : Optional [int ] = None ,
1204- filter : Optional [dict ] | Optional [ str ] = None ,
1224+ filter : Optional [dict ] = None ,
12051225 ** kwargs : Any ,
12061226 ) -> list [tuple [Document , float ]]:
12071227 raise NotImplementedError (
@@ -1214,7 +1234,7 @@ def max_marginal_relevance_search(
12141234 k : Optional [int ] = None ,
12151235 fetch_k : Optional [int ] = None ,
12161236 lambda_mult : Optional [float ] = None ,
1217- filter : Optional [dict ] | Optional [ str ] = None ,
1237+ filter : Optional [dict ] = None ,
12181238 ** kwargs : Any ,
12191239 ) -> list [Document ]:
12201240 raise NotImplementedError (
@@ -1227,7 +1247,7 @@ def max_marginal_relevance_search_by_vector(
12271247 k : Optional [int ] = None ,
12281248 fetch_k : Optional [int ] = None ,
12291249 lambda_mult : Optional [float ] = None ,
1230- filter : Optional [dict ] | Optional [ str ] = None ,
1250+ filter : Optional [dict ] = None ,
12311251 ** kwargs : Any ,
12321252 ) -> list [Document ]:
12331253 raise NotImplementedError (
@@ -1240,7 +1260,7 @@ def max_marginal_relevance_search_with_score_by_vector(
12401260 k : Optional [int ] = None ,
12411261 fetch_k : Optional [int ] = None ,
12421262 lambda_mult : Optional [float ] = None ,
1243- filter : Optional [dict ] | Optional [ str ] = None ,
1263+ filter : Optional [dict ] = None ,
12441264 ** kwargs : Any ,
12451265 ) -> list [tuple [Document , float ]]:
12461266 raise NotImplementedError (
0 commit comments