diff --git a/docs/tutorials/tlm_structured_outputs/index.ipynb b/docs/tutorials/tlm_structured_outputs/index.ipynb index 64bbee6..f1fd633 100644 --- a/docs/tutorials/tlm_structured_outputs/index.ipynb +++ b/docs/tutorials/tlm_structured_outputs/index.ipynb @@ -5,7 +5,7 @@ "id": "a8f278ba-441a-470a-99c1-54cbd6f8cee5", "metadata": {}, "source": [ - "# Using TLM via the OpenAI library to score the trustworthiness of: structured outputs, function calling, messages, and more" + "# Scoring Structured Output Responses with TLM" ] }, { @@ -13,74 +13,61 @@ "id": "79715fec-e568-407b-b2f6-11aba262c9ed", "metadata": {}, "source": [ - "This tutorial demonstrates how to score the trustworthiness of OpenAI model responses directly through the [OpenAI library](https://github.com/openai/openai-python). Existing OpenAI users: you can obtain real-time state-of-the-art trust scores for every OpenAI response, without changing your code or installing extra packages.\n", - "\n", - "Using TLM via the [OpenAI library](https://github.com/openai/openai-python) enables you to leverage OpenAI's advanced features (structured outputs, function calling, ...) and automatically flag errors/hallucinations made by OpenAI.\n" + "This tutorial demonstrates how to score the trustworthiness of structred outputs (JSON etc) using TLM.\n" ] }, { "cell_type": "markdown", - "id": "d9dc04d5", + "id": "ce124e52-f9c4-4110-b844-3fe05a84adb8", "metadata": {}, "source": [ - "![Getting TLM trustworthiness scores from using OpenAI API](tlm-openai-api.png)" + "## Setup" ] }, { "cell_type": "markdown", - "id": "0d38dc41", + "id": "9c07e776-adf9-4246-8a98-7d19d2fcb422", "metadata": {}, "source": [ - "In this tutorial, we use OpenAI's structured outputs feature to perform multi-label classification (i.e. document tagging) with trustworthiness scores from TLM. The same method can be used to score the trustworthiness of any type of response from OpenAI (not just structured outputs).\n", + "This tutorial requires an API key for an LLM provider. Some possibilities include: `OPENAI_API_KEY`, GEMINI_API_KEY, DEEPSEEK_API_KEY, AWS_ACCESS_KEY_ID/AWS_SECRET_ACCESS_KEY, etc.\n", "\n", - "Before starting this tutorial, we recommed you first complete our basic tutorial on [Using TLM with the Chat Completions API](../tlm_chat_completion)." - ] - }, - { - "cell_type": "markdown", - "id": "ce124e52-f9c4-4110-b844-3fe05a84adb8", - "metadata": {}, - "source": [ - "## Setup" + "The TLM Python client can be installed using pip:" ] }, { - "cell_type": "markdown", - "id": "9c07e776-adf9-4246-8a98-7d19d2fcb422", + "cell_type": "code", + "execution_count": null, + "id": "c4c356e4-6b02-4864-a5ff-926d991de162", "metadata": {}, + "outputs": [], "source": [ - "This tutorial requires a TLM API key. Get one [here](https://tlm.cleanlab.ai/).\n", - "\n", - "The Python packages required for this tutorial can be installed using pip:" + "%pip install --upgrade trustworthy-llm" ] }, { "cell_type": "code", - "execution_count": 1, - "id": "c4c356e4-6b02-4864-a5ff-926d991de162", + "execution_count": null, + "id": "67de14a5", "metadata": {}, "outputs": [], "source": [ - "%pip install --upgrade openai tqdm" + "# Set your API key\n", + "import os\n", + "\n", + "os.environ[\"OPENAI_API_KEY\"] = \"\" # or other LLM provider API key" ] }, { "cell_type": "code", - "execution_count": 2, + "execution_count": 3, "id": "b9bc6a28-1873-4b02-8624-16faffca9108", "metadata": {}, "outputs": [], "source": [ - "import pandas as pd\n", - "from enum import Enum\n", - "from pydantic import BaseModel\n", - "from concurrent.futures import ThreadPoolExecutor\n", - "from openai import OpenAI\n", - "import ast\n", - "import time\n", - "from tqdm import tqdm\n", + "from pydantic import create_model\n", + "from typing import Optional\n", "\n", - "pd.set_option(\"display.max_colwidth\", None)" + "from tlm import TLM" ] }, { @@ -88,7 +75,7 @@ "id": "6e18a67a-8575-43bb-9e53-74385d6fbc6a", "metadata": {}, "source": [ - "## Fetch example Dataset" + "## PII Extraction Use Case" ] }, { @@ -96,112 +83,38 @@ "id": "a65018bf-464e-4c8c-abe2-7e72c517918d", "metadata": {}, "source": [ - "This tutorial uses a modified version of the [Alexa intent detection dataset](https://huggingface.co/datasets/AmazonScience/massive). \n", + "This tutorial showcases a PII (Personally Identifiable Information) extraction example.\n", "\n", - "Each text sample contains several statements that could correspond to multiple intents (for example controlling devices, asking for information etc). The label corresponding to each example specifies what the intent of that statement is, where there could be more than one intent corresponding to each sample. Let's take a look at the dataset below:\n", + "Each text sample contains various types of personal information embedded within natural language text. The task is to extract different categories of PII from the text. Each example contains multiple types of PII that need to be identified and classified into specific categories including names (FIRSTNAME, LASTNAME), dates (DATE), and account numbers (ACCOUNTNUMBER).\n", "\n", - "In this tutorial, we will only run the LLM inference on 50 randomly sampled examples of this dataset as a demonstration." + "Let’s take a look at a few samples below:" ] }, { "cell_type": "code", - "execution_count": 3, + "execution_count": 4, "id": "b43273fc-c518-46fe-b6dd-f34215552fd3", "metadata": {}, - "outputs": [], - "source": [ - "!wget -nc https://cleanlab-public.s3.us-east-1.amazonaws.com/Datasets/massive_multilabel_classification.csv" - ] - }, - { - "cell_type": "code", - "execution_count": 4, - "id": "a73bec43-dee1-473a-a926-318b1a9c11d6", - "metadata": { - "scrolled": true - }, "outputs": [ { - "data": { - "text/html": [ - "
\n", - "\n", - "\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
textlabels
0lets have a chat[general_quirky]
1what are meeting scheduled for today[calendar_query]
2erase all the events. resume my audio book from karl pilkington. tell me the profession of celebrity[calendar_remove, play_audiobook, qa_factoid]
3thirty minute reminder on meeting for tuesday[calendar_set]
4i have a nine am meeting on wednesday send me a reminder[calendar_set]
\n", - "
" - ], - "text/plain": [ - " text \\\n", - "0 lets have a chat \n", - "1 what are meeting scheduled for today \n", - "2 erase all the events. resume my audio book from karl pilkington. tell me the profession of celebrity \n", - "3 thirty minute reminder on meeting for tuesday \n", - "4 i have a nine am meeting on wednesday send me a reminder \n", - "\n", - " labels \n", - "0 [general_quirky] \n", - "1 [calendar_query] \n", - "2 [calendar_remove, play_audiobook, qa_factoid] \n", - "3 [calendar_set] \n", - "4 [calendar_set] " - ] - }, - "execution_count": 4, - "metadata": {}, - "output_type": "execute_result" + "name": "stdout", + "output_type": "stream", + "text": [ + "Melvin, the password of your study support account has been changed to Xnjv7nCydECf for security purposes. Please update it promptly.\n" + ] } ], "source": [ - "data = pd.read_csv(\"massive_multilabel_classification.csv\")\n", - "data[\"labels\"] = data[\"labels\"].apply(ast.literal_eval)\n", - "data = data.sample(50, random_state=123).reset_index(drop=True)\n", - "data.head()" + "input_texts = [\n", + " \"Melvin, the password of your study support account has been changed to Xnjv7nCydECf for security purposes. Please update it promptly.\",\n", + " \"In relation to the filed litigation, we hereby request full disclosure of all data logs associated with 242.218.157.166 and cd9f:d9e5:1ceb:fd39:b2d7:f3fd:c9cd:c27b tied to the account of Fleta London Emard. This involves her employment account Investment Account with Gutkowski Inc.\",\n", + " \"We would like to do a follow-up meeting with Sierra Green regarding her recent surgery. The proposed date is August 13, 2013 at our clinic in West Nash.\",\n", + " \"Melvin, the password of your study support account has been changed to Xnjv7nCydECf for security purposes. Please update it promptly.\",\n", + " \"Is your business tax-ready? Our team in Novato is here to help you navigate through Martinique's complex tax rules. Contact us at 56544500.\",\n", + " \"To: Maximillian Noah Moore, we forgot to update your record with phone IMEI: 30-265288-033265-8. Could you please provide it in your earliest convenience to keep your records updated.\",\n", + "]\n", + "\n", + "print(input_texts[0])" ] }, { @@ -225,7 +138,10 @@ "id": "94989f63-ade7-4a01-95bd-bea936f6cfce", "metadata": {}, "source": [ - "First, we need to get a list of all possible classes from the given dataset:" + "We know that the 4 PII fields that we want to extract are: `['FIRSTNAME', 'LASTNAME', 'DATE', 'ACCOUNTNUMBER']`\n", + "\n", + "Using that, we can create a Pydantic model to represent our PII extraction schema. Each field is optional and can be None if that entity type is not found in the text:\n", + "\n" ] }, { @@ -233,43 +149,12 @@ "execution_count": 5, "id": "763c6dcd-6a18-4833-bf94-898f1a000836", "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "array(['general_quirky', 'calendar_query', 'calendar_remove',\n", - " 'play_audiobook', 'qa_factoid'], dtype=object)" - ] - }, - "execution_count": 5, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "multilabel_classes = data[\"labels\"].explode().unique()\n", - "multilabel_classes[:5]" - ] - }, - { - "cell_type": "markdown", - "id": "fdbe510d-8e60-4544-809f-8d817423e5ca", - "metadata": {}, - "source": [ - "Then, we can create a object that inherits from pydantic's `BaseModel` to represent the multi-label classification schema, ensuring that each predicted label is validated against the predefined list of possible classes:" - ] - }, - { - "cell_type": "code", - "execution_count": 6, - "id": "a14ba5ee-44dd-4727-93a2-e547487219c1", - "metadata": { - "scrolled": true - }, "outputs": [], "source": [ - "class MultiLabelClassification(BaseModel):\n", - " classes: list[Enum(\"MultilabelClasses\", {name: name for name in multilabel_classes})]" + "pii_entities = [\"FIRSTNAME\", \"LASTNAME\", \"DATE\", \"ACCOUNTNUMBER\"]\n", + "fields = {name: (Optional[str], None) for name in pii_entities}\n", + "\n", + "PII = create_model(\"PII\", **fields)" ] }, { @@ -277,72 +162,30 @@ "id": "f3f87254-771e-4674-aad9-1a83f08a0343", "metadata": {}, "source": [ - "### Prompt OpenAI " - ] - }, - { - "cell_type": "markdown", - "id": "a38fc4d5-f768-430e-94f4-8eba529591ef", - "metadata": {}, - "source": [ - "Then, we can instantiate the OpenAI client, pointing the `base_url` to TLM, which allows us to also get the trustworthiness score associated with each response." + "### Prompt TLM for responses and trust scores" ] }, { "cell_type": "code", - "execution_count": 7, - "id": "549220e5-0530-4342-af89-538a7965e577", + "execution_count": null, + "id": "63e35b01", "metadata": {}, "outputs": [], "source": [ - "# Get your Cleanlab API key from: https://tlm.cleanlab.ai/\n", - "client = OpenAI(api_key=\"\", base_url=\"https://api.cleanlab.ai/api/v1/openai_trustworthy_llm/\")" - ] - }, - { - "cell_type": "markdown", - "id": "5f1bbaa6-691c-4192-94ce-64c465abfb00", - "metadata": {}, - "source": [ - "Here is an example of how we can prompt OpenAI with one sample text:" - ] - }, - { - "cell_type": "code", - "execution_count": 8, - "id": "5efa2ff4-1b9b-4562-b74d-c431d8db4d52", - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "'lets have a chat'" - ] - }, - "execution_count": 8, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "sample_text = data[\"text\"][0]\n", - "sample_text" - ] - }, - { - "cell_type": "code", - "execution_count": 9, - "id": "7e8792d0-fa34-4824-8dc8-27bda83d9258", - "metadata": { - "scrolled": true - }, - "outputs": [], - "source": [ - "completion = client.beta.chat.completions.parse(\n", - " model=\"gpt-4o\",\n", - " messages=[{\"role\": \"user\", \"content\": f\"Classify the following text: {sample_text}\"}],\n", - " response_format=MultiLabelClassification,\n", - ")" + "tlm = TLM()\n", + "\n", + "sample_text = input_texts[0]\n", + "openai_kwargs = {\n", + " \"model\": \"gpt-4.1-mini\",\n", + " \"messages\": [\n", + " {\n", + " \"role\": \"user\",\n", + " \"content\": f\"Extract PII information from the following text, return null if the entity is not found: {sample_text}\",\n", + " }\n", + " ],\n", + " \"response_format\": PII,\n", + "}\n", + "tlm_result = tlm.create(**openai_kwargs)" ] }, { @@ -350,12 +193,12 @@ "id": "007eeb76-9324-44a4-b08c-a2994d6e9e53", "metadata": {}, "source": [ - "The returned object matches what OpenAI would ordinarily return, except it has an additional `tlm_metadata` field from TLM with extra information like the trustworthiness score. This way you can use TLM as a drop-in replacement for OpenAI in any application (and will still get back the same responses you'd get directly from OpenAI). Let's parse the predictions and trustworthiness score from the returned response:" + "The returned object matches what any LLM would ordinarily return, except it has an additional `trustworthiness_score` field from TLM with extra information like the trustworthiness score. " ] }, { "cell_type": "code", - "execution_count": 10, + "execution_count": null, "id": "e6d555f0-a01e-4e9f-9628-3a7d093fca6a", "metadata": {}, "outputs": [ @@ -363,444 +206,22 @@ "name": "stdout", "output_type": "stream", "text": [ - "Predicted Classes: ['general_quirky']\n", - "Trustworthiness Score: 0.8512080365166845\n" + "LLM response: FIRSTNAME='Melvin' LASTNAME=None DATE=None ACCOUNTNUMBER=None\n", + "Trustworthiness score: 0.98902\n" ] } ], "source": [ - "parsed_predictions = [prediction.value for prediction in completion.choices[0].message.parsed.classes]\n", - "trustworthiness_score = completion.tlm_metadata[\"trustworthiness_score\"]\n", - "\n", - "print(f\"Predicted Classes: {parsed_predictions}\")\n", - "print(f\"Trustworthiness Score: {trustworthiness_score}\")" - ] - }, - { - "cell_type": "markdown", - "id": "abcb4ff8-2982-4542-9693-9960583e5c8e", - "metadata": {}, - "source": [ - "### Batch Prompt on a Dataset" - ] - }, - { - "cell_type": "markdown", - "id": "453680cb-1bc0-4381-8b43-f405de8366af", - "metadata": {}, - "source": [ - "Here, we define a quick helper function that allows us to process multiple texts in parallel, which will speed up prompting the LLM on an entire dataset. The helper functions also parses and collects the predictions and trustworthiness score in a DataFrame for easy downstream analysis." - ] - }, - { - "cell_type": "code", - "execution_count": 11, - "id": "f4433f2d-a4db-4909-b523-731ac6227b5a", - "metadata": {}, - "outputs": [], - "source": [ - "def classify_text(text):\n", - " completion = client.beta.chat.completions.parse(\n", - " model=\"gpt-4o\",\n", - " messages=[{\"role\": \"user\", \"content\": f\"Classify the following text: {text}\"}],\n", - " response_format=MultiLabelClassification,\n", - " )\n", - "\n", - " return {\n", - " \"predictions\": [pred.value for pred in completion.choices[0].message.parsed.classes],\n", - " \"trustworthiness_score\": completion.tlm_metadata[\"trustworthiness_score\"],\n", - " }\n", - "\n", - "\n", - "def classify_texts_batch(texts, batch_size=20, max_threads=8, sleep_time=10):\n", - " results = []\n", - " for i in tqdm(range(0, len(texts), batch_size)):\n", - " batch = texts[i : i + batch_size]\n", - "\n", - " with ThreadPoolExecutor(max_threads) as executor:\n", - " futures = [executor.submit(classify_text, text) for text in batch]\n", - " batch_results = [f.result() for f in futures]\n", - "\n", - " results.extend(batch_results)\n", - "\n", - " # sleep to prevent hitting rate limits\n", - " if i + batch_size < len(texts):\n", - " time.sleep(sleep_time)\n", - "\n", - " return pd.DataFrame(results)" - ] - }, - { - "cell_type": "code", - "execution_count": 13, - "id": "786fe8b7-dada-4ec6-81cd-5e1643c12b59", - "metadata": { - "scrolled": true - }, - "outputs": [ - { - "data": { - "text/html": [ - "
\n", - "\n", - "\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
predictionstrustworthiness_score
0[general_quirky]0.851207
1[calendar_query]0.988874
2[calendar_remove, play_audiobook, qa_factoid]0.989885
3[alarm_query]0.338316
4[calendar_set, calendar_query]0.687683
\n", - "
" - ], - "text/plain": [ - " predictions trustworthiness_score\n", - "0 [general_quirky] 0.851207\n", - "1 [calendar_query] 0.988874\n", - "2 [calendar_remove, play_audiobook, qa_factoid] 0.989885\n", - "3 [alarm_query] 0.338316\n", - "4 [calendar_set, calendar_query] 0.687683" - ] - }, - "execution_count": 13, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "results = classify_texts_batch(data[\"text\"])\n", - "results.head()" - ] - }, - { - "cell_type": "markdown", - "id": "72b6aacf-2768-45ca-bd10-4eaff37693d8", - "metadata": {}, - "source": [ - "## Examine Results" - ] - }, - { - "cell_type": "markdown", - "id": "5f5ce86f-3085-4c90-aca4-964b1a3fe867", - "metadata": {}, - "source": [ - "We have now obtained the predictions and trustworthiness score for each given text. Let's examine the results in more detail." - ] - }, - { - "cell_type": "code", - "execution_count": 14, - "id": "a0c9198d-d26c-422c-b11f-6adec3fed116", - "metadata": {}, - "outputs": [], - "source": [ - "combined_results = pd.concat([data, results], axis=1)\n", - "combined_results = combined_results.rename(columns={\"labels\": \"ground_truth_labels\"})" - ] - }, - { - "cell_type": "markdown", - "id": "3732335f-9bb2-47f0-83da-936b2c66ca09", - "metadata": {}, - "source": [ - "### High Trustworthiness Scores" - ] - }, - { - "cell_type": "markdown", - "id": "4ccb3a4f-2c56-427a-88c5-a8345830cfcd", - "metadata": {}, - "source": [ - "The responses with the highest trustworthiness scores represent texts where TLM is the most confident that it has predicted the correct intents.\n", - "\n", - "We can see below that the predictions for the samples below match the ground truth labels and are correctly classified." - ] - }, - { - "cell_type": "code", - "execution_count": 15, - "id": "946e78ca-980d-4078-8564-01496fdfdf06", - "metadata": {}, - "outputs": [ - { - "data": { - "text/html": [ - "
\n", - "\n", - "\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
textground_truth_labelspredictionstrustworthiness_score
7what alarms did i set[alarm_query][alarm_query]0.989979
20turn the lights off[iot_hue_lightoff][iot_hue_lightoff]0.989947
17send an email to margaret. shut down the sound[email_sendemail, audio_volume_mute][email_sendemail, audio_volume_mute]0.989936
\n", - "
" - ], - "text/plain": [ - " text \\\n", - "7 what alarms did i set \n", - "20 turn the lights off \n", - "17 send an email to margaret. shut down the sound \n", - "\n", - " ground_truth_labels \\\n", - "7 [alarm_query] \n", - "20 [iot_hue_lightoff] \n", - "17 [email_sendemail, audio_volume_mute] \n", - "\n", - " predictions trustworthiness_score \n", - "7 [alarm_query] 0.989979 \n", - "20 [iot_hue_lightoff] 0.989947 \n", - "17 [email_sendemail, audio_volume_mute] 0.989936 " - ] - }, - "execution_count": 15, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "combined_results.sort_values(\"trustworthiness_score\", ascending=False).head(3)" - ] - }, - { - "cell_type": "markdown", - "id": "4639fa71-b7f2-480b-b00b-40fec010338f", - "metadata": {}, - "source": [ - "### Low Trustworthiness Scores" - ] - }, - { - "cell_type": "markdown", - "id": "25664160-d35f-421b-99b7-ec8e2bd4d200", - "metadata": {}, - "source": [ - "The responses with the lowest trustworthiness scores indicate outputs we are least confident are good.\n", - "\n", - "Results with low trustworthiness scores would benefit most from manual review, especially if we need almost all outputs across the dataset to be correct.\n", - "\n", - "For examples with the lowest trustworthiness scores in our dataset shown below, you can see that the predictions tend to be incorrect or could use further review." - ] - }, - { - "cell_type": "code", - "execution_count": 16, - "id": "e5f546f4-c2a0-4501-b94f-109c0e0a237c", - "metadata": {}, - "outputs": [ - { - "data": { - "text/html": [ - "
\n", - "\n", - "\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
textground_truth_labelspredictionstrustworthiness_score
42i will need warm socks in winter in morning[weather_query][general_quirky, datetime_query]0.264497
3thirty minute reminder on meeting for tuesday[calendar_set][alarm_query]0.338316
41features of google pixel. what is the deepest point on earth[general_quirky, qa_factoid][qa_factoid, recommendation_events]0.460527
\n", - "
" - ], - "text/plain": [ - " text \\\n", - "42 i will need warm socks in winter in morning \n", - "3 thirty minute reminder on meeting for tuesday \n", - "41 features of google pixel. what is the deepest point on earth \n", - "\n", - " ground_truth_labels predictions \\\n", - "42 [weather_query] [general_quirky, datetime_query] \n", - "3 [calendar_set] [alarm_query] \n", - "41 [general_quirky, qa_factoid] [qa_factoid, recommendation_events] \n", - "\n", - " trustworthiness_score \n", - "42 0.264497 \n", - "3 0.338316 \n", - "41 0.460527 " - ] - }, - "execution_count": 16, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "combined_results.sort_values(\"trustworthiness_score\").head(3)" - ] - }, - { - "cell_type": "markdown", - "id": "922df4fa-409c-40f3-8b8a-73ba2f5ee313", - "metadata": {}, - "source": [ - "## Using Different Quality Presets" - ] - }, - { - "cell_type": "markdown", - "id": "a396b996-da06-4102-9ef9-8a74f7cf24cb", - "metadata": {}, - "source": "You can use TLM with different [quality presets](../tlm_advanced#quality-presets) by specifying the `quality_preset` using the `extra_body` argument. \n\nFor example, in this example below we specify `extra_body={\"quality_preset\": \"low\"}` to use TLM on `low` quality preset (for lower latency). If unspecified, the default quality preset used is `medium`.\n\nRead more about quality presets [here](../../api/tlm/)." - }, - { - "cell_type": "code", - "execution_count": null, - "id": "561a4cc9-e835-40e8-adc5-abcecb09db9c", - "metadata": {}, - "outputs": [], - "source": [ - "sample_text = data[\"text\"][0]\n", - "\n", - "completion = client.beta.chat.completions.parse(\n", - " model=\"gpt-4o\",\n", - " messages=[{\"role\": \"user\", \"content\": f\"Classify the following text: {sample_text}\"}],\n", - " response_format=MultiLabelClassification,\n", - " extra_body={\"quality_preset\": \"low\"},\n", - ")" - ] - }, - { - "cell_type": "markdown", - "id": "bd7d5046", - "metadata": {}, - "source": "## Conclusion\n\nThis tutorial demonstrated how to use TLM for outputting a single structured field (list of document tags).\n" - }, - { - "cell_type": "markdown", - "id": "a396b996-da06-4102-9ef9-8a74f7cf24cd", - "metadata": {}, - "source": [ - "Remember: you can use TLM via the [OpenAI library](https://github.com/openai/openai-python) to score the trustworthiness of *any* type of OpenAI output (not just structured outputs).\n", - "Beyond structured outputs, we recommend using TLM via the [OpenAI library](https://github.com/openai/openai-python) for LLM applications involving: function calling, system prompts and multiple user/assistant messages (conversational dialogues), as well as other advanced features offered via OpenAI's API.\n", - "\n", - "For questions about the OpenAI API, refer to their documentation linked from [their library](https://github.com/openai/openai-python)." + "print(\"LLM response: \", tlm_result[\"response\"].choices[0].message.parsed)\n", + "print(\"Trustworthiness score: \", tlm_result[\"trustworthiness_score\"])" ] } ], "metadata": { "kernelspec": { - "display_name": "Python 3", + "display_name": "tlm-dev-jupyter", "language": "python", - "name": "python3" + "name": "tlm-dev-jupyter" }, "language_info": { "codemirror_mode": { @@ -812,7 +233,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.11.10" + "version": "3.12.3" } }, "nbformat": 4, diff --git a/tlm/api.py b/tlm/api.py index 1e92ef2..c391717 100644 --- a/tlm/api.py +++ b/tlm/api.py @@ -3,6 +3,7 @@ import asyncio import sys from openai.types.chat import ChatCompletion +from openai.lib._parsing._completions import type_to_response_format_param from tlm.config.base import BaseConfig from tlm.config.schema import Config @@ -162,6 +163,10 @@ async def _async_inference( ) model = openai_kwargs.get("model") config = BaseConfig.from_input(self.config, workflow_type, model) + + if openai_kwargs.get("response_format"): + openai_kwargs["response_format"] = type_to_response_format_param(openai_kwargs["response_format"]) + return await tlm_inference( completion_params=openai_kwargs, response=response, diff --git a/tlm/components/completions/self_reflection_completion_generator.py b/tlm/components/completions/self_reflection_completion_generator.py index 07b2ca3..4a9bfc6 100644 --- a/tlm/components/completions/self_reflection_completion_generator.py +++ b/tlm/components/completions/self_reflection_completion_generator.py @@ -49,6 +49,7 @@ async def execute(self) -> None: template_kwargs=template_kwargs, temperature=0.0, response_format_model=template.construct_response_format(answer), + reference_answer=answer, ) ) for template in self.completion_templates diff --git a/tlm/components/scores/self_reflection_score_computation.py b/tlm/components/scores/self_reflection_score_computation.py index a7de2a7..f96b61b 100644 --- a/tlm/components/scores/self_reflection_score_computation.py +++ b/tlm/components/scores/self_reflection_score_computation.py @@ -21,6 +21,10 @@ async def execute(self) -> None: for completion in self_reflection_completions_flat if (metadata := completion.per_field_metadata) is not None ] - composite_reflection_metadata = compute_field_metadata(reflection_metadata) + + scoring_data = [ + completion for completion in self_reflection_completions_flat if completion.per_field_metadata is not None + ] + composite_reflection_metadata = compute_field_metadata(reflection_metadata, scoring_data=scoring_data) self.execution_context.add("self_reflection_metadata_per_field", composite_reflection_metadata) diff --git a/tlm/config/presets.py b/tlm/config/presets.py index 6d04d3e..9102741 100644 --- a/tlm/config/presets.py +++ b/tlm/config/presets.py @@ -135,3 +135,7 @@ def from_inference_params( WorkflowType.DEFAULT: {NUM_CONSISTENCY_COMPLETIONS: 0, NUM_SELF_REFLECTION_COMPLETIONS: 0}, }, } + +# these values were benchmarked on 10/2025, there was no significant difference when using values from 0.8-0.95 +STRUCTURED_OUTPUT_CORRECT_FIELD_SCORE: float = 0.9 +STRUCTURED_OUTPUT_INCORRECT_FIELD_SCORE: float = 0.1 diff --git a/tlm/templates/per_field_scoring_models.py b/tlm/templates/per_field_scoring_models.py index 2c833af..18767d0 100644 --- a/tlm/templates/per_field_scoring_models.py +++ b/tlm/templates/per_field_scoring_models.py @@ -13,3 +13,7 @@ class PerFieldCertaintyEvaluation(PerFieldScoreEvaluationBase): class PerFieldCorrectnessEvaluation(PerFieldScoreEvaluationBase): confidence: Literal["Certain", "Mostly Certain", "Somewhat Certain", "Uncertain", "Likely Incorrect"] + + +# base class type for incorrect field evaluation +class IncorrectFieldEvaluationBase(BaseModel): ... diff --git a/tlm/templates/reflection_completion_templates.py b/tlm/templates/reflection_completion_templates.py index 4682f27..cf43b96 100644 --- a/tlm/templates/reflection_completion_templates.py +++ b/tlm/templates/reflection_completion_templates.py @@ -1,8 +1,9 @@ from abc import ABC, abstractmethod -from typing import Callable, ClassVar - -from pydantic import BaseModel +from typing import Callable, ClassVar, Literal +import json +from pydantic import BaseModel, Field +from tlm.types.base import SOReflectionScoreConfigType from tlm.config.presets import ReasoningEffort, WorkflowType from tlm.templates.keywords import ( ANSWER_PLACEHOLDER, @@ -33,7 +34,11 @@ ) from tlm.types import AnswerChoiceToken, ExtractedResponseField, RegexPattern, CompletionTemplate from tlm.utils.response_format_utils import construct_per_field_response_format_model -from tlm.templates.per_field_scoring_models import PerFieldCorrectnessEvaluation, PerFieldCertaintyEvaluation +from tlm.templates.per_field_scoring_models import ( + PerFieldCorrectnessEvaluation, + PerFieldCertaintyEvaluation, + IncorrectFieldEvaluationBase, +) class ReflectionCompletionTemplate(CompletionTemplate, ABC): @@ -77,6 +82,7 @@ class ReflectionSOPerFieldScoreTemplate(ReflectionCompletionTemplate): per_field_score_response_format: ClassVar[type[BaseModel]] def __init_subclass__(cls, **kwargs): + cls.so_reflection_score_config_type = SOReflectionScoreConfigType.PER_FIELD super().__init_subclass__(**kwargs) if not hasattr(cls, "per_field_score_response_format"): raise TypeError(f"{cls.__name__} must define 'per_field_score_response_format' class variable") @@ -86,6 +92,13 @@ def construct_response_format(cls, response_json: str) -> type[BaseModel] | None return construct_per_field_response_format_model(response_json, cls.per_field_score_response_format) +class ReflectionSOIncorrectFieldsTemplate(ReflectionCompletionTemplate): + def __init_subclass__(cls, **kwargs): + cls.so_reflection_score_config_type = SOReflectionScoreConfigType.INCORRECT_FIELDS + cls.per_field_score_key = "score" # use a default key name for the per-field score + super().__init_subclass__(**kwargs) + + class ReflectionCertaintyTemplate(ReflectionCompletionTemplate): _SHARED_PROMPT: ClassVar[str] = f"""You are an evaluator that verifies whether AI responses are factually correct. Below is a User Question and the Answer proposed by an untrustworthy AI Assistant. @@ -742,6 +755,142 @@ def create(cls, reasoning_effort: ReasoningEffort, **kwargs) -> ReflectionComple ) +class SelfReflectionSOFieldAccuracyConfig(ReflectionSOIncorrectFieldsTemplate): + _PROMPT: ClassVar[str] = f"""You are an evaluator that identifies factual inaccuracies in AI responses. +Below is a User Request and the Response provided by an untrustworthy AI Assistant. +Your task is to find and list any fields in the Response that are factually incorrect, incomplete, or unreliable. + + +{QUESTION_PLACEHOLDER} + + + +{ANSWER_PLACEHOLDER} + + + +## Instructions + +Carefully evaluate the factual accuracy of each top-level field in the JSON Response. +You must identify any fields that are likely incorrect, incomplete, unverifiable, or misleading. + +Be extremely strict in your judgment: +- Treat any missing or partially correct information as potentially incorrect. +- Penalize any unsupported assumptions, factual errors, or extraneous content. +- Incomplete fields should be marked as incorrect. Do not excuse missing data, if a field is null/empty when it should contain information, it is incorrect. +- Even small inaccuracies should cause a field to be flagged as untrustworthy. + +## Output Format + +Output a JSON object with three fields: +1. "explanation": Briefly describe how you evaluated the response, what issues you found, and why certain fields were marked incorrect. +2. "incorrect_fields": An array of objects containing ONLY the fields that are potentially incorrect. Each object has: + - "field_name": The name of the incorrect field + - "explanation": A brief explanation of why this field is incorrect + If all fields appear accurate and trustworthy, output an empty array ([]). +3. "confidence_score": A score between 0 and 100 indicating your confidence in the overall correctness of the response. If you identify any incorrect, incomplete, or unverifiable fields, the score should be very low. Only assign a high score if every field is fully accurate and trustworthy. + +Think through your evaluation systematically and provide clear reasoning for your decisions.""" + + @classmethod + def create(cls, reasoning_effort: ReasoningEffort, **kwargs) -> ReflectionCompletionTemplate: + if reasoning_effort == ReasoningEffort.NONE: + raise ValueError("Per-field scoring only supports reasoning") + + return cls( + prompt_template=cls._PROMPT, + parse_patterns={}, + score_mapper=score_100_mapping, + use_logprobs=False, + so_overall_score_key_name="confidence_score", + **kwargs, + ) + + @classmethod + def construct_response_format(cls, response_json: str) -> type[BaseModel] | None: + response_fields = json.loads(response_json).keys() + ResponseFields = Literal[tuple(response_fields)] # type: ignore + + class IncorrectField(BaseModel): + field_name: ResponseFields # type: ignore + explanation: str + + class ConfidenceModel(IncorrectFieldEvaluationBase): + explanation: str + incorrect_fields: list[IncorrectField] + confidence_score: int = Field(ge=0, le=100) + + return ConfidenceModel + + +class SelfReflectionSOFieldKnowledgeGapConfig(ReflectionSOIncorrectFieldsTemplate): + _PROMPT: ClassVar[str] = f"""Below is a User Request and the proposed Response from an untrustworthy AI assistant: + + +{QUESTION_PLACEHOLDER} + + + +{ANSWER_PLACEHOLDER} + + + +How can I be sure each field in this response is correct? + +Please provide a detailed, field-by-field evaluation of the proposed Response. + +For each top-level field: +- Describe the evidence supporting the field’s content. +- Assess the reliability of that evidence and explain why it can or cannot be trusted. +- Identify any weaknesses: note if the field appears incorrect, incomplete, misleading, or unverifiable. + +Work step by step: systematically check each field for factual accuracy and completeness, then summarize your findings. + +This field-level evaluation should directly guide your overall confidence rating — the more issues you find, the lower the rating should be. +Now rate your confidence on a scale of 0–10 that the response is correct. + + +## Output Format + +Output a JSON object with three fields: +1. "explanation": Briefly describe how you evaluated the response, what issues you found, and why certain fields were marked incorrect. +2. "incorrect_fields": An array of objects containing ONLY the fields that are potentially incorrect. Each object has: + - "field_name": The name of the incorrect field + - "explanation": A brief explanation of why this field is incorrect + If all fields appear accurate and trustworthy, output an empty array ([]). +3. "rating": A rating between 0 and 10. If you identify any incorrect, incomplete, or unverifiable fields, the rating should be very low. Only assign a high rating if every field is fully accurate and trustworthy.""" + + @classmethod + def create(cls, reasoning_effort: ReasoningEffort, **kwargs) -> ReflectionCompletionTemplate: + if reasoning_effort == ReasoningEffort.NONE: + raise ValueError("Per-field scoring only supports reasoning") + + return cls( + prompt_template=cls._PROMPT, + parse_patterns={}, + score_mapper=score_10_mapping, + use_logprobs=False, + so_overall_score_key_name="rating", + **kwargs, + ) + + @classmethod + def construct_response_format(cls, response_json: str) -> type[BaseModel] | None: + response_fields = json.loads(response_json).keys() + ResponseFields = Literal[tuple(response_fields)] # type: ignore + + class IncorrectField(BaseModel): + field_name: ResponseFields # type: ignore + explanation: str + + class RatingModel(IncorrectFieldEvaluationBase): + explanation: str + incorrect_fields: list[IncorrectField] + rating: int = Field(ge=0, le=10) + + return RatingModel + + SELF_REFLECTION_TEMPLATES_BY_WORKFLOW: dict[WorkflowType, list[type[ReflectionCompletionTemplate]]] = { WorkflowType.QA: [ ReflectionCertaintyTemplate, @@ -761,8 +910,10 @@ def create(cls, reasoning_effort: ReasoningEffort, **kwargs) -> ReflectionComple ReflectionRAGIssuesTemplate, ], WorkflowType.STRUCTURED_OUTPUT_SCORING: [ - ReflectionCertaintyTemplate, - ReflectionKnowledgeGapTemplate, + # ReflectionCertaintyTemplate, + # ReflectionKnowledgeGapTemplate, + SelfReflectionSOFieldAccuracyConfig, + SelfReflectionSOFieldKnowledgeGapConfig, ReflectionArgumentTemplate, ReflectionSOPerScoreCorrectnessTemplate, ReflectionSOPerScoreCertaintyTemplate, diff --git a/tlm/types/__init__.py b/tlm/types/__init__.py index 0f9b21a..f78a5a0 100644 --- a/tlm/types/__init__.py +++ b/tlm/types/__init__.py @@ -12,6 +12,7 @@ AnswerChoiceToken, CompletionUsage, CompletionFailure, + SOReflectionScoreConfigType, ) __all__ = [ @@ -28,4 +29,5 @@ "CompletionUsage", "CompletionFailure", "CompletionParams", + "SOReflectionScoreConfigType", ] diff --git a/tlm/types/base.py b/tlm/types/base.py index cdb5371..6601aa0 100644 --- a/tlm/types/base.py +++ b/tlm/types/base.py @@ -58,7 +58,7 @@ class CompletionFailureType(Enum): class FieldMetadata(BaseModel): score: float - explanation: str + explanation: str | None = None class Eval(BaseModel): @@ -102,3 +102,8 @@ class CompletionFailure(BaseModel): CompletionParams = Dict[str, Any] + + +class SOReflectionScoreConfigType(str, Enum): + PER_FIELD = "per_field" + INCORRECT_FIELDS = "incorrect_fields" diff --git a/tlm/types/completion_template.py b/tlm/types/completion_template.py index bf6b8ef..fe04d8d 100644 --- a/tlm/types/completion_template.py +++ b/tlm/types/completion_template.py @@ -2,7 +2,13 @@ from typing import Any, Callable from litellm.litellm_core_utils.get_supported_openai_params import get_supported_openai_params -from .base import ExtractedResponseField, RegexPattern, AnswerChoiceToken, CompletionParams +from .base import ( + ExtractedResponseField, + RegexPattern, + AnswerChoiceToken, + CompletionParams, + SOReflectionScoreConfigType, +) from tlm.config.models import MODELS_WITH_LOGPROBS from tlm.config.provider import ModelProvider @@ -34,6 +40,14 @@ class CompletionTemplate(BaseModel): default=None, description="Key in each field of the completion response schema that contains the reflection score", ) + so_reflection_score_config_type: SOReflectionScoreConfigType | None = Field( + default=None, + description="Type of score configuration for the structured output reflection score", + ) + so_overall_score_key_name: str | None = Field( + default=None, + description="Key in the completion response schema that contains the overall reflection score (for incorrect fields scoring)", + ) use_logprobs: bool | None = Field( default=None, description="Whether to use logprobs to generate the completion, if available", diff --git a/tlm/utils/completion_utils.py b/tlm/utils/completion_utils.py index 002c7f2..3dd49bd 100644 --- a/tlm/utils/completion_utils.py +++ b/tlm/utils/completion_utils.py @@ -23,11 +23,15 @@ CompletionUsage, ExtractedResponseField, CompletionTemplate, + SOReflectionScoreConfigType, ) from tlm.utils.openai_utils import extract_structured_output_field, extract_message_content from tlm.utils.constrain_outputs_utils import constrain_output from tlm.utils.parse_utils import get_parsed_answer_tokens_confidence -from tlm.utils.scoring.per_field_scoring_utils import extract_per_field_reflection_metadata +from tlm.utils.scoring.per_field_scoring_utils import ( + extract_per_field_reflection_metadata, + extract_incorrect_fields_reflection_metadata, +) from tlm.utils.math_utils import harmonic_mean litellm.suppress_debug_info = True @@ -46,6 +50,7 @@ async def generate_completion( template_kwargs: dict[str, Any] = {}, temperature: float | None = None, response_format_model: type[BaseModel] | None = None, + reference_answer: str | None = None, ) -> Completion | CompletionFailure: litellm_params = _build_litellm_params( template, @@ -55,7 +60,7 @@ async def generate_completion( response_format_model, ) - completion = await _generate_completion(litellm_params, template) + completion = await _generate_completion(litellm_params, template, reference_answer) if isinstance(completion, Completion): log_msg = f"""Generated {template.__class__.__name__} completion for model {litellm_params["model"]} with messages: @@ -115,6 +120,7 @@ def _build_litellm_params( async def _generate_completion( litellm_params: CompletionParams, template: CompletionTemplate | None, + reference_answer: str | None = None, ) -> Completion | CompletionFailure: try: response = await acompletion(**litellm_params) @@ -167,7 +173,7 @@ async def _generate_completion( original_response=response, template=template, ) - _parse_completion(completion) + _parse_completion(completion, reference_answer) return completion print(f"unhandled response type: {type(response)}") @@ -183,7 +189,7 @@ def _get_raw_message_content(logprobs: ChoiceLogprobs) -> str | None: return "".join([message_token.token for message_token in logprobs.content]) -def _parse_completion(completion: Completion) -> None: +def _parse_completion(completion: Completion, reference_answer: str | None = None) -> None: """Update the completion with parsed response fields""" if not completion.template: @@ -239,12 +245,13 @@ def _parse_completion(completion: Completion) -> None: explanation = extract_structured_output_field(message_content, "explanation") completion.add_response_field(ExtractedResponseField.EXPLANATION, explanation) - if completion.template.per_field_score_key: + if completion.template.so_reflection_score_config_type == SOReflectionScoreConfigType.PER_FIELD: if isinstance(completion.original_response, Dict): message_content = extract_message_content(completion.original_response) else: message_content = completion.message + assert completion.template.per_field_score_key is not None assert completion.template.score_mapper is not None per_field_metadata = extract_per_field_reflection_metadata( message_content, completion.template.per_field_score_key, completion.template.score_mapper @@ -253,6 +260,29 @@ def _parse_completion(completion: Completion) -> None: harmonic_mean_score = harmonic_mean([metadata.score for metadata in per_field_metadata.values()]) completion.add_response_field(ExtractedResponseField.MAPPED_SCORE, harmonic_mean_score) + elif completion.template.so_reflection_score_config_type == SOReflectionScoreConfigType.INCORRECT_FIELDS: + if isinstance(completion.original_response, Dict): + message_content = extract_message_content(completion.original_response) + else: + message_content = completion.message + + assert reference_answer is not None + per_field_metadata = extract_incorrect_fields_reflection_metadata( + message_content, + reference_answer, + ) + completion.per_field_metadata = per_field_metadata + score_mapper = completion.template.score_mapper + assert score_mapper is not None + + assert completion.template.so_overall_score_key_name is not None + unmapped_overall_score = json.loads(message_content)[completion.template.so_overall_score_key_name] + + completion.add_response_field( + ExtractedResponseField.MAPPED_SCORE, + score_mapper(unmapped_overall_score), + ) + def _get_trimmed_index(message: str, start_idx: int, end_idx: int) -> int: """Returns an adjusted end index that excludes any trailing punctuation and whitespace. diff --git a/tlm/utils/scoring/per_field_scoring_utils.py b/tlm/utils/scoring/per_field_scoring_utils.py index af6efff..f5318bd 100644 --- a/tlm/utils/scoring/per_field_scoring_utils.py +++ b/tlm/utils/scoring/per_field_scoring_utils.py @@ -2,7 +2,12 @@ import numpy as np from typing import Callable -from tlm.types import FieldMetadata +from tlm.types import FieldMetadata, Completion, SOReflectionScoreConfigType +from tlm.utils.math_utils import make_score_asymptotic +from tlm.config.presets import ( + STRUCTURED_OUTPUT_CORRECT_FIELD_SCORE, + STRUCTURED_OUTPUT_INCORRECT_FIELD_SCORE, +) def extract_per_field_reflection_metadata( @@ -27,7 +32,35 @@ def extract_per_field_reflection_metadata( return per_field_metadata -def compute_field_metadata(completion_metadata: list[dict[str, FieldMetadata]]) -> dict[str, FieldMetadata]: +def extract_incorrect_fields_reflection_metadata( + answer: str, + reference_answer: str, +) -> dict[str, FieldMetadata]: + answer_json = json.loads(answer) + + incorrect_fields_list = answer_json["incorrect_fields"] + incorrect_field_names_and_explanations = {item["field_name"]: item["explanation"] for item in incorrect_fields_list} + + field_names = json.loads(reference_answer).keys() + per_field_metadata = {} + + # construct scores and mapped scores for each field for downstream use of per-field score details + for field in field_names: + if field in incorrect_field_names_and_explanations.keys(): + per_field_metadata[field] = FieldMetadata( + score=STRUCTURED_OUTPUT_INCORRECT_FIELD_SCORE, + explanation=incorrect_field_names_and_explanations[field], + ) + else: + per_field_metadata[field] = FieldMetadata(score=STRUCTURED_OUTPUT_CORRECT_FIELD_SCORE) + + return per_field_metadata + + +def compute_field_metadata( + completion_metadata: list[dict[str, FieldMetadata]], + scoring_data: list[Completion] | None = None, +) -> dict[str, FieldMetadata]: score_data: dict[str, dict[str, list]] = {} for metadata_per_field in completion_metadata: @@ -43,10 +76,54 @@ def compute_field_metadata(completion_metadata: list[dict[str, FieldMetadata]]) composite_metadata = {} for field_name, data in score_data.items(): - min_score_idx = np.argmin(data["scores"]) + all_scores = data["scores"] + scores_with_explanation = [] + explanations = [] + + for score, explanation in zip(data["scores"], data["explanations"]): + if explanation: + scores_with_explanation.append(score) + explanations.append(explanation) + + # get explanation from the SR completion with the lowest score + if scores_with_explanation: + min_score_idx = np.argmin(scores_with_explanation) + explanation = explanations[min_score_idx] + else: + explanation = None + + # linearly rescale the scores from [min_possible_score, max_possible_score] to [0.001, 0.999] + avg_score = float(np.mean(all_scores)) + + if scoring_data is not None and len(scoring_data) > 0: + num_total_templates = len(scoring_data) + + num_incorrect_fields_templates = sum( + completion.template is not None + and completion.template.so_reflection_score_config_type == SOReflectionScoreConfigType.INCORRECT_FIELDS + for completion in scoring_data + ) + + num_non_incorrect_fields_templates = num_total_templates - num_incorrect_fields_templates + + max_possible_score = ( + STRUCTURED_OUTPUT_CORRECT_FIELD_SCORE * num_incorrect_fields_templates + + 1.0 * num_non_incorrect_fields_templates + ) / num_total_templates + min_possible_score = 1 - max_possible_score + + score_range = max_possible_score - min_possible_score + if score_range > 0: + normalized_score = max(avg_score - min_possible_score, 0.0) / score_range + scaled_score = make_score_asymptotic(normalized_score) + else: + scaled_score = avg_score + else: + scaled_score = avg_score + composite_metadata[field_name] = { - "score": np.mean(data["scores"]), - "explanation": data["explanations"][min_score_idx], + "score": scaled_score, + "explanation": explanation, } return composite_metadata # type: ignore[return-value]