Skip to content

Commit 36650dd

Browse files
authored
Collections: Async completion (#243)
* routes and deps * small fix * logging * adding org id and project in crud tests * migration file * test cases * test cases * formatting * test case fix * migration and pascalcase * pr review fixes * vector store fix
1 parent d8e5a27 commit 36650dd

File tree

13 files changed

+501
-85
lines changed

13 files changed

+501
-85
lines changed
Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
"""add/alter columns in collections table
2+
3+
Revision ID: 3389c67fdcb4
4+
Revises: 8757b005d681
5+
Create Date: 2025-06-20 18:08:16.585843
6+
7+
"""
8+
from alembic import op
9+
import sqlalchemy as sa
10+
import sqlmodel.sql.sqltypes
11+
from sqlalchemy.dialects import postgresql
12+
13+
14+
# revision identifiers, used by Alembic.
15+
revision = "3389c67fdcb4"
16+
down_revision = "8757b005d681"
17+
branch_labels = None
18+
depends_on = None
19+
20+
collection_status_enum = postgresql.ENUM(
21+
"processing",
22+
"successful",
23+
"failed",
24+
name="collectionstatus",
25+
create_type=False, # we create manually to avoid duplicate issues
26+
)
27+
28+
29+
def upgrade():
30+
collection_status_enum.create(op.get_bind(), checkfirst=True)
31+
op.add_column(
32+
"collection", sa.Column("organization_id", sa.Integer(), nullable=False)
33+
)
34+
op.add_column("collection", sa.Column("project_id", sa.Integer(), nullable=True))
35+
op.add_column(
36+
"collection",
37+
sa.Column(
38+
"status",
39+
collection_status_enum,
40+
nullable=False,
41+
server_default="processing",
42+
),
43+
)
44+
op.add_column("collection", sa.Column("updated_at", sa.DateTime(), nullable=False))
45+
op.alter_column(
46+
"collection", "llm_service_id", existing_type=sa.VARCHAR(), nullable=True
47+
)
48+
op.alter_column(
49+
"collection", "llm_service_name", existing_type=sa.VARCHAR(), nullable=True
50+
)
51+
op.create_foreign_key(
52+
None,
53+
"collection",
54+
"organization",
55+
["organization_id"],
56+
["id"],
57+
ondelete="CASCADE",
58+
)
59+
op.create_foreign_key(
60+
None, "collection", "project", ["project_id"], ["id"], ondelete="CASCADE"
61+
)
62+
63+
64+
def downgrade():
65+
op.drop_constraint(None, "collection", type_="foreignkey")
66+
op.drop_constraint(None, "collection", type_="foreignkey")
67+
op.alter_column(
68+
"collection", "llm_service_name", existing_type=sa.VARCHAR(), nullable=False
69+
)
70+
op.alter_column(
71+
"collection", "llm_service_id", existing_type=sa.VARCHAR(), nullable=False
72+
)
73+
op.drop_column("collection", "updated_at")
74+
op.drop_column("collection", "status")
75+
op.drop_column("collection", "project_id")
76+
op.drop_column("collection", "organization_id")

backend/app/api/deps.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,30 @@ def get_current_user_org(
107107
CurrentUserOrg = Annotated[UserOrganization, Depends(get_current_user_org)]
108108

109109

110+
def get_current_user_org_project(
111+
current_user: CurrentUser, session: SessionDep, request: Request
112+
) -> UserProjectOrg:
113+
api_key = request.headers.get("X-API-KEY")
114+
organization_id = None
115+
project_id = None
116+
117+
if api_key:
118+
api_key_record = get_api_key_by_value(session, api_key)
119+
if api_key_record:
120+
validate_organization(session, api_key_record.organization_id)
121+
organization_id = api_key_record.organization_id
122+
project_id = api_key_record.project_id
123+
124+
return UserProjectOrg(
125+
**current_user.model_dump(),
126+
organization_id=organization_id,
127+
project_id=project_id,
128+
)
129+
130+
131+
CurrentUserOrgProject = Annotated[UserProjectOrg, Depends(get_current_user_org_project)]
132+
133+
110134
def get_current_active_superuser(current_user: CurrentUser) -> User:
111135
if not current_user.is_superuser:
112136
raise HTTPException(

backend/app/api/routes/collections.py

Lines changed: 76 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import inspect
22
import logging
3+
import time
34
import warnings
45
from uuid import UUID, uuid4
56
from typing import Any, List, Optional
@@ -11,13 +12,14 @@
1112
from pydantic import BaseModel, Field, HttpUrl
1213
from sqlalchemy.exc import NoResultFound, MultipleResultsFound, SQLAlchemyError
1314

14-
from app.api.deps import CurrentUser, SessionDep
15+
from app.api.deps import CurrentUser, SessionDep, CurrentUserOrgProject
1516
from app.core.cloud import AmazonCloudStorage
1617
from app.core.config import settings
1718
from app.core.util import now, raise_from_unknown, post_callback
1819
from app.crud import DocumentCrud, CollectionCrud, DocumentCollectionCrud
1920
from app.crud.rag import OpenAIVectorStoreCrud, OpenAIAssistantCrud
2021
from app.models import Collection, Document
22+
from app.models.collection import CollectionStatus
2123
from app.utils import APIResponse, load_description
2224

2325
router = APIRouter(prefix="/collections", tags=["collections"])
@@ -173,61 +175,77 @@ def do_create_collection(
173175
request: CreationRequest,
174176
payload: ResponsePayload,
175177
):
178+
start_time = time.time()
176179
client = OpenAI(api_key=settings.OPENAI_API_KEY)
177-
if request.callback_url is None:
178-
callback = SilentCallback(payload)
179-
else:
180-
callback = WebHookCallback(request.callback_url, payload)
181-
182-
#
183-
# Create the assistant and vector store
184-
#
185-
186-
vector_store_crud = OpenAIVectorStoreCrud(client)
187-
try:
188-
vector_store = vector_store_crud.create()
189-
except OpenAIError as err:
190-
callback.fail(str(err))
191-
return
180+
callback = (
181+
SilentCallback(payload)
182+
if request.callback_url is None
183+
else WebHookCallback(request.callback_url, payload)
184+
)
192185

193186
storage = AmazonCloudStorage(current_user)
194187
document_crud = DocumentCrud(session, current_user.id)
195188
assistant_crud = OpenAIAssistantCrud(client)
189+
vector_store_crud = OpenAIVectorStoreCrud(client)
190+
collection_crud = CollectionCrud(session, current_user.id)
196191

197-
docs = request(document_crud)
198-
kwargs = dict(request.extract_super_type(AssistantOptions))
199192
try:
200-
updates = vector_store_crud.update(vector_store.id, storage, docs)
201-
documents = list(updates)
202-
assistant = assistant_crud.create(vector_store.id, **kwargs)
203-
except Exception as err: # blanket to handle SQL and OpenAI errors
204-
logging.error(f"File Search setup error: {err} ({type(err).__name__})")
205-
vector_store_crud.delete(vector_store.id)
206-
callback.fail(str(err))
207-
return
193+
vector_store = vector_store_crud.create()
208194

209-
#
210-
# Store the results
211-
#
195+
docs = list(request(document_crud))
196+
flat_docs = [doc for sublist in docs for doc in sublist]
212197

213-
collection_crud = CollectionCrud(session, current_user.id)
214-
collection = Collection(
215-
id=UUID(payload.key),
216-
llm_service_id=assistant.id,
217-
llm_service_name=request.model,
218-
)
219-
try:
220-
collection_crud.create(collection, documents)
221-
except SQLAlchemyError as err:
222-
_backout(assistant_crud, assistant.id)
223-
callback.fail(str(err))
224-
return
198+
file_exts = {doc.fname.split(".")[-1] for doc in flat_docs if "." in doc.fname}
199+
file_sizes_kb = [
200+
storage.get_file_size_kb(doc.object_store_url) for doc in flat_docs
201+
]
202+
203+
logging.info(
204+
f"[VectorStore Update] Uploading {len(flat_docs)} documents to vector store {vector_store.id}"
205+
)
206+
list(vector_store_crud.update(vector_store.id, storage, docs))
207+
logging.info(f"[VectorStore Upload] Upload completed")
208+
209+
assistant_options = dict(request.extract_super_type(AssistantOptions))
210+
logging.info(
211+
f"[Assistant Create] Creating assistant with options: {assistant_options}"
212+
)
213+
assistant = assistant_crud.create(vector_store.id, **assistant_options)
214+
logging.info(f"[Assistant Create] Assistant created: {assistant.id}")
215+
216+
collection = collection_crud.read_one(UUID(payload.key))
217+
collection.llm_service_id = assistant.id
218+
collection.llm_service_name = request.model
219+
collection.status = CollectionStatus.successful
220+
collection.updated_at = now()
221+
222+
if flat_docs:
223+
logging.info(
224+
f"[DocumentCollection] Linking {len(flat_docs)} documents to collection {collection.id}"
225+
)
226+
DocumentCollectionCrud(session).create(collection, flat_docs)
227+
228+
collection_crud._update(collection)
225229

226-
#
227-
# Send back successful response
228-
#
230+
elapsed = time.time() - start_time
231+
logging.info(
232+
f"Collection created: {collection.id} | Time: {elapsed:.2f}s | "
233+
f"Files: {len(flat_docs)} | Sizes: {file_sizes_kb} KB | Types: {list(file_exts)}"
234+
)
235+
callback.success(collection.model_dump(mode="json"))
229236

230-
callback.success(collection.model_dump(mode="json"))
237+
except Exception as err:
238+
logging.error(f"[Collection Creation Failed] {err} ({type(err).__name__})")
239+
if "assistant" in locals():
240+
_backout(assistant_crud, assistant.id)
241+
try:
242+
collection = collection_crud.read_one(UUID(payload.key))
243+
collection.status = CollectionStatus.failed
244+
collection.updated_at = now()
245+
collection_crud._update(collection)
246+
except Exception as suberr:
247+
logging.warning(f"[Collection Status Update Failed] {suberr}")
248+
callback.fail(str(err))
231249

232250

233251
@router.post(
@@ -236,14 +254,26 @@ def do_create_collection(
236254
)
237255
def create_collection(
238256
session: SessionDep,
239-
current_user: CurrentUser,
257+
current_user: CurrentUserOrgProject,
240258
request: CreationRequest,
241259
background_tasks: BackgroundTasks,
242260
):
243261
this = inspect.currentframe()
244262
route = router.url_path_for(this.f_code.co_name)
245263
payload = ResponsePayload("processing", route)
246264

265+
collection = Collection(
266+
id=UUID(payload.key),
267+
owner_id=current_user.id,
268+
organization_id=current_user.organization_id,
269+
project_id=current_user.project_id,
270+
status=CollectionStatus.processing,
271+
)
272+
273+
collection_crud = CollectionCrud(session, current_user.id)
274+
collection_crud.create(collection)
275+
276+
# 2. Launch background task
247277
background_tasks.add_task(
248278
do_create_collection,
249279
session,

backend/app/core/cloud/storage.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
import os
22

3-
# import logging
43
import functools as ft
54
from pathlib import Path
65
from dataclasses import dataclass, asdict
@@ -125,6 +124,13 @@ def stream(self, url: str) -> StreamingBody:
125124
except ClientError as err:
126125
raise CloudStorageError(f'AWS Error: "{err}" ({url})') from err
127126

127+
def get_file_size_kb(self, url: str) -> float:
128+
name = SimpleStorageName.from_url(url)
129+
kwargs = asdict(name)
130+
response = self.aws.client.head_object(**kwargs)
131+
size_bytes = response["ContentLength"]
132+
return round(size_bytes / 1024, 2)
133+
128134
def delete(self, url: str) -> None:
129135
name = SimpleStorageName.from_url(url)
130136
kwargs = asdict(name)

backend/app/crud/collection.py

Lines changed: 20 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,12 @@
11
import functools as ft
22
from uuid import UUID
33
from typing import Optional
4-
4+
import logging
55
from sqlmodel import Session, func, select, and_
66

77
from app.models import Document, Collection, DocumentCollection
88
from app.core.util import now
9+
from app.models.collection import CollectionStatus
910

1011
from .document_collection import DocumentCollectionCrud
1112

@@ -43,13 +44,24 @@ def _exists(self, collection: Collection):
4344

4445
return bool(present)
4546

46-
def create(self, collection: Collection, documents: list[Document]):
47-
if self._exists(collection):
48-
raise FileExistsError("Collection already present")
49-
50-
collection = self._update(collection)
51-
dc_crud = DocumentCollectionCrud(self.session)
52-
dc_crud.create(collection, documents)
47+
def create(
48+
self,
49+
collection: Collection,
50+
documents: Optional[list[Document]] = None,
51+
):
52+
try:
53+
existing = self.read_one(collection.id)
54+
if existing.status == CollectionStatus.failed:
55+
self._update(collection)
56+
else:
57+
raise FileExistsError("Collection already present")
58+
except:
59+
self.session.add(collection)
60+
self.session.commit()
61+
62+
if documents:
63+
dc_crud = DocumentCollectionCrud(self.session)
64+
dc_crud.create(collection, documents)
5365

5466
return collection
5567

0 commit comments

Comments
 (0)