diff --git a/.gitignore b/.gitignore index 10ef763..7b7971b 100644 --- a/.gitignore +++ b/.gitignore @@ -1,4 +1,5 @@ .DS_Store +uv.lock # Byte-compiled / optimized / DLL files __pycache__/ *.py[cod] diff --git a/byaldi/RAGModel.py b/byaldi/RAGModel.py index 32b66bf..002cd6c 100644 --- a/byaldi/RAGModel.py +++ b/byaldi/RAGModel.py @@ -2,9 +2,9 @@ from typing import Any, Dict, List, Optional, Union from PIL import Image +from transformers import BitsAndBytesConfig from byaldi.colpali import ColPaliModel - from byaldi.objects import Result # Optional langchain integration @@ -45,6 +45,7 @@ def from_pretrained( index_root: str = ".byaldi", device: str = "cuda", verbose: int = 1, + quantization_config: BitsAndBytesConfig | None = None, ): """Load a ColPali model from a pre-trained checkpoint. @@ -61,6 +62,7 @@ def from_pretrained( index_root=index_root, device=device, verbose=verbose, + quantization_config=quantization_config, ) return instance @@ -71,6 +73,7 @@ def from_index( index_root: str = ".byaldi", device: str = "cuda", verbose: int = 1, + quantization_config: BitsAndBytesConfig | None = None, ): """Load an Index and the associated ColPali model from an existing document index. @@ -84,7 +87,11 @@ def from_index( instance = cls() index_path = Path(index_path) instance.model = ColPaliModel.from_index( - index_path, index_root=index_root, device=device, verbose=verbose + index_path, + index_root=index_root, + device=device, + verbose=verbose, + quantization_config=quantization_config, ) return instance diff --git a/byaldi/colpali.py b/byaldi/colpali.py index cc11dcb..e753a13 100644 --- a/byaldi/colpali.py +++ b/byaldi/colpali.py @@ -10,6 +10,7 @@ from colpali_engine.models import ColPali, ColPaliProcessor, ColQwen2, ColQwen2Processor from pdf2image import convert_from_path from PIL import Image +from transformers import BitsAndBytesConfig from byaldi.objects import Result @@ -27,6 +28,7 @@ def __init__( load_from_index: bool = False, index_root: str = ".byaldi", device: Optional[Union[str, torch.device]] = None, + quantization_config: BitsAndBytesConfig | None = None, **kwargs, ): if isinstance(pretrained_model_name_or_path, Path): @@ -76,6 +78,7 @@ def __init__( else None ), token=kwargs.get("hf_token", None) or os.environ.get("HF_TOKEN"), + quantization_config=quantization_config, ) elif "colqwen2" in pretrained_model_name_or_path.lower(): self.model = ColQwen2.from_pretrained( @@ -88,6 +91,7 @@ def __init__( else None ), token=kwargs.get("hf_token", None) or os.environ.get("HF_TOKEN"), + quantization_config=quantization_config, ) self.model = self.model.eval() @@ -204,6 +208,7 @@ def from_pretrained( verbose: int = 1, device: Optional[Union[str, torch.device]] = None, index_root: str = ".byaldi", + quantization_config: BitsAndBytesConfig | None = None, **kwargs, ): return cls( @@ -213,6 +218,7 @@ def from_pretrained( load_from_index=False, index_root=index_root, device=device, + quantization_config=quantization_config, **kwargs, ) diff --git a/examples/quick_overview.ipynb b/examples/quick_overview.ipynb index eeeca7c..ad37ea2 100644 --- a/examples/quick_overview.ipynb +++ b/examples/quick_overview.ipynb @@ -2,73 +2,58 @@ "cells": [ { "cell_type": "code", - "execution_count": 1, + "execution_count": null, "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Verbosity is set to 1 (active). Pass verbose=0 to make quieter.\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "`Qwen2VLRotaryEmbedding` can now be fully parameterized by passing the model config through the `config` argument. All other arguments will be removed in v4.46\n" - ] - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "5514749b070c4a679a7c4b40fc0396fe", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "Loading checkpoint shards: 0%| | 0/2 [00:00] 2.11M --.-KB/s in 0.06s \n", - "\n", - "2024-11-13 07:58:42 (33.3 MB/s) - ‘1706.03762’ saved [2215244/2215244]\n", - "\n", - "mkdir: cannot create directory ‘docs’: File exists\n" - ] - } - ], + "outputs": [], "source": [ "# Let's get everyone's favourite paper in here\n", "!wget https://arxiv.org/pdf/1706.03762\n", @@ -79,76 +64,14 @@ }, { "cell_type": "code", - "execution_count": 3, + "execution_count": null, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "overwrite is on. Deleting existing index attention_index to build a new one.\n", - "Indexing file: docs/financial_report.pdf\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Added page 1 of document 0 to index.\n", - "Added page 2 of document 0 to index.\n", - "Added page 3 of document 0 to index.\n", - "Added page 4 of document 0 to index.\n", - "Added page 5 of document 0 to index.\n", - "Added page 6 of document 0 to index.\n", - "Index exported to .byaldi/attention_index\n", - "Indexing file: docs/product_c.png\n", - "Added page 1 of document 1 to index.\n", - "Index exported to .byaldi/attention_index\n", - "Indexing file: docs/attention.pdf\n", - "Added page 1 of document 2 to index.\n", - "Added page 2 of document 2 to index.\n", - "Added page 3 of document 2 to index.\n", - "Added page 4 of document 2 to index.\n", - "Added page 5 of document 2 to index.\n", - "Added page 6 of document 2 to index.\n", - "Added page 7 of document 2 to index.\n", - "Added page 8 of document 2 to index.\n", - "Added page 9 of document 2 to index.\n", - "Added page 10 of document 2 to index.\n", - "Added page 11 of document 2 to index.\n", - "Added page 12 of document 2 to index.\n", - "Added page 13 of document 2 to index.\n", - "Added page 14 of document 2 to index.\n", - "Added page 15 of document 2 to index.\n", - "Index exported to .byaldi/attention_index\n", - "Indexing file: docs/attention_with_a_mustache.pdf\n", - "Added page 1 of document 3 to index.\n", - "Added page 2 of document 3 to index.\n", - "Added page 3 of document 3 to index.\n", - "Added page 4 of document 3 to index.\n", - "Added page 5 of document 3 to index.\n", - "Added page 6 of document 3 to index.\n", - "Added page 7 of document 3 to index.\n", - "Added page 8 of document 3 to index.\n", - "Added page 9 of document 3 to index.\n", - "Added page 10 of document 3 to index.\n", - "Added page 11 of document 3 to index.\n", - "Added page 12 of document 3 to index.\n", - "Added page 13 of document 3 to index.\n", - "Added page 14 of document 3 to index.\n", - "Added page 15 of document 3 to index.\n", - "Index exported to .byaldi/attention_index\n", - "Indexing file: docs/attention_table.png\n", - "Added page 1 of document 4 to index.\n", - "Index exported to .byaldi/attention_index\n", - "Index exported to .byaldi/attention_index\n", - "Search results for 'what's the BLEU score of this new strange method?':\n", - "Doc ID: 2, Page: 1, Score: 14.9375\n", - "Doc ID: 3, Page: 1, Score: 14.9375\n", - "Doc ID: 3, Page: 8, Score: 14.6875\n", - "Doc ID: 2, Page: 8, Score: 14.6875\n", - "Doc ID: 4, Page: 1, Score: 14.5625\n", - "Test completed successfully!\n" + "Added page 12 of document 3 to index.\n" ] } ], @@ -178,17 +101,9 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": null, "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "62.5 ms ± 1.28 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)\n" - ] - } - ], + "outputs": [], "source": [ "%%timeit\n", "model.search(query, k=3)" @@ -196,56 +111,21 @@ }, { "cell_type": "code", - "execution_count": 5, + "execution_count": null, "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Verbosity is set to 1 (active). Pass verbose=0 to make quieter.\n" - ] - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "e827b1252bb843bbad57550986c88e4f", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "Loading checkpoint shards: 0%| | 0/2 [00:00=0.44.1", "colpali-engine>=0.3.4,<0.4.0", "ml-dtypes", "mteb==1.6.35",