Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
731 changes: 76 additions & 655 deletions docs/tutorials/tlm_structured_outputs/index.ipynb

Large diffs are not rendered by default.

5 changes: 5 additions & 0 deletions tlm/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import asyncio
import sys
from openai.types.chat import ChatCompletion
from openai.lib._parsing._completions import type_to_response_format_param

from tlm.config.base import BaseConfig
from tlm.config.schema import Config
Expand Down Expand Up @@ -162,6 +163,10 @@ async def _async_inference(
)
model = openai_kwargs.get("model")
config = BaseConfig.from_input(self.config, workflow_type, model)

if openai_kwargs.get("response_format"):
openai_kwargs["response_format"] = type_to_response_format_param(openai_kwargs["response_format"])

return await tlm_inference(
completion_params=openai_kwargs,
response=response,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ async def execute(self) -> None:
template_kwargs=template_kwargs,
temperature=0.0,
response_format_model=template.construct_response_format(answer),
reference_answer=answer,
)
)
for template in self.completion_templates
Expand Down
6 changes: 5 additions & 1 deletion tlm/components/scores/self_reflection_score_computation.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,10 @@ async def execute(self) -> None:
for completion in self_reflection_completions_flat
if (metadata := completion.per_field_metadata) is not None
]
composite_reflection_metadata = compute_field_metadata(reflection_metadata)

scoring_data = [
completion for completion in self_reflection_completions_flat if completion.per_field_metadata is not None
]
composite_reflection_metadata = compute_field_metadata(reflection_metadata, scoring_data=scoring_data)

self.execution_context.add("self_reflection_metadata_per_field", composite_reflection_metadata)
4 changes: 4 additions & 0 deletions tlm/config/presets.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,3 +135,7 @@ def from_inference_params(
WorkflowType.DEFAULT: {NUM_CONSISTENCY_COMPLETIONS: 0, NUM_SELF_REFLECTION_COMPLETIONS: 0},
},
}

# these values were benchmarked on 10/2025, there was no significant difference when using values from 0.8-0.95
STRUCTURED_OUTPUT_CORRECT_FIELD_SCORE: float = 0.9
STRUCTURED_OUTPUT_INCORRECT_FIELD_SCORE: float = 0.1
4 changes: 4 additions & 0 deletions tlm/templates/per_field_scoring_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,3 +13,7 @@ class PerFieldCertaintyEvaluation(PerFieldScoreEvaluationBase):

class PerFieldCorrectnessEvaluation(PerFieldScoreEvaluationBase):
confidence: Literal["Certain", "Mostly Certain", "Somewhat Certain", "Uncertain", "Likely Incorrect"]


# base class type for incorrect field evaluation
class IncorrectFieldEvaluationBase(BaseModel): ...
163 changes: 157 additions & 6 deletions tlm/templates/reflection_completion_templates.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
from abc import ABC, abstractmethod
from typing import Callable, ClassVar

from pydantic import BaseModel
from typing import Callable, ClassVar, Literal
import json
from pydantic import BaseModel, Field

from tlm.types.base import SOReflectionScoreConfigType
from tlm.config.presets import ReasoningEffort, WorkflowType
from tlm.templates.keywords import (
ANSWER_PLACEHOLDER,
Expand Down Expand Up @@ -33,7 +34,11 @@
)
from tlm.types import AnswerChoiceToken, ExtractedResponseField, RegexPattern, CompletionTemplate
from tlm.utils.response_format_utils import construct_per_field_response_format_model
from tlm.templates.per_field_scoring_models import PerFieldCorrectnessEvaluation, PerFieldCertaintyEvaluation
from tlm.templates.per_field_scoring_models import (
PerFieldCorrectnessEvaluation,
PerFieldCertaintyEvaluation,
IncorrectFieldEvaluationBase,
)


class ReflectionCompletionTemplate(CompletionTemplate, ABC):
Expand Down Expand Up @@ -77,6 +82,7 @@ class ReflectionSOPerFieldScoreTemplate(ReflectionCompletionTemplate):
per_field_score_response_format: ClassVar[type[BaseModel]]

def __init_subclass__(cls, **kwargs):
cls.so_reflection_score_config_type = SOReflectionScoreConfigType.PER_FIELD
super().__init_subclass__(**kwargs)
if not hasattr(cls, "per_field_score_response_format"):
raise TypeError(f"{cls.__name__} must define 'per_field_score_response_format' class variable")
Expand All @@ -86,6 +92,13 @@ def construct_response_format(cls, response_json: str) -> type[BaseModel] | None
return construct_per_field_response_format_model(response_json, cls.per_field_score_response_format)


