Skip to content

Commit 63d9b4b

Browse files
authored
Delegate embedding generation to Qdrant (#36)
* feat: delegate embedding generation to Qdrant * feat: upsert media documents for Qdrant embeddings * test: cover loader document embeddings
1 parent ad74850 commit 63d9b4b

File tree

5 files changed

+92
-580
lines changed

5 files changed

+92
-580
lines changed

AGENTS.md

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,11 @@
11
# AGENTS
22

33
## Architecture
4-
- `mcp_plex/loader.py` ingests Plex, TMDb, and IMDb metadata, builds dense and sparse embeddings, and stores items in a Qdrant collection.
4+
- `mcp_plex/loader.py` ingests Plex, TMDb, and IMDb metadata, relies on Qdrant to generate dense and sparse embeddings, and stores items in a Qdrant collection.
55
- `mcp_plex/server.py` exposes retrieval and search tools via FastMCP backed by Qdrant.
66
- `mcp_plex/types.py` defines the Pydantic models used across the project.
77
- When making architectural design decisions, add a short note here describing the decision and its rationale.
8+
- Embedding generation was moved from local FastEmbed models to Qdrant's document API to reduce local dependencies and centralize vector creation.
89
- Actor names are stored as a top-level payload field and indexed in Qdrant to enable actor and year-based filtering.
910
- Dense and sparse embedding model names are configurable via `DENSE_MODEL` and
1011
`SPARSE_MODEL` environment variables or the corresponding CLI options.
@@ -38,6 +39,7 @@ The project should handle natural-language searches and recommendations such as:
3839
- Use realistic (or as realistic as possible) data in tests; avoid meaningless placeholder values.
3940
- Always test both positive and negative logical paths.
4041
- Do **not** use `# pragma: no cover`; add tests to exercise code paths instead.
42+
- All changes should include tests that demonstrate the new or modified behavior.
4143

4244
## Efficiency and Search
4345
- Use `rg` (ripgrep) for recursive search.

mcp_plex/loader.py

Lines changed: 34 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99

1010
import click
1111
import httpx
12-
from fastembed import SparseTextEmbedding, TextEmbedding
1312
from qdrant_client import models
1413
from qdrant_client.async_qdrant_client import AsyncQdrantClient
1514

@@ -311,8 +310,8 @@ async def run(
311310
server = PlexServer(plex_url, plex_token)
312311
items = await _load_from_plex(server, tmdb_api_key)
313312

314-
# Embed and store in Qdrant
315-
texts: List[str] = []
313+
# Assemble points with server-side embeddings
314+
points: List[models.PointStruct] = []
316315
for item in items:
317316
parts = [
318317
item.plex.title,
@@ -325,13 +324,33 @@ async def run(
325324
]
326325
if item.tmdb and hasattr(item.tmdb, "reviews"):
327326
parts.extend(r.get("content", "") for r in getattr(item.tmdb, "reviews", []))
328-
texts.append("\n".join(p for p in parts if p))
329-
330-
dense_model = TextEmbedding(dense_model_name)
331-
sparse_model = SparseTextEmbedding(sparse_model_name)
332-
333-
dense_vectors = list(dense_model.embed(texts))
334-
sparse_vectors = list(sparse_model.passage_embed(texts))
327+
text = "\n".join(p for p in parts if p)
328+
payload = {
329+
"data": item.model_dump(),
330+
"title": item.plex.title,
331+
"type": item.plex.type,
332+
}
333+
if item.plex.actors:
334+
payload["actors"] = [p.tag for p in item.plex.actors]
335+
if item.plex.year is not None:
336+
payload["year"] = item.plex.year
337+
if item.plex.added_at is not None:
338+
payload["added_at"] = item.plex.added_at
339+
point_id: int | str = (
340+
int(item.plex.rating_key)
341+
if item.plex.rating_key.isdigit()
342+
else item.plex.rating_key
343+
)
344+
points.append(
345+
models.PointStruct(
346+
id=point_id,
347+
vector={
348+
"dense": models.Document(text=text, model=dense_model_name),
349+
"sparse": models.Document(text=text, model=sparse_model_name),
350+
},
351+
payload=payload,
352+
)
353+
)
335354

336355
if qdrant_url is None and qdrant_host is None:
337356
qdrant_url = ":memory:"
@@ -344,31 +363,14 @@ async def run(
344363
https=qdrant_https,
345364
prefer_grpc=qdrant_prefer_grpc,
346365
)
366+
dense_size, dense_distance = client._get_model_params(dense_model_name)
347367
collection_name = "media-items"
348-
vectors_config = {
349-
"dense": models.VectorParams(
350-
size=dense_model.embedding_size, distance=models.Distance.COSINE
351-
)
352-
}
353-
sparse_vectors_config = {"sparse": models.SparseVectorParams()}
354-
355368
created_collection = False
356-
if await client.collection_exists(collection_name):
357-
info = await client.get_collection(collection_name)
358-
existing_size = info.config.params.vectors["dense"].size # type: ignore[index]
359-
if existing_size != dense_model.embedding_size:
360-
await client.delete_collection(collection_name)
361-
await client.create_collection(
362-
collection_name=collection_name,
363-
vectors_config=vectors_config,
364-
sparse_vectors_config=sparse_vectors_config,
365-
)
366-
created_collection = True
367-
else:
369+
if not await client.collection_exists(collection_name):
368370
await client.create_collection(
369371
collection_name=collection_name,
370-
vectors_config=vectors_config,
371-
sparse_vectors_config=sparse_vectors_config,
372+
vectors_config={"dense": models.VectorParams(size=dense_size, distance=dense_distance)},
373+
sparse_vectors_config={"sparse": models.SparseVectorParams()},
372374
)
373375
created_collection = True
374376

@@ -419,34 +421,8 @@ async def run(
419421
field_schema=models.PayloadSchemaType.INTEGER,
420422
)
421423

422-
points = []
423-
for idx, (item, dense, sparse) in enumerate(zip(items, dense_vectors, sparse_vectors)):
424-
sv = models.SparseVector(
425-
indices=sparse.indices.tolist(), values=sparse.values.tolist()
426-
)
427-
payload = {
428-
"data": item.model_dump(),
429-
"title": item.plex.title,
430-
"type": item.plex.type,
431-
}
432-
if item.plex.actors:
433-
payload["actors"] = [p.tag for p in item.plex.actors]
434-
if item.plex.year is not None:
435-
payload["year"] = item.plex.year
436-
if item.plex.added_at is not None:
437-
payload["added_at"] = item.plex.added_at
438-
points.append(
439-
models.Record(
440-
id=int(item.plex.rating_key)
441-
if item.plex.rating_key.isdigit()
442-
else item.plex.rating_key,
443-
payload=payload,
444-
vector={"dense": dense, "sparse": sv},
445-
)
446-
)
447-
448424
if points:
449-
await client.upsert(collection_name="media-items", points=points)
425+
await client.upsert(collection_name=collection_name, points=points)
450426

451427
json.dump([item.model_dump() for item in items], fp=sys.stdout, indent=2)
452428
sys.stdout.write("\n")

pyproject.toml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,15 +4,15 @@ build-backend = "setuptools.build_meta"
44

55
[project]
66
name = "mcp-plex"
7-
version = "0.26.1"
7+
version = "0.26.3"
88

99
description = "Plex-Oriented Model Context Protocol Server"
10-
requires-python = ">=3.11,<4"
10+
requires-python = ">=3.11,<3.13"
1111
dependencies = [
1212
"fastmcp>=2.11.2",
1313
"pydantic>=2.11.7",
1414
"plexapi>=4.17.0",
15-
"qdrant-client[fastembed-gpu]>=1.12.1",
15+
"qdrant-client[fastembed-gpu]>=1.15.1",
1616
"rapidfuzz>=3.13.0",
1717
"scikit-learn>=1.7.1",
1818
"httpx>=0.27.0",

tests/test_loader_integration.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,17 +4,23 @@
44
from pathlib import Path
55

66
from qdrant_client.async_qdrant_client import AsyncQdrantClient
7+
from qdrant_client import models
78

89
from mcp_plex import loader
910

1011

1112
class CaptureClient(AsyncQdrantClient):
1213
instance: "CaptureClient" | None = None
14+
captured_points: list[models.PointStruct] = []
1315

1416
def __init__(self, *args, **kwargs):
1517
super().__init__(*args, **kwargs)
1618
CaptureClient.instance = self
1719

20+
async def upsert(self, collection_name: str, points, **kwargs):
21+
CaptureClient.captured_points = points
22+
return await super().upsert(collection_name=collection_name, points=points, **kwargs)
23+
1824

1925
async def _run_loader(sample_dir: Path) -> None:
2026
await loader.run(
@@ -36,5 +42,14 @@ def test_run_writes_points(monkeypatch):
3642
points, _ = asyncio.run(client.scroll("media-items", limit=10, with_payload=True))
3743
assert len(points) == 2
3844
assert all("title" in p.payload and "type" in p.payload for p in points)
45+
captured = CaptureClient.captured_points
46+
assert len(captured) == 2
47+
assert all(isinstance(p.vector["dense"], models.Document) for p in captured)
48+
assert all(p.vector["dense"].model == "BAAI/bge-small-en-v1.5" for p in captured)
49+
assert all(isinstance(p.vector["sparse"], models.Document) for p in captured)
50+
assert all(
51+
p.vector["sparse"].model == "Qdrant/bm42-all-minilm-l6-v2-attentions"
52+
for p in captured
53+
)
3954

4055

0 commit comments

Comments
 (0)