Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
113 changes: 98 additions & 15 deletions pave/stores/txtai_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,22 @@
from __future__ import annotations
import os, json, operator
from datetime import datetime
from typing import Dict, Iterable, List, Any
from typing import Any, Dict, Iterable, List, Optional
from threading import Lock
from contextlib import contextmanager
from txtai.embeddings import Embeddings
from pave.stores.base import BaseStore, Record
from pave.config import CFG as c, LOG as log

_LOCKS : dict[str, Lock] = {}
_SQL_TRANS = str.maketrans({
";": " ",
'"': " ",
"`": " ",
"\\": " ",
"\x00": "",
})

def get_lock(key: str) -> Lock:
if key not in _LOCKS:
_LOCKS[key] = Lock()
Expand Down Expand Up @@ -236,20 +244,22 @@ def index_records(self, tenant: str, collection: str, docid: str,

md["docid"] = docid
try:
meta_json = json.dumps(md, ensure_ascii=False)
md = json.loads(meta_json)
except:
md = {}
safe_meta = self._sanit_meta_dict(md)
meta_json = json.dumps(safe_meta, ensure_ascii=False)
except Exception:
safe_meta = {}
meta_json = ""

rid = str(rid)
txt = str(txt)
if not rid.startswith(f"{docid}::"):
rid = f"{docid}::{rid}"

meta_side[rid] = md
md_for_index = {k: v for k, v in safe_meta.items() if k != "text"}

meta_side[rid] = safe_meta
record_ids.append(rid)
prepared.append((rid, {"text":txt, **md}, meta_json))
prepared.append((rid, {"text": txt, **md_for_index}, meta_json))

self._save_chunk_text(tenant, collection, rid, txt)
assert txt == (self._load_chunk_text(tenant, collection, rid) or "")
Expand Down Expand Up @@ -280,10 +290,15 @@ def _matches_filters(m: Dict[str, Any],
if not filters:
return True

def match(have: Any, cond: str) -> bool:
def match(have: Any, cond: Any) -> bool:
if have is None:
return False
s = str(cond)
if isinstance(have, (list, tuple, set)):
return any(match(item, cond) for item in have)
if isinstance(cond, str):
s = TxtaiStore._sanit_sql(cond)
else:
s = str(cond)
hv = str(have)
# Numeric/date ops
for op in (">=", "<=", "!=", ">", "<"):
Expand Down Expand Up @@ -313,7 +328,7 @@ def match(have: Any, cond: str) -> bool:
return hv == s

for k, vals in filters.items():
if not any(match(m.get(k), v) for v in vals):
if not any(match(TxtaiStore._lookup_meta(m, k), v) for v in vals):
return False
return True

Expand All @@ -325,6 +340,9 @@ def _split_filters(filters: dict[str, Any] | None) -> tuple[dict, dict]:

pre_f, pos_f = {}, {}
for key, vals in (filters or {}).items():
safe_key = TxtaiStore._sanit_field(key)
if not safe_key:
continue
if not isinstance(vals, list):
vals = [vals]
exacts, extended = [], []
Expand All @@ -338,12 +356,68 @@ def _split_filters(filters: dict[str, Any] | None) -> tuple[dict, dict]:
else:
exacts.append(v)
if exacts:
pre_f[key] = exacts
pre_f[safe_key] = exacts
if extended:
pos_f[key] = extended
pos_f[safe_key] = extended
log.debug(f"after split: PRE {pre_f} POS {pos_f}")
return pre_f, pos_f

@staticmethod
def _lookup_meta(meta: Dict[str, Any] | None, key: str) -> Any:
if not meta:
return None
if key in meta:
return meta.get(key)
for raw_key, value in meta.items():
if TxtaiStore._sanit_field(raw_key) == key:
return value
return None

@staticmethod
def _sanit_meta_value(value: Any) -> Any:
if isinstance(value, dict):
return TxtaiStore._sanit_meta_dict(value)
if isinstance(value, (list, tuple, set)):
return [TxtaiStore._sanit_meta_value(v) for v in value]
if isinstance(value, (int, float, bool)) or value is None:
return value
return TxtaiStore._sanit_sql(value)

@staticmethod
def _sanit_meta_dict(meta: Dict[str, Any] | None) -> Dict[str, Any]:
safe: Dict[str, Any] = {}
if not isinstance(meta, dict):
return safe
for raw_key, raw_value in meta.items():
safe_key = TxtaiStore._sanit_field(raw_key)
if not safe_key or safe_key == "text":
continue
safe[safe_key] = TxtaiStore._sanit_meta_value(raw_value)
return safe

@staticmethod
def _sanit_sql(value: Any, *, max_len: Optional[int] = None) -> str:
if value is None:
return ""
text = str(value).translate(_SQL_TRANS)
for token in ("--", "/*", "*/"):
if token in text:
text = text.split(token, 1)[0]
text = text.strip()
if max_len is not None and max_len > 0 and len(text) > max_len:
text = text[:max_len]
return text.replace("'", "''")

@staticmethod
def _sanit_field(name: Any) -> str:
if not isinstance(name, str):
name = str(name)
safe = []
for ch in name:
if ch.isalnum() or ch in {"_", "-"}:
safe.append(ch)
return "".join(safe)

@staticmethod
def _build_sql(query: str, k: int, filters: dict[str, Any], columns: list[str],
with_similarity: bool = True, avoid_duplicates = True) -> str:
Expand All @@ -356,14 +430,23 @@ def _build_sql(query: str, k: int, filters: dict[str, Any], columns: list[str],

wheres = []
if with_similarity and query:
q_safe = query.replace("'", "''")
max_len_cfg = c.get("vector_store.txtai.max_query_chars", 512)
try:
max_len = int(max_len_cfg)
except (TypeError, ValueError):
max_len = 512
limit = max_len if max_len > 0 else None
q_safe = TxtaiStore._sanit_sql(query, max_len=limit)
wheres.append(f"similar('{q_safe}')")

for key, vals in filters.items():
safe_key = TxtaiStore._sanit_field(key)
if not safe_key:
continue
ors = []
for v in vals:
safe_v = str(v).replace("'", "''")
ors.append(f"[{key}] = '{safe_v}'")
safe_v = TxtaiStore._sanit_sql(v)
ors.append(f"[{safe_key}] = '{safe_v}'")
or_safe = " OR ".join(ors)
wheres.append(f"({or_safe})")

Expand Down
16 changes: 16 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,22 @@
# (C) 2025 Rodrigo Rodrigues da Silva <rodrigopitanga@posteo.net>
# SPDX-License-Identifier: GPL-3.0-or-later

import sys
import types

if "txtai.embeddings" not in sys.modules:
txtai_stub = types.ModuleType("txtai")
embeddings_stub = types.ModuleType("txtai.embeddings")

class _StubEmbeddings: # pragma: no cover - stub for optional dependency
def __init__(self, *args, **kwargs):
pass

embeddings_stub.Embeddings = _StubEmbeddings
txtai_stub.embeddings = embeddings_stub
sys.modules.setdefault("txtai", txtai_stub)
sys.modules.setdefault("txtai.embeddings", embeddings_stub)

import pytest
from fastapi.testclient import TestClient
from pave.config import get_cfg, reload_cfg
Expand Down
116 changes: 116 additions & 0 deletions tests/test_txtai_store_sql_safety.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
# (C) 2025 Rodrigo Rodrigues da Silva <rodrigopitanga@posteo.net>
# SPDX-License-Identifier: GPL-3.0-or-later

import json

import pytest

from pave.stores import txtai_store as store_mod
from pave.stores.txtai_store import TxtaiStore
from pave.config import get_cfg
from utils import FakeEmbeddings


@pytest.fixture(autouse=True)
def _fake_embeddings(monkeypatch):
monkeypatch.setattr(store_mod, "Embeddings", FakeEmbeddings, raising=True)


@pytest.fixture()
def store():
return TxtaiStore()


def _extract_similarity_term(sql: str) -> str:
marker = "similar('"
if marker not in sql:
raise AssertionError(f"similar() clause missing in SQL: {sql!r}")
rest = sql.split(marker, 1)[1]
return rest.split("')", 1)[0]


def test_build_sql_sanitizes_similarity_term(store):
raw_query = "foo'; DROP TABLE users; -- comment"
sql = store._build_sql(raw_query, 5, {}, ["id", "text"])
term = _extract_similarity_term(sql)

# injection primitives are stripped or neutralised
assert ";" not in term
assert "--" not in term
# original alpha characters remain so search still works
assert "foo" in term


def test_build_sql_sanitizes_filter_values(store):
filters = {"lang": ["en'; DELETE FROM x;"], "tags": ['alpha"beta']}
sql = store._build_sql("foo", 5, filters, ["id", "text"])

# filter clause should not leak dangerous characters
assert ";" not in sql
assert '"' not in sql
assert "--" not in sql


def test_build_sql_normalises_filter_keys(store):
filters = {"lang]; DROP": ["en"], 123: ["x"]}
sql = store._build_sql("foo", 5, filters, ["id"])
assert "[langDROP]" in sql
assert "[123]" in sql


def test_build_sql_applies_query_length_limit(store):
cfg = get_cfg()
snapshot = cfg.snapshot()
try:
cfg.set("vector_store.txtai.max_query_chars", 8)
sql = store._build_sql("abcdefghijklmno", 5, {}, ["id"])
term = _extract_similarity_term(sql)

# collapse the doubled quotes to measure the original payload length
collapsed = term.replace("''", "'")
assert len(collapsed) == 8
finally:
cfg.replace(data=snapshot)


def test_search_handles_special_characters(store):
tenant, collection = "tenant", "coll"
store.load_or_init(tenant, collection)

records = [("r1", "hello world", {"lang": "en"})]
store.index_records(tenant, collection, "doc", records)

hits = store.search(tenant, collection, "world; -- comment", k=5)
assert hits
assert hits[0]["id"].endswith("::r1")


def test_round_trip_with_weird_metadata_field(store):
tenant, collection = "tenant", "coll"
store.load_or_init(tenant, collection)

weird_key = "meta;`DROP"
weird_value = "val'u"
records = [("r2", "strange world", {weird_key: weird_value})]
store.index_records(tenant, collection, "doc2", records)

filters = {weird_key: weird_value}
hits = store.search(tenant, collection, "strange", k=5, filters=filters)

assert hits
assert hits[0]["id"].endswith("::r2")

emb = store._emb[(tenant, collection)]
safe_key = TxtaiStore._sanit_field(weird_key)
assert emb.last_sql and f"[{safe_key}]" in emb.last_sql

rid = hits[0]["id"]
stored_meta = store._load_meta(tenant, collection).get(rid) or {}
assert safe_key in stored_meta
assert stored_meta[safe_key] == TxtaiStore._sanit_sql(weird_value)

doc = emb._docs[rid]
assert doc["meta"].get(safe_key) == TxtaiStore._sanit_sql(weird_value)
serialized = json.loads(doc["meta_json"]) if doc.get("meta_json") else {}
assert serialized.get(safe_key) == TxtaiStore._sanit_sql(weird_value)
assert hits[0]["meta"].get(safe_key) == TxtaiStore._sanit_sql(weird_value)
Loading