Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
.DS_Store
uv.lock
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
Expand Down
11 changes: 9 additions & 2 deletions byaldi/RAGModel.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@
from typing import Any, Dict, List, Optional, Union

from PIL import Image
from transformers import BitsAndBytesConfig

from byaldi.colpali import ColPaliModel

from byaldi.objects import Result

# Optional langchain integration
Expand Down Expand Up @@ -45,6 +45,7 @@ def from_pretrained(
index_root: str = ".byaldi",
device: str = "cuda",
verbose: int = 1,
quantization_config: BitsAndBytesConfig | None = None,
):
"""Load a ColPali model from a pre-trained checkpoint.

Expand All @@ -61,6 +62,7 @@ def from_pretrained(
index_root=index_root,
device=device,
verbose=verbose,
quantization_config=quantization_config,
)
return instance

Expand All @@ -71,6 +73,7 @@ def from_index(
index_root: str = ".byaldi",
device: str = "cuda",
verbose: int = 1,
quantization_config: BitsAndBytesConfig | None = None,
):
"""Load an Index and the associated ColPali model from an existing document index.

Expand All @@ -84,7 +87,11 @@ def from_index(
instance = cls()
index_path = Path(index_path)
instance.model = ColPaliModel.from_index(
index_path, index_root=index_root, device=device, verbose=verbose
index_path,
index_root=index_root,
device=device,
verbose=verbose,
quantization_config=quantization_config,
)

return instance
Expand Down
6 changes: 6 additions & 0 deletions byaldi/colpali.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from colpali_engine.models import ColPali, ColPaliProcessor, ColQwen2, ColQwen2Processor
from pdf2image import convert_from_path
from PIL import Image
from transformers import BitsAndBytesConfig

from byaldi.objects import Result

Expand All @@ -27,6 +28,7 @@ def __init__(
load_from_index: bool = False,
index_root: str = ".byaldi",
device: Optional[Union[str, torch.device]] = None,
quantization_config: BitsAndBytesConfig | None = None,
**kwargs,
):
if isinstance(pretrained_model_name_or_path, Path):
Expand Down Expand Up @@ -76,6 +78,7 @@ def __init__(
else None
),
token=kwargs.get("hf_token", None) or os.environ.get("HF_TOKEN"),
quantization_config=quantization_config,
)
elif "colqwen2" in pretrained_model_name_or_path.lower():
self.model = ColQwen2.from_pretrained(
Expand All @@ -88,6 +91,7 @@ def __init__(
else None
),
token=kwargs.get("hf_token", None) or os.environ.get("HF_TOKEN"),
quantization_config=quantization_config,
)
self.model = self.model.eval()

Expand Down Expand Up @@ -204,6 +208,7 @@ def from_pretrained(
verbose: int = 1,
device: Optional[Union[str, torch.device]] = None,
index_root: str = ".byaldi",
quantization_config: BitsAndBytesConfig | None = None,
**kwargs,
):
return cls(
Expand All @@ -213,6 +218,7 @@ def from_pretrained(
load_from_index=False,
index_root=index_root,
device=device,
quantization_config=quantization_config,
**kwargs,
)

Expand Down
Loading