class ReflectionSOIncorrectFieldsTemplate(ReflectionCompletionTemplate):
def __init_subclass__(cls, **kwargs):
cls.so_reflection_score_config_type = SOReflectionScoreConfigType.INCORRECT_FIELDS
cls.per_field_score_key = "score" # use a default key name for the per-field score
super().__init_subclass__(**kwargs)


class ReflectionCertaintyTemplate(ReflectionCompletionTemplate):
_SHARED_PROMPT: ClassVar[str] = f"""You are an evaluator that verifies whether AI responses are factually correct.
Below is a User Question and the Answer proposed by an untrustworthy AI Assistant.
Expand Down Expand Up @@ -742,6 +755,142 @@ def create(cls, reasoning_effort: ReasoningEffort, **kwargs) -> ReflectionComple
)


class SelfReflectionSOFieldAccuracyConfig(ReflectionSOIncorrectFieldsTemplate):
_PROMPT: ClassVar[str] = f"""You are an evaluator that identifies factual inaccuracies in AI responses.
Below is a User Request and the Response provided by an untrustworthy AI Assistant.
Your task is to find and list any fields in the Response that are factually incorrect, incomplete, or unreliable.

<request>
{QUESTION_PLACEHOLDER}
</request>

<response>
{ANSWER_PLACEHOLDER}
</response>


## Instructions

Carefully evaluate the factual accuracy of each top-level field in the JSON Response.
You must identify any fields that are likely incorrect, incomplete, unverifiable, or misleading.

Be extremely strict in your judgment:
- Treat any missing or partially correct information as potentially incorrect.
- Penalize any unsupported assumptions, factual errors, or extraneous content.
- Incomplete fields should be marked as incorrect. Do not excuse missing data, if a field is null/empty when it should contain information, it is incorrect.
- Even small inaccuracies should cause a field to be flagged as untrustworthy.

## Output Format

Output a JSON object with three fields:
1. "explanation": Briefly describe how you evaluated the response, what issues you found, and why certain fields were marked incorrect.
2. "incorrect_fields": An array of objects containing ONLY the fields that are potentially incorrect. Each object has:
- "field_name": The name of the incorrect field
- "explanation": A brief explanation of why this field is incorrect
If all fields appear accurate and trustworthy, output an empty array ([]).
3. "confidence_score": A score between 0 and 100 indicating your confidence in the overall correctness of the response. If you identify any incorrect, incomplete, or unverifiable fields, the score should be very low. Only assign a high score if every field is fully accurate and trustworthy.

Think through your evaluation systematically and provide clear reasoning for your decisions."""

@classmethod
def create(cls, reasoning_effort: ReasoningEffort, **kwargs) -> ReflectionCompletionTemplate:
if reasoning_effort == ReasoningEffort.NONE:
raise ValueError("Per-field scoring only supports reasoning")

return cls(
prompt_template=cls._PROMPT,
parse_patterns={},
score_mapper=score_100_mapping,
use_logprobs=False,
so_overall_score_key_name="confidence_score",
**kwargs,
)

@classmethod
def construct_response_format(cls, response_json: str) -> type[BaseModel] | None:
response_fields = json.loads(response_json).keys()
ResponseFields = Literal[tuple(response_fields)] # type: ignore

class IncorrectField(BaseModel):
field_name: ResponseFields # type: ignore
explanation: str

class ConfidenceModel(IncorrectFieldEvaluationBase):
explanation: str
incorrect_fields: list[IncorrectField]
confidence_score: int = Field(ge=0, le=100)

return ConfidenceModel


class SelfReflectionSOFieldKnowledgeGapConfig(ReflectionSOIncorrectFieldsTemplate):
_PROMPT: ClassVar[str] = f"""Below is a User Request and the proposed Response from an untrustworthy AI assistant:

<request>
{QUESTION_PLACEHOLDER}
</request>

<response>
{ANSWER_PLACEHOLDER}
</response>


How can I be sure each field in this response is correct?

Please provide a detailed, field-by-field evaluation of the proposed Response.

For each top-level field:
- Describe the evidence supporting the field’s content.
- Assess the reliability of that evidence and explain why it can or cannot be trusted.
- Identify any weaknesses: note if the field appears incorrect, incomplete, misleading, or unverifiable.

Work step by step: systematically check each field for factual accuracy and completeness, then summarize your findings.

This field-level evaluation should directly guide your overall confidence rating — the more issues you find, the lower the rating should be.
Now rate your confidence on a scale of 0–10 that the response is correct.


## Output Format

