From ae040a042c05712806e3f02f7a780ea6da192e51 Mon Sep 17 00:00:00 2001 From: ppinchuk Date: Thu, 29 Jan 2026 15:14:10 -0700 Subject: [PATCH 01/11] New function --- elm/base.py | 37 ++++++++++++++++++++++++++++++++++++- 1 file changed, 36 insertions(+), 1 deletion(-) diff --git a/elm/base.py b/elm/base.py index ef987c3e..dc1a1349 100644 --- a/elm/base.py +++ b/elm/base.py @@ -148,7 +148,42 @@ async def call_api(url, headers, request_json): return out - async def call_api_async(self, url, headers, all_request_jsons, + @staticmethod + async def call_client_embedding(client, request_json): + """Call OpenAI embedding API using client. + + Parameters + ---------- + client : openai.azure.AzureOpenAI + Optional OpenAI client to use for embedding calls. + request_json : mapping + Mapping of request json for embedding call (to be passed + to ``client.embeddings.create()``). + + Returns + ------- + dict + Embeddings response in json format. Will contain an + 'error' key if there was an error while processing the API + call. + """ + out = None + kwargs = dict(request_json) + + try: + response = client.embeddings.create(**kwargs) + out = response.model_dump_json(indent=2) + except Exception as e: + logger.debug(f'Error in OpenAI API call from ' + f'`aiohttp.ClientSession().post(**kwargs)` with ' + f'kwargs: {kwargs}') + logger.exception('Error in OpenAI API call! Turn on debug logging ' + 'to see full query that caused error.') + out = {'error': str(e)} + + return out + + async def call_api_async(self, all_request_jsons, ignore_error=None, rate_limit=40e3): """Use GPT to clean raw pdf text in parallel calls to the OpenAI API. From c88daa263b8c8c57c915bb059d631babfd4b1b2b Mon Sep 17 00:00:00 2001 From: ppinchuk Date: Thu, 29 Jan 2026 15:15:03 -0700 Subject: [PATCH 02/11] Add `ClientEmbeddingsApiQueue` --- elm/base.py | 53 +++++++++++++++++++++++++++++++++++++++++++++++++---- 1 file changed, 49 insertions(+), 4 deletions(-) diff --git a/elm/base.py b/elm/base.py index dc1a1349..e8a2d6d7 100644 --- a/elm/base.py +++ b/elm/base.py @@ -520,10 +520,10 @@ def submit_jobs(self): elif tokens < avail_tokens: token_count += tokens - task = asyncio.create_task(ApiBase.call_api(self.url, - self.headers, - request), - name=self.job_names[ijob]) + task = asyncio.create_task( + self._get_call_api_coro(request), + name=self.job_names[ijob]) + self.api_jobs[ijob] = task self.tries[ijob] += 1 self._tsub = time.time() @@ -541,6 +541,10 @@ def submit_jobs(self): token_count = 0 break + def _get_call_api_coro(self, request): + """Convenience function to get the appropriate API call coroutine""" + return ApiBase.call_api(self.url, self.headers, request) + async def collect_jobs(self): """Collect asyncronous API calls and API outputs. Store outputs in the `out` attribute.""" @@ -617,3 +621,44 @@ async def run(self): time.sleep(5) return self.out + + +class ClientEmbeddingsApiQueue(ApiQueue): + """Class to manage the parallel API embedding submissions using a client""" + + + def __init__(self, client, request_jsons, ignore_error=None, + rate_limit=40e3, max_retries=10): + """ + + Parameters + ---------- + client : openai.AzureOpenAI | openai.OpenAI + OpenAI client object to use for API calls. + request_jsons : list + List of API data input, one entry typically looks like this for + chat completion: + {"model": "gpt-3.5-turbo", + "messages": [{"role": "system", "content": "You do this..."}, + {"role": "user", "content": "Do this: {}"}], + "temperature": 0.0} + ignore_error : None | callable + Optional callable to parse API error string. If the callable + returns True, the error will be ignored, the API call will not be + tried again, and the output will be an empty string. + rate_limit : float + OpenAI API rate limit (tokens / minute). Note that the + gpt-3.5-turbo limit is 90k as of 4/2023, but we're using a large + factor of safety (~1/2) because we can only count the tokens on the + input side and assume the output is about the same count. + max_retries : int + Number of times to retry an API call wi + """ + super().__init__(url=None, headers=None, request_jsons=request_jsons, + ignore_error=ignore_error, + rate_limit=rate_limit, max_retries=max_retries) + self.client = client + + def _get_call_api_coro(self, request): + """Convenience function to get the appropriate API call coroutine""" + return ApiBase.call_client_embedding(self.client, request) From 97b01aeea8756091f60f3be6a4ad28ced18547f1 Mon Sep 17 00:00:00 2001 From: ppinchuk Date: Thu, 29 Jan 2026 15:15:31 -0700 Subject: [PATCH 03/11] Update `get_embedding` logic and introduce flag --- elm/base.py | 26 +++++++++++++++++++++----- 1 file changed, 21 insertions(+), 5 deletions(-) diff --git a/elm/base.py b/elm/base.py index e8a2d6d7..6695f90c 100644 --- a/elm/base.py +++ b/elm/base.py @@ -26,6 +26,9 @@ class ApiBase(ABC): EMBEDDING_MODEL = 'text-embedding-ada-002' """Default model to do text embeddings.""" + USE_CLIENT_EMBEDDINGS = False + """Option to use AzureOpenAI client for embedding calls.""" + EMBEDDING_URL = 'https://api.openai.com/v1/embeddings' """OpenAI embedding API URL""" @@ -355,8 +358,7 @@ async def generic_async_query(self, queries, model_role=None, return out - @classmethod - def get_embedding(cls, text): + def get_embedding(self, text): """Get the 1D array (list) embedding of a text string. Parameters @@ -369,9 +371,23 @@ def get_embedding(cls, text): embedding : list List of float that represents the numerical embedding of the text """ - kwargs = dict(url=cls.EMBEDDING_URL, - headers=cls.HEADERS, - json={'model': cls.EMBEDDING_MODEL, + if self.USE_CLIENT_EMBEDDINGS: + kwargs = dict(input=text, model=self.EMBEDDING_MODEL) + response = self._client.embeddings.create(**kwargs) + + try: + embedding = response.data[0].embedding + except Exception as exc: + msg = ('Embedding request failed: {} {}' + .format(out.reason, embedding)) + logger.error(msg) + raise RuntimeError(msg) from exc + + return embedding + + kwargs = dict(url=self.EMBEDDING_URL, + headers=self.HEADERS, + json={'model': self.EMBEDDING_MODEL, 'input': text}) out = requests.post(**kwargs) From 0d3cd0a88cdaaa3c152ed7cd13148439b578e821 Mon Sep 17 00:00:00 2001 From: ppinchuk Date: Thu, 29 Jan 2026 15:15:48 -0700 Subject: [PATCH 04/11] Use new flag for `call_api_async` --- elm/base.py | 22 +++++++++++----------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/elm/base.py b/elm/base.py index 6695f90c..c86a2520 100644 --- a/elm/base.py +++ b/elm/base.py @@ -195,14 +195,6 @@ async def call_api_async(self, all_request_jsons, Parameters ---------- - url : str - OpenAI API url, typically either: - https://api.openai.com/v1/embeddings - https://api.openai.com/v1/chat/completions - headers : dict - OpenAI API headers, typically: - {"Content-Type": "application/json", - "Authorization": f"Bearer {openai.api_key}"} all_request_jsons : list List of API data input, one entry typically looks like this for chat completion: @@ -226,9 +218,17 @@ async def call_api_async(self, all_request_jsons, List of API outputs where each list entry is a GPT answer from the corresponding message in the all_request_jsons input. """ - self.api_queue = ApiQueue(url, headers, all_request_jsons, - ignore_error=ignore_error, - rate_limit=rate_limit) + if self.USE_CLIENT_EMBEDDINGS: + self.api_queue = ClientEmbeddingsApiQueue(self._client, + all_request_jsons, + ignore_error, + rate_limit=rate_limit) + else: + self.api_queue = ApiQueue(self.EMBEDDING_URL, self.EMBEDDING_URL, + all_request_jsons, + ignore_error=ignore_error, + rate_limit=rate_limit) + out = await self.api_queue.run() return out From 8705f730ae747531f5265ba35606aaeb712a6861 Mon Sep 17 00:00:00 2001 From: ppinchuk Date: Thu, 29 Jan 2026 15:20:49 -0700 Subject: [PATCH 05/11] Revert function implementation --- elm/base.py | 24 ++++++++++++------------ 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/elm/base.py b/elm/base.py index c86a2520..6726e574 100644 --- a/elm/base.py +++ b/elm/base.py @@ -186,7 +186,7 @@ async def call_client_embedding(client, request_json): return out - async def call_api_async(self, all_request_jsons, + async def call_api_async(self, url, headers, all_request_jsons, ignore_error=None, rate_limit=40e3): """Use GPT to clean raw pdf text in parallel calls to the OpenAI API. @@ -195,6 +195,14 @@ async def call_api_async(self, all_request_jsons, Parameters ---------- + url : str + OpenAI API url, typically either: + https://api.openai.com/v1/embeddings + https://api.openai.com/v1/chat/completions + headers : dict + OpenAI API headers, typically: + {"Content-Type": "application/json", + "Authorization": f"Bearer {openai.api_key}"} all_request_jsons : list List of API data input, one entry typically looks like this for chat completion: @@ -218,17 +226,9 @@ async def call_api_async(self, all_request_jsons, List of API outputs where each list entry is a GPT answer from the corresponding message in the all_request_jsons input. """ - if self.USE_CLIENT_EMBEDDINGS: - self.api_queue = ClientEmbeddingsApiQueue(self._client, - all_request_jsons, - ignore_error, - rate_limit=rate_limit) - else: - self.api_queue = ApiQueue(self.EMBEDDING_URL, self.EMBEDDING_URL, - all_request_jsons, - ignore_error=ignore_error, - rate_limit=rate_limit) - + self.api_queue = ApiQueue(url, headers, all_request_jsons, + ignore_error=ignore_error, + rate_limit=rate_limit) out = await self.api_queue.run() return out From 8f900e28164a27de93eeb2155660baea161fa103 Mon Sep 17 00:00:00 2001 From: ppinchuk Date: Thu, 29 Jan 2026 15:21:01 -0700 Subject: [PATCH 06/11] Add `call_embedding_async` function for embedding class --- elm/embed.py | 67 ++++++++++++++++++++++++++++++++++++++++++++++------ 1 file changed, 60 insertions(+), 7 deletions(-) diff --git a/elm/embed.py b/elm/embed.py index b1470779..a909f015 100644 --- a/elm/embed.py +++ b/elm/embed.py @@ -3,11 +3,12 @@ ELM text embedding """ import openai +import json import re import os import logging -from elm.base import ApiBase +from elm.base import ApiBase, ApiQueue, ClientEmbeddingsApiQueue from elm.chunk import Chunker @@ -20,7 +21,7 @@ class ChunkAndEmbed(ApiBase): DEFAULT_MODEL = 'text-embedding-ada-002' """Default model to do embeddings.""" - def __init__(self, text, model=None, **chunk_kwargs): + def __init__(self, text, model=None, client=None, **chunk_kwargs): """ Parameters ---------- @@ -30,6 +31,10 @@ def __init__(self, text, model=None, **chunk_kwargs): model : None | str Optional specification of OpenAI model to use. Default is cls.DEFAULT_MODEL + client : openai.azure.AzureOpenAI | None + Optional OpenAI client to use for embedding calls. If + ``None``, a client is set up using environment variables. + By default, ``None``. chunk_kwargs : dict | None kwargs for initialization of :class:`elm.chunk.Chunker` """ @@ -37,6 +42,8 @@ def __init__(self, text, model=None, **chunk_kwargs): super().__init__(model) self.text = text + if client is not None: + self._client = client if os.path.isfile(text): logger.info('Loading text file: {}'.format(text)) @@ -142,17 +149,17 @@ async def run_async(self, rate_limit=175e3): for chunk in self.text_chunks: req = {"input": chunk, "model": self.model} - if 'azure' in str(openai.api_type).lower(): + if 'embedding' not in str(self.model).lower(): req['engine'] = self.model all_request_jsons.append(req) - embeddings = await self.call_api_async(self.EMBEDDING_URL, - self.HEADERS, - all_request_jsons, - rate_limit=rate_limit) + embeddings = await self.call_embedding_async(all_request_jsons, + rate_limit=rate_limit) for i, chunk in enumerate(embeddings): + if self.USE_CLIENT_EMBEDDINGS: + chunk = json.loads(chunk) try: embeddings[i] = chunk['data'][0]['embedding'] except Exception: @@ -164,3 +171,49 @@ async def run_async(self, rate_limit=175e3): logger.info('Finished all embeddings.') return embeddings + + async def call_embedding_async(self, all_request_jsons, + ignore_error=None, rate_limit=40e3): + """Use GPT to clean raw pdf text in parallel calls to the OpenAI API. + + NOTE: you need to call this using the await command in ipython or + jupyter, e.g.: `out = await PDFtoTXT.clean_txt_async()` + + Parameters + ---------- + all_request_jsons : list + List of API data input, one entry typically looks like this for + chat completion: + {"model": "gpt-3.5-turbo", + "messages": [{"role": "system", "content": "You do this..."}, + {"role": "user", "content": "Do this: {}"}], + "temperature": 0.0} + ignore_error : None | callable + Optional callable to parse API error string. If the callable + returns True, the error will be ignored, the API call will not be + tried again, and the output will be an empty string. + rate_limit : float + OpenAI API rate limit (tokens / minute). Note that the + gpt-3.5-turbo limit is 90k as of 4/2023, but we're using a large + factor of safety (~1/2) because we can only count the tokens on the + input side and assume the output is about the same count. + + Returns + ------- + out : list + List of API outputs where each list entry is a GPT answer from the + corresponding message in the all_request_jsons input. + """ + if self.USE_CLIENT_EMBEDDINGS: + self.api_queue = ClientEmbeddingsApiQueue(self._client, + all_request_jsons, + ignore_error, + rate_limit=rate_limit) + else: + self.api_queue = ApiQueue(self.EMBEDDING_URL, self.EMBEDDING_URL, + all_request_jsons, + ignore_error=ignore_error, + rate_limit=rate_limit) + + out = await self.api_queue.run() + return out \ No newline at end of file From 80aebc81a7503b7a16101ffc4cc83f04181efe99 Mon Sep 17 00:00:00 2001 From: ppinchuk Date: Thu, 29 Jan 2026 15:21:19 -0700 Subject: [PATCH 07/11] Minor update to tests --- tests/test_wizard.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/test_wizard.py b/tests/test_wizard.py index cf64353b..c69c3e31 100644 --- a/tests/test_wizard.py +++ b/tests/test_wizard.py @@ -26,8 +26,8 @@ class MockObject: class MockClass: """Dummy class to mock various api calls""" - @staticmethod - def get_embedding(*args, **kwargs): # pylint: disable=unused-argument + # pylint: disable=unused-argument + def get_embedding(self, *args, **kwargs): """Mock for ChunkAndEmbed.call_api()""" embedding = np.random.uniform(0, 1, 10) return embedding From 5dd8ba5e8403ad460ee6d73d2bdadcd2ab359c75 Mon Sep 17 00:00:00 2001 From: ppinchuk Date: Thu, 29 Jan 2026 15:21:41 -0700 Subject: [PATCH 08/11] Bump version --- elm/version.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/elm/version.py b/elm/version.py index f9963abf..2627c20a 100644 --- a/elm/version.py +++ b/elm/version.py @@ -2,4 +2,4 @@ ELM version number """ -__version__ = "0.0.35" +__version__ = "0.0.36" From 43781a1f967cf8b034db1fe3e552c11312b9d0a8 Mon Sep 17 00:00:00 2001 From: ppinchuk Date: Thu, 29 Jan 2026 15:22:42 -0700 Subject: [PATCH 09/11] Minor docstring update --- elm/embed.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/elm/embed.py b/elm/embed.py index a909f015..59509820 100644 --- a/elm/embed.py +++ b/elm/embed.py @@ -174,7 +174,7 @@ async def run_async(self, rate_limit=175e3): async def call_embedding_async(self, all_request_jsons, ignore_error=None, rate_limit=40e3): - """Use GPT to clean raw pdf text in parallel calls to the OpenAI API. + """Use an OpenAI API client to generate embeddings for text. NOTE: you need to call this using the await command in ipython or jupyter, e.g.: `out = await PDFtoTXT.clean_txt_async()` From b10734931488362d59f32d98e641761615db7cca Mon Sep 17 00:00:00 2001 From: ppinchuk Date: Thu, 29 Jan 2026 15:32:12 -0700 Subject: [PATCH 10/11] Linter fixes --- elm/base.py | 1 - 1 file changed, 1 deletion(-) diff --git a/elm/base.py b/elm/base.py index 6726e574..f75ac3b2 100644 --- a/elm/base.py +++ b/elm/base.py @@ -642,7 +642,6 @@ async def run(self): class ClientEmbeddingsApiQueue(ApiQueue): """Class to manage the parallel API embedding submissions using a client""" - def __init__(self, client, request_jsons, ignore_error=None, rate_limit=40e3, max_retries=10): """ From e04b17e4e61fa070ef648d881cdf351867984fbf Mon Sep 17 00:00:00 2001 From: ppinchuk Date: Thu, 29 Jan 2026 15:32:18 -0700 Subject: [PATCH 11/11] Linter fixes --- elm/embed.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/elm/embed.py b/elm/embed.py index 59509820..b0bc189e 100644 --- a/elm/embed.py +++ b/elm/embed.py @@ -80,7 +80,7 @@ def clean_tables(text): return '\n'.join(lines) - def run(self, rate_limit=175e3): + def run(self, rate_limit=175e3): # pylint: disable=unused-argument """Run text embedding in serial Parameters @@ -216,4 +216,4 @@ async def call_embedding_async(self, all_request_jsons, rate_limit=rate_limit) out = await self.api_queue.run() - return out \ No newline at end of file + return out