Output a JSON object with three fields:
1. "explanation": Briefly describe how you evaluated the response, what issues you found, and why certain fields were marked incorrect.
2. "incorrect_fields": An array of objects containing ONLY the fields that are potentially incorrect. Each object has:
- "field_name": The name of the incorrect field
- "explanation": A brief explanation of why this field is incorrect
If all fields appear accurate and trustworthy, output an empty array ([]).
3. "rating": A rating between 0 and 10. If you identify any incorrect, incomplete, or unverifiable fields, the rating should be very low. Only assign a high rating if every field is fully accurate and trustworthy."""

@classmethod
def create(cls, reasoning_effort: ReasoningEffort, **kwargs) -> ReflectionCompletionTemplate:
if reasoning_effort == ReasoningEffort.NONE:
raise ValueError("Per-field scoring only supports reasoning")

return cls(
prompt_template=cls._PROMPT,
parse_patterns={},
score_mapper=score_10_mapping,
use_logprobs=False,
so_overall_score_key_name="rating",
**kwargs,
)

@classmethod
def construct_response_format(cls, response_json: str) -> type[BaseModel] | None:
response_fields = json.loads(response_json).keys()
ResponseFields = Literal[tuple(response_fields)] # type: ignore

class IncorrectField(BaseModel):
field_name: ResponseFields # type: ignore
explanation: str

class RatingModel(IncorrectFieldEvaluationBase):
explanation: str
incorrect_fields: list[IncorrectField]
rating: int = Field(ge=0, le=10)

return RatingModel


SELF_REFLECTION_TEMPLATES_BY_WORKFLOW: dict[WorkflowType, list[type[ReflectionCompletionTemplate]]] = {
WorkflowType.QA: [
ReflectionCertaintyTemplate,
Expand All @@ -761,8 +910,10 @@ def create(cls, reasoning_effort: ReasoningEffort, **kwargs) -> ReflectionComple
ReflectionRAGIssuesTemplate,
],
WorkflowType.STRUCTURED_OUTPUT_SCORING: [
ReflectionCertaintyTemplate,
ReflectionKnowledgeGapTemplate,
# ReflectionCertaintyTemplate,
# ReflectionKnowledgeGapTemplate,
SelfReflectionSOFieldAccuracyConfig,
SelfReflectionSOFieldKnowledgeGapConfig,
ReflectionArgumentTemplate,
ReflectionSOPerScoreCorrectnessTemplate,
ReflectionSOPerScoreCertaintyTemplate,
Expand Down
2 changes: 2 additions & 0 deletions tlm/types/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
AnswerChoiceToken,
CompletionUsage,
CompletionFailure,
SOReflectionScoreConfigType,
)

__all__ = [
Expand All @@ -28,4 +29,5 @@
"CompletionUsage",
"CompletionFailure",
"CompletionParams",
"SOReflectionScoreConfigType",
]
7 changes: 6 additions & 1 deletion tlm/types/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ class CompletionFailureType(Enum):

class FieldMetadata(BaseModel):
score: float
explanation: str
explanation: str | None = None


class Eval(BaseModel):
Expand Down Expand Up @@ -102,3 +102,8 @@ class CompletionFailure(BaseModel):


CompletionParams = Dict[str, Any]


class SOReflectionScoreConfigType(str, Enum):
PER_FIELD = "per_field"
INCORRECT_FIELDS = "incorrect_fields"
16 changes: 15 additions & 1 deletion tlm/types/completion_template.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,13 @@
from typing import Any, Callable
from litellm.litellm_core_utils.get_supported_openai_params import get_supported_openai_params

from .base import ExtractedResponseField, RegexPattern, AnswerChoiceToken, CompletionParams
from .base import (
ExtractedResponseField,
RegexPattern,
AnswerChoiceToken,
CompletionParams,
SOReflectionScoreConfigType,
)

from tlm.config.models import MODELS_WITH_LOGPROBS
from tlm.config.provider import ModelProvider
Expand Down Expand Up @@ -34,6 +40,14 @@ class CompletionTemplate(BaseModel):
default=None,
description="Key in each field of the completion response schema that contains the reflection score",
)
so_reflection_score_config_type: SOReflectionScoreConfigType | None = Field(
default=None,
description="Type of score configuration for the structured output reflection score",
)
so_overall_score_key_name: str | None = Field(
default=None,
description="Key in the completion response schema that contains the overall reflection score (for incorrect fields scoring)",
)
use_logprobs: bool | None = Field(
default=None,
description="Whether to use logprobs to generate the completion, if available",
Expand Down
Loading