From dfae31279c47640c8c374bd67050955e671db923 Mon Sep 17 00:00:00 2001 From: alexjyc Date: Mon, 4 Aug 2025 15:13:37 -0400 Subject: [PATCH] prompt optimizer module update --- README.md | 25 +- src/example.py | 85 +++++++ src/extract_ocr.py | 318 +++++++++++++++++++++++++ src/fv_extract.py | 574 +++++++++++++++------------------------------ src/model.py | 32 ++- src/ocr_load.py | 11 +- src/optimizer.py | 300 +++++++++++++++++++++++ src/w2_extract.py | 326 ++++++++++++------------- 8 files changed, 1102 insertions(+), 569 deletions(-) create mode 100644 src/example.py create mode 100644 src/extract_ocr.py create mode 100644 src/optimizer.py diff --git a/README.md b/README.md index 7372a6c..54f371f 100644 --- a/README.md +++ b/README.md @@ -20,7 +20,8 @@ Launching Streamlit: `streamlit run src/w2_app.py` Import zipped directory of the PDFs -### Python +## Document Extraction with OCR Process +#### For working with OCR pipeline and optimizing prompts: ```bash python src/w2_extract.py --fieldpath --filepath [options] ``` @@ -38,9 +39,31 @@ python src/w2_extract.py --fieldpath --filepath [options] - `--spatial_ocr`: Enable spatial OCR with coordinate information - `--prompt_opt`: Enable prompt optimization with evaluation - `--label_file FILE`: Path to label file for evaluation (required when using --prompt_opt) +- `--train_file FILE`: Path to the training file path for prompt optimization (required when using --prompt_opt) +- `--train_label FILE`: Path to the training label file for prompt optimization (required when using --prompt_opt) Should be able to handle PDF, directories, and zipped directory paths +## OCR-based Extraction and Prompt Optimization +#### For working with pre-extracted OCR data and optimizing prompts: +```bash +python src/extract_ocr.py --fieldpath --test_file [options] +``` + +##### Required Arguments: +- `--fieldpath`: Path to the field definitions file (.json, .yaml, or .yml) +- `--test_file`: Path to the test OCR data file (.json) + +##### Optional Arguments: +- `--test_label`: Path to the test labels file (.csv) for evaluation +- `--training_file`: Path to the training OCR data file (.json) for prompt optimization +- `--training_label`: Path to the training label file (.csv) for prompt optimization +- `--file_out`: Path to save the output CSV or Excel file +- `--model_type {gpt-4o-mini,gpt-4.1,gpt-o3}`: Model type for extraction (default: gpt-4o-mini) +- `--prompt_opt`: Enable prompt optimization (requires training_file and training_label) +- `--opt_iterations`: Number of optimization iterations (default: 3) +- `--max_workers`: Number of worker threads for parallel processing (default: 4) + ## Evaluation ### Standalone Evaluation diff --git a/src/example.py b/src/example.py new file mode 100644 index 0000000..f55e8e4 --- /dev/null +++ b/src/example.py @@ -0,0 +1,85 @@ +class Example: + def __init__(self, base=None, fields=None, context=None, mode=None, output=None): + # Internal storage + self._store = {} + self._demos = [] + self._input_keys = {'fields', 'context', 'mode'} + + # Initialize from a base Example if provided + if base and isinstance(base, type(self)): + self._store = base._store.copy() + self._input_keys = base._input_keys.copy() + + # Initialize from a dict if provided + elif base and isinstance(base, dict): + self._store = base.copy() + + if fields is not None: + self._store['fields'] = fields + if context is not None: + self._store['context'] = context + if mode is not None: + self._store['mode'] = mode + if output is not None: + self._store['output'] = output + + def __getattr__(self, key): + if key.startswith("__") and key.endswith("__"): + raise AttributeError + if key in self._store: + return self._store[key] + raise AttributeError(f"'{type(self).__name__}' object has no attribute '{key}'") + + def __setattr__(self, key, value): + if key.startswith("_") or key in dir(self.__class__): + super().__setattr__(key, value) + else: + self._store[key] = value + + def __getitem__(self, key): + return self._store[key] + + def __setitem__(self, key, value): + self._store[key] = value + + def __delitem__(self, key): + del self._store[key] + + def __contains__(self, key): + return key in self._store + + def keys(self): + return self._store.keys() + + def values(self): + return self._store.values() + + def get(self, key, default=None): + return self._store.get(key, default) + + def inputs(self): + if self._input_keys is None: + raise ValueError("Inputs have not been set for this example.") + + d = {key: self._store[key] for key in self._store if key in self._input_keys} + new_instance = type(self)(base=d) + new_instance._input_keys = self._input_keys + return new_instance + + def labels(self): + input_keys = self.inputs().keys() + d = {key: self._store[key] for key in self._store if key not in input_keys} + return type(self)(d) + + def copy(self, **kwargs): + return type(self)(base=self, **kwargs) + + def without(self, *keys): + copied = self.copy() + for key in keys: + if key in copied._store: + del copied._store[key] + return copied + + def to_dict(self): + return self._store.copy() \ No newline at end of file diff --git a/src/extract_ocr.py b/src/extract_ocr.py new file mode 100644 index 0000000..0176e12 --- /dev/null +++ b/src/extract_ocr.py @@ -0,0 +1,318 @@ +from concurrent.futures import ThreadPoolExecutor, as_completed +import os +import json +import tempfile +import argparse +from typing import List +import pandas as pd +import yaml + +from example import Example +from fv_extract import FVExtraction +from evaluation import compare_datasets +from optimizer import PromptOptimizer + + +def load_fields_from_path(filepath: str) -> dict[str, dict]: + """Load field definitions from a file path (JSON or YAML).""" + _, ext = os.path.splitext(filepath) + + with open(filepath, 'r') as f: + if ext.lower() in ['.yaml', '.yml']: + custom_fields = yaml.safe_load(f) + elif ext.lower() == '.json': + custom_fields = json.load(f) + else: + raise ValueError(f"Unsupported file format: {ext}. Use .json, .yaml, or .yml") + + return custom_fields + + +def load_fields(file) -> dict[str, dict]: + """Load field definitions from a file-like object (for compatibility).""" + _, ext = os.path.splitext(file.name) + + with tempfile.NamedTemporaryFile(delete=False, suffix=ext) as tmp_file: + tmp_file.write(file.getvalue()) + tmp_path = tmp_file.name + + try: + with open(tmp_path, 'r') as f: + if ext.lower() in ['.yaml', '.yml']: + custom_fields = yaml.safe_load(f) + elif ext.lower() == '.json': + custom_fields = json.load(f) + else: + raise ValueError(f"Unsupported file format: {ext}. Use .json, .yaml, or .yml") + + return custom_fields + finally: + if os.path.exists(tmp_path): + os.unlink(tmp_path) + + +def create_trainset(ocr_cache, labels_df, fields, ocr_mode: str="ocr") -> List[Example]: + """Create training dataset from OCR cache and labels.""" + dataset = [] + + for idx, row in labels_df.iterrows(): + filename = row['file'] + + # Check if we have OCR data for this file + if filename not in ocr_cache: + print(f"Warning: No OCR data found for {filename}") + continue + + ocr_pages = ocr_cache[filename] + + # Create expected output dictionary + expected_fields = {} + for field in fields.keys(): + value = row[field] + expected_fields[field] = value + + example = Example( + fields=str(fields), + context=str(ocr_pages), + mode=ocr_mode, + output=expected_fields + ) + + dataset.append(example) + print(f"Added training example for {filename}") + + return dataset + + +def process_single_ocr_data(filename_and_data): + """Process a single OCR data entry with the extractor.""" + filename, ocr_data, extractor = filename_and_data + try: + # Handle different OCR data structures + if isinstance(ocr_data, list): + if len(ocr_data) == 1: + ocr_content = ocr_data[0].get('content', str(ocr_data[0])) + else: + ocr_content = '\n'.join([ + item.get('content', str(item)) if isinstance(item, dict) else str(item) + for item in ocr_data + ]) + elif isinstance(ocr_data, dict): + ocr_content = ocr_data.get('content', str(ocr_data)) + else: + ocr_content = str(ocr_data) + + # Extract using the extractor + result = extractor.fv_extract(ocr=ocr_content, include_rationale=False) + + return filename, [result], None + + except Exception as e: + return filename, None, str(e) + + +def optimize_prompt_from_external_data( + fields: dict[str, dict], + test_file: str, + test_label: str = None, + training_file: str = None, + training_label: str = None, + file_out: str = None, + model_type: str="gpt-4o-mini", + prompt_opt: bool=False, + opt_iterations: int = 3, + max_workers: int=4 + ): + """Optimize prompts using external OCR data and evaluate on test data.""" + base_extractor = FVExtraction(fields=fields, model_type=model_type) + extractor_to_use = base_extractor + + if prompt_opt and training_file and training_label: + print("Loading training OCR data...") + try: + with open(training_file, 'r') as f: + external_ocr_data = json.load(f) + except Exception as e: + raise ValueError(f"Failed to load external OCR data from {training_file}: {e}") + + # Load training labels + print("Loading training labels...") + try: + labels_df = pd.read_csv(training_label, keep_default_na=False, na_filter=False) + except Exception as e: + raise ValueError(f"Failed to load training labels from {training_label}: {e}") + + print("Creating training dataset...") + dataset = create_trainset( + ocr_cache=external_ocr_data, + labels_df=labels_df, + fields=fields + ) + + print(f"Optimizing prompts with {len(dataset)} training examples...") + optimizer = PromptOptimizer(fields=fields, model_type=model_type) + optimized_extractor = optimizer.optimize_prompt( + extractor=base_extractor, + train_data=dataset, + iterations=opt_iterations + ) + extractor_to_use = optimized_extractor + print("Prompt optimization completed!") + + print("Loading test OCR data...") + try: + with open(test_file, 'r') as f: + test_ocr_data = json.load(f) + except Exception as e: + raise ValueError(f"Failed to load test OCR data from {test_file}: {e}") + + print(f"Processing {len(test_ocr_data)} test files...") + results = {} + failed_files = [] + + with ThreadPoolExecutor(max_workers=max_workers) as executor: + # Submit all tasks + future_to_filename = { + executor.submit(process_single_ocr_data, (filename, ocr_data, extractor_to_use)): filename + for filename, ocr_data in test_ocr_data.items() + } + + # Collect results as they complete + for future in as_completed(future_to_filename): + filename = future_to_filename[future] + try: + filename, result, error = future.result() + if error is None and result is not None: + results[filename] = result + print(f"✓ Successfully processed {filename}") + else: + failed_files.append((filename, error)) + print(f"✗ Failed to process {filename}: {error}") + except Exception as e: + failed_files.append((filename, e)) + print(f"✗ Exception processing {filename}: {e}") + + if failed_files: + print(f"\nFailed to process {len(failed_files)} files:") + for filename, error in failed_files: + print(f" - {filename}: {error}") + + print(f"Successfully processed {len(results)} files") + + # Convert results to DataFrame + entries = [] + for filename, file_results in results.items(): + if file_results: + for pnum, page_result in enumerate(file_results): + temp = {'file': filename, 'page': pnum + 1} + temp.update(page_result.get('result', {})) + entries.append(temp) + + df = pd.DataFrame(entries) + print("Test results:") + print(df) + + # Save results if output file specified + if file_out is not None: + if file_out.endswith('.csv'): + df.to_csv(file_out, index=False) + elif file_out.endswith('.xlsx'): + with pd.ExcelWriter(file_out, engine='xlsxwriter') as writer: + df.to_excel(writer, index=False, sheet_name='Sheet1') + worksheet = writer.sheets['Sheet1'] + + # Adjust column widths + for col_idx, col in enumerate(df.columns, start=1): + max_length = max(df[col].astype(str).map(len).max(), len(col)) + 2 + worksheet.set_column(col_idx - 1, col_idx - 1, max_length) + print(f"Test results saved to: {file_out}") + + # Evaluate against test labels if provided + if test_label: + print("Evaluating test results...") + try: + total_comparisons, correct_comparisons, mismatches = compare_datasets( + labels_files=[test_label], + preds_file=file_out if file_out else None, + preds_df=df if file_out is None else None + ) + accuracy = correct_comparisons / total_comparisons if total_comparisons > 0 else 0 + print(f"Test accuracy: {accuracy:.2%}") + print(f"Total comparisons: {total_comparisons}") + print(f"Correct: {correct_comparisons}") + print(f"Mismatches: {len(mismatches)}") + except Exception as e: + print(f"Warning: Could not evaluate results: {e}") + + return df + + +def main(): + """Main function for command-line interface.""" + parser = argparse.ArgumentParser(description="Extract fields from OCR data using prompt optimization.") + + # Required arguments + parser.add_argument("--fieldpath", type=str, required=True, + help="Path to the field definitions file (.json, .yaml, or .yml)") + parser.add_argument("--test_file", type=str, required=True, + help="Path to the test OCR data file (.json)") + + # Optional arguments + parser.add_argument("--test_label", type=str, default=None, + help="Path to the test labels file (.csv) for evaluation") + parser.add_argument("--training_file", type=str, default=None, + help="Path to the training OCR data file (.json) for prompt optimization") + parser.add_argument("--training_label", type=str, default=None, + help="Path to the training labels file (.csv) for prompt optimization") + parser.add_argument("--file_out", type=str, default=None, + help="Path to save the output CSV or Excel file") + parser.add_argument("--model_type", default="gpt-4o-mini", + choices=["gpt-4o-mini", "gpt-4.1", "gpt-4.1-mini", "gpt-o3", + "Qwen/Qwen2.5-7B-Instruct-Turbo", "Qwen/Qwen3-8B", + "mistralai/Mistral-7B-Instruct-v0.2"], + help="Model type used for field extraction") + parser.add_argument("--prompt_opt", action="store_true", + help="Enable prompt optimization (requires training_file and training_label)") + parser.add_argument("--opt_iterations", default=3, type=int, + help="Number of optimization iterations") + parser.add_argument("--max_workers", default=4, type=int, + help="Number of worker threads for parallel processing") + + args = parser.parse_args() + + # Validate arguments + if args.prompt_opt and (not args.training_file or not args.training_label): + parser.error("--prompt_opt requires both --training_file and --training_label") + + try: + # Load field definitions + print(f"Loading field definitions from {args.fieldpath}") + custom_fields = load_fields_from_path(args.fieldpath) + print(f"Loaded {len(custom_fields)} field definitions") + + # Run extraction with optimization + result_df = optimize_prompt_from_external_data( + fields=custom_fields, + test_file=args.test_file, + test_label=args.test_label, + training_file=args.training_file, + training_label=args.training_label, + file_out=args.file_out, + model_type=args.model_type, + prompt_opt=args.prompt_opt, + opt_iterations=args.opt_iterations, + max_workers=args.max_workers + ) + + print("Processing completed successfully!") + print(f"Final results shape: {result_df.shape}") + + except Exception as e: + import traceback + traceback.print_exc() + print(f"Error during extraction: {str(e)}") + exit(1) + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/src/fv_extract.py b/src/fv_extract.py index cafab63..73d39d8 100644 --- a/src/fv_extract.py +++ b/src/fv_extract.py @@ -1,23 +1,15 @@ -from collections import defaultdict -from concurrent.futures import ThreadPoolExecutor, as_completed from dotenv import load_dotenv +from transformers import AutoTokenizer import tiktoken + +from field_validate import validate load_dotenv() import json -from langchain_openai import ChatOpenAI from pydantic import BaseModel, create_model -from typing import Optional, TypedDict, Union, Dict, List, Type - -class FieldErrorAnalysis(TypedDict): - expected: Union[str, float] - actual: Union[str, float] - reason: str - filename: str +from typing import Optional, Dict, List, Type -class ErrorAnalysis(TypedDict): - """Model for structured error analysis output""" - response: Dict[str, List[FieldErrorAnalysis]] +from model import get_model class TypeMapping: STRING_TO_TYPE: Dict[str, Type] = { @@ -32,103 +24,151 @@ def string_to_type(cls, type_string: str) -> Type: """Convert a string representation to its corresponding Python type.""" return cls.STRING_TO_TYPE.get(type_string) -class OptimizedExtraction: +class FVExtraction: def __init__(self, fields: dict[str, dict], model_type: str="gpt-4o-mini", max_attempts: int = 3): + self.llm = get_model(model_type) self.fields = fields self.model_type = model_type self.max_attempts = max_attempts self.base_template_parts = self._initialize_base_template() - - self.iteration_mismatches = [] - self.iteration_error_analyses = [] - self.current_iteration = 0 - self.document_error_counts = defaultdict(int) - self.field_error_counts = defaultdict(int) + self.meta_prompt = None + self.positive_exemplars = [] + def _initialize_base_template(self) -> Dict[str, str]: """Initialize the base prompt template parts.""" return { + 'meta_prompt_instructions': """ + [Meta-prompt Instructions]: + {meta_instructions} + ====End of Meta-prompt Instructions==== + """, 'base': """ -Extract value from a key:value list dictionary, OCR result, and/or coordinate layout given a description and type for each field in a `fields` dictionary. + Extract value from a key:value dictionary list, OCR result, and/or coordinate layout given a description and type for each field in a fields dictionary. -Field names may not exactly match keys in the input data. Match values using semantic similarity, type expectations, and (when available) spatial relationships or OCR confidence. + **CRITICAL RULE: All extracted values MUST come exclusively from the INPUT DATA section. Do not use any information from examples, previous extractions, or external knowledge.** -Return key-value results in a JSON dictionary string, where: -- If a field is **not found**, return `null`. -- If a field is found but has **no value**, return an **empty string**. -- Values must respect their declared types (e.g., float, string). -- Output must be compatible with `json.loads()` (i.e., double-quoted strings). + Return key-value results in a JSON dictionary string, where: + - If a field is **not found**, return `null`. + - If a field is found but has **no value**, return an **empty string**. + - Values must respect their declared types (e.g., float, string). + - Output must be compatible with `json.loads()` (i.e., double-quoted strings). + - Provide the final answer with rationale. -{mode_instructions} -""", + {mode_instructions} + """, 'kvs': """ -Use only the key-value dictionary provided. Match field descriptions to dictionary keys using: -- Semantic similarity -- Field type expectations (e.g., float, string) -Normalize keys for comparison (e.g., lowercase, remove punctuation). -""", + Matching Strategy: + 1. Normalize for comparison: Convert keys to lowercase, remove extra spaces/punctuation + 2. Priority matching order: + - Exact field identifier match (e.g., field 'a' matches key starting with 'a ') + - Semantic similarity (e.g., "SSN" matches "social security number") + - Keyword overlap (e.g., "wages" in field matches "wages" in key) + - Type-based inference (e.g., 9-digit number for SSN field) + Value Selection: + - If value is a list, typically use the first clean entry + - Trim whitespace from extracted values (preserve essential formatting) + - Validate against expected type and format + """, 'ocr': """ -Use the raw OCR text to identify values by: -- Finding lines that contain field-like descriptions -- Inferring label-value pairs by proximity in the OCR line order -- Matching expected data format or type (e.g., 9-digit number for SSN) -""", + Extraction Strategy: + 1. Pattern Recognition: Look for "field_description: value" or "field_code value" patterns + 2. Proximity Analysis: Find values on the same line or immediately following field descriptions + 3. Multi-line Handling: For addresses/complex fields, capture continuation lines + 4. Format Validation: Use expected patterns (9 digits for SSN, decimal for wages, etc.) + + Value Extraction: + - Extract values that appear after field identifiers + - Handle multi-line values (especially addresses) + - Clean extracted values but preserve essential formatting + """, 'spatial': """ -Use OCR text and corresponding coordinates (x, y, confidence) to: -- Identify field-value pairs using spatial alignment (e.g., same row or column) -- Prefer high-confidence OCR results -- Choose the value closest in position and most similar in meaning to the field description -""", - 'error_header': "\n==Previous Extraction Errors - CRITICAL: Learn from these patterns==\n", - 'field_error': "\nField '{field}' ({desc}):\nPredicted (Extracted) Value:{pred}\nWhat went wrong:{reason}\nExpected (True) Value:{true}\n\n", - 'error_analysis_section': "", - 'example': """ -==Example== -Example Fields dict: -{{ - 'a': {{'desc': 'a: Employee's Social Security Number (SSN) [9 digits]', 'type': 'string', 'example': '123456789', 'required': True}}, - '1': {{'desc': '1 Wages tips other comp', 'type': 'float', 'example': 1234.56, 'required': True}}, - 'e': {{'desc': 'e: Employee's Name', 'type': 'string', 'example': 'John Smith', 'required': True}}, - 'f': {{'desc': 'f: Employee's Address', 'type': 'string', 'example': '123 Main St, City, State, ZIP', 'required': True}}, - '14': {{'desc': '14: Other', 'type': 'string', 'example': 'DCP-REG 916.18 CA-SDI 66.53', 'required': False}}, - '17': {{'desc': '17: Testing', "type": "string", 'example': '6,281.52', 'required': False}} -}} - -Example Key-Value dict: -{{ - 'a Employee's social security number' : ['987654321 ', '987654321 ', '987654321 ', '987654321 '] - 'e/f Employee's name, address and ZIP code' : ['Yi Ding 6709 SETZLER PARKWAY BROOKLYN PARK MN 55445 ', 'Yi Ding 6709 SETZLER PARKWAY BROOKLYN PARK MN 55445 ', 'Yi Ding 6709 SETZLER PARKWAY BROOKLYN PARK MN 55445 ', 'Yi Ding 6709 SETZLER PARKWAY BROOKLYN PARK MN 55445 '] - '14 Other' : ['', '', '', ''] -}} - -Example Return: -''' -{{ - 'a': '987654321', - 'e': 'Yi Ding', - 'f': '6709 SETZLER PARKWAY BROOKLYN PARK MN 55445' - '14': '' - '17': null -}} -'''""", + 1. Confidence Filtering: Prioritize OCR results with confidence > 0.8 + 2. Spatial Relationships: + - Same row: y-coordinates within ±5 pixels + - Same column: x-coordinates within ±10 pixels + - Adjacent positioning: closest text element to field label + 3. Alignment Detection: Find horizontally aligned label-value pairs + + Value Selection: + - Choose the highest confidence text element that's spatially related to the field + - Prefer values that are properly aligned (same row/column) + - Validate extracted values against expected type and format + """, + 'w2_example': """ + ==Positive Examples== + Example Fields dict: + {{ + 'a': {{'desc': 'a: Employee's Social Security Number (SSN) [9 digits]', 'type': 'string', 'example': '123456789', 'required': True}}, + '1': {{'desc': '1 Wages tips other comp', 'type': 'float', 'example': 1234.56, 'required': True}}, + 'e': {{'desc': 'e: Employee's Name', 'type': 'string', 'example': 'John Smith', 'required': True}}, + 'f': {{'desc': 'f: Employee's Address', 'type': 'string', 'example': '123 Main St, City, State, ZIP', 'required': True}}, + '12a': {{'desc': '12a: code', "type": "string", 'example': 'D 6,281.52', 'required': False}} + '14': {{'desc': '14: Other', 'type': 'string', 'example': 'DCP-REG 916.18 CA-SDI 66.53', 'required': False}}, + + }} + + Example Key-Value dict: + {{ + 'a Employee's social security number' : ['987654321 ', '987654321 ', '987654321 ', '987654321 '] + '1 Wages tips other comp' : ['52345.67 ', '52345.67 ', '52345.67 ', '52345.67 '] + 'e Employee's name, address and ZIP code' : ['Yi Ding 6709 SETZLER PARKWAY BROOKLYN PARK MN 55445 ', 'Yi Ding 6709 SETZLER PARKWAY BROOKLYN PARK MN 55445 ', 'Yi Ding 6709 SETZLER PARKWAY BROOKLYN PARK MN 55445 ', 'Yi Ding 6709 SETZLER PARKWAY BROOKLYN PARK MN 55445 '], + '12a code see inst, for box 12' : ['D 6,281.52', 'D 6,281.52', 'D 6,281.52', 'D 6,281.52'], + '14 Other' : ['DCP-REG 916.18 CA-SDI 66.53', 'DCP-REG 916.18', 'CA-SDI 66.53', 'DCP-REG 916.18 CA-SDI 66.53'] + }} + + Example OCR text result: + ''' + a Employee's SSN 987654321 + 1 Wages tips other comp 52345.67 + e Employee's name, address and ZIP code + Yi Ding 6709 SETZLER PARKWAY BROOKLYN PARK MN 55445 + 12a code see inst, for box 12 + D 6,281.52 + 14 Other DCP-REG 916.18 + CA-SDI 66.53 + ''' + + Example JSON output: + ''' + {{ + 'a': '987654321', + 'e': 'Yi Ding', + 'f': '6709 SETZLER PARKWAY BROOKLYN PARK MN 55445', + '1': 52345.67, + '12a': 'D 6,281.52', + '14': 'DCP-REG 916.18 CA-SDI 66.53' + }} + ''' + ===END OF POSITIVE EXAMPLES=== + """, + 'mismatch': """ + === PREVIOUS EXTRACTION EXEMPLARIES === + {mismatches} + ===END OF PREVIOUS EXTRACTION EXEMPLARIES=== + """, 'input': """ -==Input== -Input Fields dict: -{fields} - -Input Data (key:value list dictionary, OCR result, and/or coordinate layout): -{input_data} - -Return: -''' -{{output}} -''' -""" + === INPUT DATA === + Input Fields dict: + {fields} + + Input OCR Data (key:value dict, OCR text, or OCR text with coordinates): + {input_data} + ===END OF INPUT DATA=== + """ } def generate_prompt(self, fields: dict, kvs: Dict = None, ocr: str = None, coordinates: List = None) -> tuple[str, int]: # Build prompt efficiently using string templates + prompt_parts = [] + + if self.meta_prompt: + tag = self.meta_prompt.tag + instructions = "\n".join(self.meta_prompt.instruction) + meta_section = self.base_template_parts['meta_prompt_instructions'].format(meta_instructions=f"<{tag}>\n{instructions}\n") + prompt_parts.append(meta_section) + if kvs: mode_instructions = self.base_template_parts['kvs'] elif ocr: @@ -136,19 +176,14 @@ def generate_prompt(self, fields: dict, kvs: Dict = None, ocr: str = None, coord elif ocr and coordinates: mode_instructions = self.base_template_parts['spatial'] - prompt_parts = [self.base_template_parts['base'].format(mode_instructions=mode_instructions)] - - if self.iteration_error_analyses: - prompt_parts.append(self.base_template_parts['error_analysis_section']) - - # Add example section - prompt_parts.append(self.base_template_parts['example']) + prompt_parts.append(self.base_template_parts['base'].format(mode_instructions=mode_instructions)) + prompt_parts.append(self.base_template_parts['w2_example']) input_data = [] if kvs: input_data.append(f"Input Key-Value dict:\n{kvs}") if ocr: - input_data.append(f"Input OCR result:\n{ocr}") + input_data.append(f"Input OCR text result:\n{ocr}") if coordinates: input_data.append(f"Input coordinates:\n{coordinates}") @@ -163,8 +198,6 @@ def generate_prompt(self, fields: dict, kvs: Dict = None, ocr: str = None, coord final_prompt = ''.join(prompt_parts) token_count = self.count_tokens(final_prompt) - - # Join all parts efficiently return final_prompt, token_count def count_tokens(self, prompt: str) -> int: @@ -178,7 +211,7 @@ def count_tokens(self, prompt: str) -> int: Returns: int: The number of tokens in the text """ - model_encodings = { + openai_model_encodings = { "gpt-4o-mini": "o200k_base", "gpt-4o": "o200k_base", "gpt-4": "cl100k_base", @@ -188,306 +221,63 @@ def count_tokens(self, prompt: str) -> int: "gpt-4.1": "cl100k_base" # Fallback encoding } - encoding_name = model_encodings.get(self.model_type, "cl100k_base") # Default fallback + hf_model_tokenizers = { + "Qwen/Qwen3-8B": "Qwen/Qwen3-8B", + "mistralai/Mistral-7B-Instruct-v0.2": "mistralai/Mistral-7B-Instruct-v0.2" + } + # claude token count tbd try: - encoding = tiktoken.get_encoding(encoding_name) - return len(encoding.encode(prompt)) + if self.model_type in openai_model_encodings: + encoding_name = openai_model_encodings[self.model_type] + encoding = tiktoken.get_encoding(encoding_name) + return len(encoding.encode(prompt)) + elif self.model_type in hf_model_tokenizers: + tokenizer = AutoTokenizer.from_pretrained(hf_model_tokenizers[self.model_type]) + return len(tokenizer.encode(prompt)) + else: + print(f"Warning: Unknown model type {self.model_type}, using fallback estimate.") + return len(prompt) // 4 except Exception as e: print(f"Warning: Could not count tokens for model {self.model_type}: {e}") return len(prompt) // 4 + def structured_model(self, fields: dict, include_rationale: bool) -> Type: + fields_model = create_model( + "DynamicFieldOutputs", + **{key: Optional[TypeMapping.string_to_type(field["type"])] + for key, field in fields.items()} + ) - def _get_all_accumulated_mismatches(self) -> List[Dict]: - """Get all mismatches from all iterations flattened.""" - all_mismatches = [] - for iteration_mismatches in self.iteration_mismatches: - all_mismatches.extend(iteration_mismatches['mismatches']) - return all_mismatches - - def _find_new_mismatches(self, current_mismatches: Dict) -> Dict: - """Find mismatches that don't exist in previous iterations.""" - previous_mismatches = self._get_all_accumulated_mismatches() - - new_file_mismatches = {} - - for filename, (mismatches, ocr_data) in current_mismatches.items(): - new_mismatches = [] - for mismatch in mismatches: - # Check if this mismatch is new - is_new = True - for prev_mismatch in previous_mismatches: - if (mismatch.get('name') == prev_mismatch.get('name') and - mismatch.get('column') == prev_mismatch.get('column') and - str(mismatch.get('raw_value')) == str(prev_mismatch.get('raw_value')) and - str(mismatch.get('true_value')) == str(prev_mismatch.get('true_value'))): - is_new = False - break - - if is_new: - new_mismatches.append(mismatch) - - if new_mismatches: - new_file_mismatches[filename] = (new_mismatches, ocr_data) - - return new_file_mismatches - - def _update_error_counts(self, new_mismatches: Dict) -> None: - """Update document and field error counts for prioritization.""" - for filename, (mismatches, _) in new_mismatches.items(): - self.document_error_counts[filename] += len(mismatches) - for mismatch in mismatches: - field = mismatch.get('column') - if field: - self.field_error_counts[field] += 1 - - def _prioritize_errors_for_template(self) -> List[Dict]: - """ - Prioritize errors for inclusion in the template based on: - 1. Fields with more errors (prioritized) - 2. Documents with more errors (prioritized) - 3. Recency (more recent iterations prioritized) - """ - all_errors = [] - - # Collect all errors with metadata - for iter_data in self.iteration_error_analyses: - iteration = iter_data['iteration'] - analysis = iter_data['analysis'] - - for field, errors in analysis.items(): - if field in self.fields: - field_error_count = self.field_error_counts.get(field, 0) - - for error in errors: - error_filename = error.get('filename') - doc_error_count = self.document_error_counts.get(error_filename, 0) if error_filename else 0 - - all_errors.append({ - 'field': field, - 'error': error, - 'field_error_count': field_error_count, - 'doc_error_count': doc_error_count, - 'iteration': iteration, - }) - - # Sort by priority: field errors (desc), document errors (desc), recency (desc) - all_errors.sort(key=lambda x: ( - -x['field_error_count'], # More field errors = higher priority - -x['doc_error_count'], # More document errors = higher priority - -x['iteration'] # More recent = higher priority - )) - - return all_errors - - def _select_exemplars_with_limits(self, prioritized_errors: List[Dict]) -> List[Dict]: - """ - Select exemplars with limits on total count and per-field count. - """ - selected_errors = [] - field_counts = defaultdict(int) - - for error_data in prioritized_errors: - field = error_data['field'] - - # Check if we've reached global limit - if len(selected_errors) >= 12: - break - - # Check if we've reached per-field limit - if field_counts[field] >= 3: - continue - - selected_errors.append(error_data) - field_counts[field] += 1 - - return selected_errors - - def add_iteration_mismatches(self, new_mismatches: Dict) -> None: - """ - Add mismatches from a new iteration and update error analysis incrementally. - - Args: - new_mismatches (List[Dict]): Mismatches from the current iteration - """ - if not new_mismatches: - print("No mismatches to add for this iteration") - return - - # Find only the truly new mismatches (not seen in previous iterations) - unique_new_mismatches = self._find_new_mismatches(new_mismatches) - print(f"Found {len(unique_new_mismatches)} new mismatches") - - if not unique_new_mismatches: - print("No new unique mismatches found - all were seen in previous iterations") - return - - self._update_error_counts(unique_new_mismatches) - - # Add to iteration tracking - all_new_mismatches = [] - ocr_context = {} - - for filename, (mismatches, ocr_data) in unique_new_mismatches.items(): - all_new_mismatches.extend(mismatches) - ocr_context[filename] = ocr_data - - # Add to iteration tracking with OCR context - self.iteration_mismatches.append({ - 'mismatches': all_new_mismatches, - 'ocr_context': ocr_context, - 'file_mismatches': unique_new_mismatches - }) - self.current_iteration += 1 - - print(f"Iteration {self.current_iteration}: Added {len(unique_new_mismatches)} new mismatches") - print(f"Total mismatches across all iterations: {len(self._get_all_accumulated_mismatches())}") - - # Generate error analysis for ONLY the new mismatches - - new_error_analysis = {} - with ThreadPoolExecutor(max_workers=5) as executor: # Limit concurrent API calls - # Submit all tasks - future_to_file = { - executor.submit(self.get_error_analysis, mismatches, ocr_data, filename): filename - for filename, (mismatches, ocr_data) in unique_new_mismatches.items() - } - - # Collect results as they complete - for future in as_completed(future_to_file): - filename = future_to_file[future] - try: - response = future.result(timeout=30) # 30 second timeout per call - if response and 'response' in response: - for field, errors in response['response'].items(): - new_error_analysis[field] = new_error_analysis.get(field, []) + errors - except Exception as e: - print(f"Error analyzing {filename}: {e}") - - # Store the error analysis for this iteration - self.iteration_error_analyses.append({ - 'iteration': self.current_iteration, - 'analysis': new_error_analysis, - 'mismatch_count': len(unique_new_mismatches) - }) - - self._rebuild_error_analysis_section() # Update the error analysis section - - def _rebuild_error_analysis_section(self) -> None: - """Rebuild the complete error analysis section from all iterations.""" - if not self.iteration_error_analyses: - self.base_template_parts['error_analysis_section'] = "" - return - - prioritized_errors = self._prioritize_errors_for_template() - selected_errors = self._select_exemplars_with_limits(prioritized_errors) - - + output_model = create_model( + "ExtractorOutput", + result=(fields_model, ...) + ) - error_sections = [] - - # Add main header - error_sections.append(self.base_template_parts['error_header']) - - # Process each iteration's error analysis - for error_data in selected_errors: - field = error_data['field'] - error = error_data['error'] - - error_sections.append( - self.base_template_parts['field_error'].format( - field=field, - desc=self.fields[field]['desc'], - pred=error['actual'], - reason=error['reason'], - true=error['expected'] - ) + if include_rationale: + output_model = create_model( + "ExtractorOutput", + result=(fields_model, ...), + logic=(str, ...) ) - - # Update the error analysis section - self.base_template_parts['error_analysis_section'] = ''.join(error_sections) + return output_model - total_iterations = len(self.iteration_error_analyses) - total_mismatches = len(self._get_all_accumulated_mismatches()) - print(f"Rebuilt error analysis section with {total_iterations} iterations, {total_mismatches} total mismatches") - - def get_error_analysis(self, mismatches: list, ocr_data: str, filename: str) -> dict: - llm = ChatOpenAI(model=self.model_type, temperature=0) - llm = llm.with_structured_output(ErrorAnalysis) - - prompt = """ - You are an expert in document field extraction error analysis. - Your job is to take a list of mismatches between 'expected' vs. 'predicted' values, - With given field information and OCR context, lets think step-by-step - - 1. Root Cause Analysis: Identify why the extraction model failed - 2. Self-Reflection: Generate insights about what the model might have "thought" during extraction - 3. Evidence Mapping: Connect OCR artifacts to extraction decisions - 4. Pattern Recognition: Identify systematic vs. isolated errors - - RESPONSE FORMAT: - Return a dictionary where: - - The 'response' key contains a dictionary - - Each field name maps to a LIST of error objects - - Each error object has: - - expected: ground truth value - - actual: model predicted value - - reason: a concise reasoning or explanation of why the model failed after self-reflection - - filename: the filename of the document (use: {filename}) - - INPUT DATA: - Fields: - {fields} - - filename: - {filename} - - OCR Context: - {ocr_context} - - Mismatches: - {mismatches} - """ - - response = llm.invoke(prompt.format(fields=self.fields, ocr_context=ocr_data, mismatches=mismatches, filename=filename)) - return response - - def structured_model(self, fields: dict) -> Type: - output_format = create_model( - "DynamicFieldOutputs", - **{key: Optional[TypeMapping.string_to_type(field["type"])] for key, field in fields.items()} - - ) - return output_format + return output_model def output_filling(self, model_cls: dict, response: BaseModel) -> tuple[Type, List[str]]: missing = set() filled_model = response.model_dump() for name in model_cls: - if model_cls[name]: - continue - elif filled_model[name] is None or filled_model[name] == "": - missing.add(name) - else: + if name in filled_model: model_cls[name] = filled_model[name] - - return model_cls, missing - - def get_response(self, fields: dict, kvs: dict = None, ocr: str = None, coordinates: list = None): - llm = ChatOpenAI(model=self.model_type, temperature=0) - output_format = self.structured_model(fields) - - llm = llm.with_structured_output(output_format) - prompt, token_count = self.generate_prompt(fields=fields, kvs=kvs, ocr=ocr, coordinates=coordinates) - - print(f"Prompt token count: {token_count}") - - response = llm.invoke(prompt) + + if not validate(self.fields[name]['type'], model_cls[name], self.fields[name]['required']): + missing.add(name) - return response + return model_cls, missing - def fv_extract(self, kvs: dict[str, list[str]] = None, ocr: str = None, coordinates: list = None) -> dict[str, str]: + def fv_extract(self, kvs: dict[str, list[str]] = None, ocr: str = None, coordinates: list = None, include_rationale: bool = False) -> dict: ''' Extract values from a key-value dictionary based on field descriptions and types. @@ -501,7 +291,7 @@ def fv_extract(self, kvs: dict[str, list[str]] = None, ocr: str = None, coordina # Track the number of attempts and Initialize results and incomplete fields attempts = 0 - results = {k: None for k in self.fields.keys()} + results, rationale = {k: None for k in self.fields.keys()}, None incomplete_fields = {k: self.fields[k] for k in self.fields.keys()} # Iterate until all fields are filled or max attempts reached @@ -509,15 +299,16 @@ def fv_extract(self, kvs: dict[str, list[str]] = None, ocr: str = None, coordina try: print('Attempt #', attempts + 1) - # Query the OpenAI API with the prompt - response = self.get_response(fields=incomplete_fields, kvs=kvs, ocr=ocr, coordinates=coordinates) - results, missing = self.output_filling(results, response) + output_format = self.structured_model(incomplete_fields, include_rationale=include_rationale) + prompt, token_count = self.generate_prompt(fields=incomplete_fields, kvs=kvs, ocr=ocr, coordinates=coordinates) + response = self.llm.get_response(prompt, output_format=output_format) + results, missing = self.output_filling(results, response.result if include_rationale else response) + + if include_rationale and hasattr(response, 'logic'): + rationale = response.logic + + incomplete_fields = {k: incomplete_fields[k] for k in missing} - # Check for incomplete fields - incomplete_fields = { - k: incomplete_fields[k] - for k in missing - } # If all fields are filled, break the loop if len(incomplete_fields) == 0: @@ -529,4 +320,7 @@ def fv_extract(self, kvs: dict[str, list[str]] = None, ocr: str = None, coordina # If JSON decoding fails, return the raw result print(f"Failed to decode JSON: {response}") - return results + return { + 'result': results, + 'rationale': rationale + } \ No newline at end of file diff --git a/src/model.py b/src/model.py index 277a893..404cf9f 100644 --- a/src/model.py +++ b/src/model.py @@ -168,54 +168,62 @@ def __init__(self, model_name: str): ) try: - if model_name == "gpt-4": + if model_name == ["gpt-4o-mini", "gpt-4.1", "gpt-o3"]: api_key = os.getenv('OPENAI_API_KEY') if not api_key: raise ValueError("OPENAI_API_KEY not found in environment variables") - logger.info(f"Initializing OpenAI GPT-4 model...") + logger.info("Initializing OpenAI GPT-4 model...") self.model = ChatOpenAI( - model="gpt-4-turbo-preview", + model=model_name, temperature=0, max_tokens=150, openai_api_key=api_key ) - logger.info(f"GPT-4 model initialized successfully") + logger.info("GPT-4 model initialized successfully") elif model_name == "gemini": api_key = os.getenv('GOOGLE_API_KEY') if not api_key: raise ValueError("GOOGLE_API_KEY not found in environment variables") - logger.info(f"Initializing Google Gemini model...") + logger.info("Initializing Google Gemini model...") self.model = ChatGoogleGenerativeAI( model="gemini-pro", temperature=0, max_output_tokens=150, google_api_key=api_key ) - logger.info(f"Gemini model initialized successfully") + logger.info("Gemini model initialized successfully") elif model_name == "claude": api_key = os.getenv('ANTHROPIC_API_KEY') if not api_key: raise ValueError("ANTHROPIC_API_KEY not found in environment variables") - logger.info(f"Initializing Anthropic Claude model...") + logger.info("Initializing Anthropic Claude model...") self.model = ChatAnthropic( model="claude-3-opus-20240229", temperature=0, max_tokens=150, anthropic_api_key=api_key ) - logger.info(f"Claude model initialized successfully") + logger.info("Claude model initialized successfully") else: raise ValueError(f"Unsupported LangChain model: {model_name}") except Exception as e: logger.error(f"Error initializing {model_name} model: {e}") raise - def get_response(self, prompt: str, max_new_tokens: int = 100) -> str: + def get_response(self, prompt: str, max_new_tokens: int = 100, **kwargs) -> str: """Get response from the LangChain model.""" try: + model = self.model + output_format = kwargs.pop('output_format', None) + if output_format: + model = model.with_structured_output(output_format) # Simple invoke - max_tokens is already set in constructor - response = self.model.invoke(prompt) - return response.content.strip() + response = model.invoke(prompt) + + if hasattr(response, "content"): + return response.content.strip() + else: + return response except Exception as e: logger.error(f"Error in {self.model_name} query: {str(e)}") logger.error(f"Full error details: {e}") @@ -237,7 +245,7 @@ def get_model(model_name: str) -> BaseModel: """ if model_name in ["Qwen/Qwen3-8B", "mistralai/Mistral-7B-Instruct-v0.2"]: return OpenSourceModel(model_name) - elif model_name in ["gpt-4", "gemini", "claude"]: + elif model_name in ["gpt-4o-mini", "gpt-4.1", "gpt-o3", "gemini", "claude"]: return LangChainModel(model_name) else: raise ValueError( diff --git a/src/ocr_load.py b/src/ocr_load.py index 3395613..96f3311 100644 --- a/src/ocr_load.py +++ b/src/ocr_load.py @@ -1,14 +1,13 @@ from abc import ABC, abstractmethod import base64 -import io import logging import os import sys -from typing import Optional, List, Dict, Any, Union +from typing import Optional, List, Dict, Any from dataclasses import dataclass from pathlib import Path import boto3 -from pdf2image import convert_from_path +import pymupdf from mistralai import Mistral @@ -44,14 +43,12 @@ def process(self, filepath: str) -> List[OCRResult]: # currently pdf needs to be converted to image -> is this the best way to do this? if ext == 'pdf': - images = convert_from_path(file_path) + images = pymupdf.open(file_path) results = [] for page_num, image in enumerate(images): print(f"Processing page {page_num + 1}") - img_byte_arr = io.BytesIO() - image.save(img_byte_arr, format='PNG') - file_bytes = img_byte_arr.getvalue() + file_bytes = image.get_pixmap().tobytes() response = client.detect_document_text(Document={'Bytes': file_bytes}) diff --git a/src/optimizer.py b/src/optimizer.py new file mode 100644 index 0000000..2c5f8e2 --- /dev/null +++ b/src/optimizer.py @@ -0,0 +1,300 @@ +import pandas as pd +from typing import List, Dict, Tuple + +from pydantic import BaseModel, Field +import tqdm + +from example import Example +from fv_extract import FVExtraction +from evaluation import values_match +from model import get_model + +class MetaPromptModule(BaseModel): + tag: str = Field(..., min_length=1) + instruction: List[str] = Field(..., min_length=3, max_length=8) + +class MetaPromptOutput(BaseModel): + response: List[MetaPromptModule] + + +class PromptOptimizer: + """ + Separate class for handling prompt optimization using training data and labels. + Extracts the optimization logic from FVExtraction and w2_extract.py. + """ + + def __init__(self, fields: dict[str, dict], model_type: str = "gpt-4o-mini", ): + self.fields = fields + self.model_type = model_type + + self.llm = get_model(model_type) + + def split_data( + self, + train_data: List[Example], + train_ratio: float = 0.7, + random_seed: int = 42 + ) -> Tuple[List[Dict], List[Dict], pd.DataFrame, pd.DataFrame]: + """ + Split data into train and evaluation sets. + + Args: + ocr_data: List of OCR data dictionaries + label_file: Path to labels CSV file + + Returns: + Tuple of (train_ocr_data, eval_ocr_data, train_labels_df, eval_labels_df) + """ + import random + + available_data = [data for data in train_data] + + if len(available_data) == 0: + raise ValueError("No files found with corresponding labels") + + # Split data + random.seed(random_seed) + random.shuffle(available_data) + + train_size = int(len(available_data) * train_ratio) + train_data = available_data[:train_size] + eval_data = available_data[train_size:] + + # Split labels + train_df = [data for data in train_data] + eval_df = [data for data in eval_data] + + print(f"Split: {len(train_df)} training files, {len(eval_df)} evaluation files") + + return train_df, eval_df + + def evaluate_with_demos(self, example_data: List[Example]): + total_comparisons, correct_comparisons, positive_cases, mismatches = 0, 0, [], [] + field_analysis = {} + + for data in example_data: + pred, true, rationale = data._demos + perfect_match = True + + for col in self.fields.keys(): + total_comparisons += 1 + if values_match(pred[col], true[col], col): + correct_comparisons += 1 + else: + mismatches.append({ + 'column': col, + 'pred_value': pred[col], + 'true_value': true[col], + 'rationale': rationale + }) + field_analysis[col] = field_analysis.get(col, 0) + 1 + perfect_match = False + + if perfect_match: + positive_cases.append({ + 'pred_value': pred, + 'true_value': true, + 'rationale': rationale + }) + + return total_comparisons, correct_comparisons, positive_cases, mismatches, field_analysis + + def evaluate_score(self, example_data: List[Example]): + total_comparisons, correct_comparisons = 0, 0 + + for data in example_data: + pred, true, _ = data._demos + + for col in self.fields.keys(): + total_comparisons += 1 + if values_match(pred[col], true[col], col): + correct_comparisons += 1 + + return total_comparisons, correct_comparisons + + + def predict_from_example(self, extractor: FVExtraction, train_data: List[Example], max_workers: int = 4): + """ + Extract fields from OCR data using the current extractor. + + Args: + ocr_data: List of OCR data dictionaries + spatial: Whether to use spatial coordinates + """ + from concurrent.futures import ThreadPoolExecutor, as_completed + + def process_single_example(data: Example): + """Process a single example in a worker thread.""" + try: + ocr_context = data.context + mode = data.mode + label = data.labels().get('output', {}) + + if mode == "kv": + page_result = extractor.fv_extract(kvs=ocr_context, include_rationale=True) + elif mode == "ocr": + page_result = extractor.fv_extract(ocr=ocr_context, include_rationale=True) + elif mode == "spatial": + page_result = extractor.fv_extract(coordinates=ocr_context, include_rationale=True) + else: + raise ValueError(f"Unsupported mode: {mode}") + + return data, (page_result['result'], label, page_result['rationale']), None + + except Exception as e: + return data, None, str(e) + + with ThreadPoolExecutor(max_workers=max_workers) as executor: + # Submit all tasks + future_to_example = { + executor.submit(process_single_example, data): data + for data in train_data + } + + # Process results as they complete + completed = 0 + total = len(train_data) + + for future in tqdm.tqdm(as_completed(future_to_example), total=total, desc="Processing examples"): + data = future_to_example[future] + try: + data, result, error = future.result() + if error is None and result is not None: + data._demos = result + completed += 1 + else: + print(f"Error processing example: {error}") + except Exception as e: + print(f"Exception processing example: {e}") + + print(f"Successfully processed {completed}/{total} examples") + + + def generate_meta_prompt_description(self, base_prompt_template: Dict, positive_cases: List[Example], mismatches: List[Dict], field_analysis: Dict): + prompt = """ + You will generate **3 to 5 distinct meta‑prompt modules**—each with its own `tag` and a set of **3–8 actionable guidance strings**—to improve OCR field extraction accuracy. + + Each module should: + 1. Target a **different strategy** (e.g., regex normalization, semantic validation, fallback heuristics, cross‑field checks). + 2. Be concise: each guidance item must be **one to two sentences**. + 3. Be **actionable** and **testable** (e.g., “If the extracted date does not match YYYY‑MM‑DD, re‑parse with US format and validate chronological order”). + + Before generating, analyze: + A. **Field Descriptions** for patterns, data types, typical formats. + B. **Positive Exemplars** to identify success drivers. + C. **Failed Exemplars & Error Analysis** to pinpoint recurring extraction issues. + D. **Prompt Limitations** and any systematic biases + + **Output Format** (JSON array): + ```json + [ + {{ + "tag": "", + "instruction": [ + "...", + "...", + "..." + ] + }}, + // up to 5 modules + ] + ``` + + ===Input=== + Field descriptions: + {field} + + Base prompt: + {base_prompt} + + Positive extraction exemplars: + {positive_cases} + + Failure extraction exemplars (mismatches): + {new_mismatches} + + Field error analysis: + {field_analysis} + + ===End of Input=== + """ + formatted_prompt = prompt.format( + field = self.fields, + base_prompt = base_prompt_template, + positive_cases = positive_cases, + new_mismatches = mismatches, + field_analysis = field_analysis + ) + response = self.llm.get_response(prompt=formatted_prompt, output_format=MetaPromptOutput) + + return response + + + def optimize_prompt( + self, + extractor: FVExtraction, + train_data: List[Example], + iterations: int = 3 + ) -> FVExtraction: + """ + Main optimization function that trains the prompt using mismatches. + + Args: + train_data: Either directory path or pre-processed OCR data + label_file: Path to labels CSV file + ocr_method: OCR method to use + spatial: Whether to use spatial coordinates + + Returns: + Optimized FVExtraction instance + """ + + # Split into train/eval + train_raw_df, eval_raw_df = self.split_data(train_data) + + best_accuracy = -1 + best_meta = None + + for i in range(iterations): + print(f"\n=== Optimization Iteration {i + 1}/{iterations} ===") + + # Extract fields from OCR data + self.predict_from_example(extractor, train_raw_df) + + # Evaluate + total_comparisons, correct_comparisons, positive_cases, mismatches, field_analysis = self.evaluate_with_demos(train_raw_df) + + if best_accuracy == -1: + best_accuracy = correct_comparisons / total_comparisons + + print(f"Accuracy: {correct_comparisons / total_comparisons * 100:.2f}%") + print(f"Positive cases: {len(positive_cases)}") + + response = self.generate_meta_prompt_description(extractor.base_template_parts, positive_cases, mismatches, field_analysis) + meta_prompts = response.response + print(f"Number of Meta prompt descriptions: {len(meta_prompts)}") + + + iteration_best_accuracy = best_accuracy + iteration_best_meta = best_meta + for meta in meta_prompts: + extractor.meta_prompt = meta + + self.predict_from_example(extractor, eval_raw_df) + + total_comparisons, correct_comparisons = self.evaluate_score(eval_raw_df) + + accuracy = correct_comparisons / total_comparisons if total_comparisons else 0 + print(f"Accuracy: {accuracy * 100:.2f}%") + if accuracy > iteration_best_accuracy: + iteration_best_accuracy = accuracy + iteration_best_meta = meta + + if iteration_best_accuracy > best_accuracy: + best_accuracy = iteration_best_accuracy + best_meta = iteration_best_meta + + extractor.meta_prompt = best_meta + print(f"Best meta prompt: {best_meta}") + + return extractor \ No newline at end of file diff --git a/src/w2_extract.py b/src/w2_extract.py index d26f8e5..abe6afe 100644 --- a/src/w2_extract.py +++ b/src/w2_extract.py @@ -5,6 +5,7 @@ import pickle import threading import time +from typing import List import pandas as pd import zipfile @@ -15,10 +16,12 @@ import yaml +from example import Example from kv_extract import KVExtractDocument -from fv_extract import OptimizedExtraction +from fv_extract import FVExtraction import ocr_load from evaluation import compare_datasets +from optimizer import PromptOptimizer w2_fields = { 'a': {"desc": "a Employee's Social Security Number (SSN)", "type": "string", 'example': '123456789', 'required': True}, @@ -149,15 +152,6 @@ def clear_batch_results(batch_id): pickle.dump(batch_data, f) print(f"Cleared field extraction results, kept OCR cache: {batch_file}") -def group_mismatches_by_file(mismatches): - """Group mismatches by filename.""" - from collections import defaultdict - grouped = defaultdict(list) - for mismatch in mismatches: - filename = mismatch['name'] - grouped[filename].append(mismatch) - return dict(grouped) - def get_cached_file_data(filename, batch_id, ocr_method): """Get cached OCR data for a specific file.""" _, _, ocr_cache = load_batch_results(batch_id) @@ -189,8 +183,78 @@ def load_fields(file) -> dict[str, dict]: if os.path.exists(tmp_path): os.unlink(tmp_path) +def create_trainset(ocr_cache, labels_df, fields, ocr_mode: str="ocr") -> List[Example]: + dataset = [] + + for idx, row in labels_df.iterrows(): + filename = row['file'] + + # Check if we have OCR data for this file + if filename not in ocr_cache: + print(f"Warning: No OCR data found for {filename}") + continue + + ocr_pages = ocr_cache[filename] + + + # Create expected output dictionary + expected_fields = {} + for field in fields.keys(): + value = row[field] + expected_fields[field] = value + + example = Example( + fields=str(fields), + context=str(ocr_pages), + mode=ocr_mode, + output=expected_fields + ) + + dataset.append(example) + print(f"Added training example for {filename}") + + return dataset + +def get_ocr_data(args): + filepath, filename, batch_id, ocr_method = args + + print(f"Processing {filepath}") + + file_cache_key = get_file_cache_key(filepath, ocr_method) + + # Load existing batch results + _, _, ocr_cache = load_batch_results(batch_id) -def extract_file(filepath: str, fields: dict[str, dict]=w2_fields, max_attempts: int=3, batch_id: str=None, filename: str=None, model_type: str="gpt-4o-mini", ocr_method: str="textract-kv", spatial: bool=False, extractor: OptimizedExtraction=None, save_batch: bool=True) -> list[dict[str, str]]: + if file_cache_key in ocr_cache: + print(f"Using cached OCR results for {filename}") + pages = ocr_cache[file_cache_key] + else: + print(f"Running OCR for {filename}") + # Run OCR + if ocr_method == "mistral": + ocr_method_obj = ocr_load.MistralOCR() + ocr = ocr_load.OCR(filepath, ocr_method_obj) + ocr_results = ocr.get_plain_text() + pages = ocr_results + elif ocr_method == "textractocr": + ocr_method_obj = ocr_load.TextractOCR() + ocr = ocr_load.OCR(filepath, ocr_method_obj) + ocr_results = ocr.get_plain_text() + pages = ocr_results + elif ocr_method == "textract-kv": + kv_extractor = KVExtractDocument(filepath) + pages = kv_extractor.pages + else: + raise ValueError(f"Unsupported OCR method: {ocr_method}") + + # Cache OCR results + ocr_cache[file_cache_key] = pages + save_batch_results(batch_id, _, _, ocr_cache) + + return filename, pages + + +def extract_file(filepath: str, fields: dict[str, dict]=w2_fields, max_attempts: int=3, batch_id: str=None, filename: str=None, model_type: str="gpt-4o-mini", ocr_method: str="textract-kv", spatial: bool=False, extractor: FVExtraction=None) -> list[dict[str, str]]: ''' Extracts W2 fields from a PDF file using Textract. @@ -223,20 +287,14 @@ def extract_file(filepath: str, fields: dict[str, dict]=w2_fields, max_attempts: ocr = ocr_load.OCR(filepath, ocr_method_obj) ocr_results = ocr.get_text_blocks() pages = ocr_results - print("========mistral OCR results==========") - print(json.dumps(pages[0], indent=4)) elif ocr_method == "textractocr": ocr_method_obj = ocr_load.TextractOCR() ocr = ocr_load.OCR(filepath, ocr_method_obj) ocr_results = ocr.get_text_blocks() pages = ocr_results - print("========textractocr OCR results==========") - print(json.dumps(pages[0]['value'], indent=4)) elif ocr_method == "textract-kv": kv_extractor = KVExtractDocument(filepath) pages = kv_extractor.pages - print("========textract-kv Image-to-text results==========") - print(pages[0].kvs) else: raise ValueError(f"Unsupported OCR method: {ocr_method}") @@ -249,7 +307,7 @@ def extract_file(filepath: str, fields: dict[str, dict]=w2_fields, max_attempts: return results if extractor is None: - extractor = OptimizedExtraction(fields=fields, model_type=model_type, max_attempts=max_attempts) + extractor = FVExtraction(fields=fields, model_type=model_type, max_attempts=max_attempts) # Process each page # kv_document.pages or ocr_results @@ -271,8 +329,7 @@ def extract_file(filepath: str, fields: dict[str, dict]=w2_fields, max_attempts: processed_files[filename] = processed_pages file_results[filename] = results - if save_batch: - save_batch_results(batch_id, processed_files, file_results, ocr_cache) + save_batch_results(batch_id, processed_files, file_results, ocr_cache) print(f"Processed page {page_num + 1} of {filename}") @@ -310,7 +367,9 @@ def process_single_file_parallel(args): print(f"Error processing {filename} after {processing_time:.2f} seconds: {str(e)}") return filename, None, e -def extract_dir(dirpath: str, fields=w2_fields, file_out=None, batch_id: str=None, model_type: str="gpt-4.1", ocr_method: str="textract-kv", spatial: bool=False, prompt_opt: bool=False, label_file: str=None, test_file: str=None, test_label: str=None, max_workers: int=4) -> pd.DataFrame: + +def extract_dir(dirpath: str, fields=w2_fields, file_out=None, batch_id: str=None, model_type: str="gpt-4o-mini", ocr_method: str="mistral", spatial: bool=False, prompt_opt: bool=False, label_file: str=None, training_file: str=None, training_label: str=None, max_workers: int=4) -> pd.DataFrame: + ''' Extracts W2 fields from all PDF/JPG/PNG files in a directory. @@ -324,13 +383,61 @@ def extract_dir(dirpath: str, fields=w2_fields, file_out=None, batch_id: str=Non ''' if batch_id is None: batch_id = get_batch_id_dir(dirpath) - + + # List all PDF/JPG/PNG files in the directory files = [f for f in os.listdir(dirpath) if f.endswith('.pdf') or f.endswith('.png') or f.endswith('.jpg')] paths = [os.path.join(dirpath, f) for f in files] - processed_files, file_results, ocr_cache = load_batch_results(batch_id) + base_extractor = FVExtraction(fields=fields, model_type=model_type) + extractor_to_use = base_extractor + + if prompt_opt and training_file and training_label: + training_files = [f for f in os.listdir(training_file) if f.endswith('.pdf') or f.endswith('.png') or f.endswith('.jpg')] + training_paths = [os.path.join(training_file, f) for f in training_files] + + process_args = [ + (filepath, filename, batch_id, ocr_method) + for filepath, filename in zip(training_paths, training_files) + ] + + results = {} + failed_files = [] + + if process_args: + with ThreadPoolExecutor(max_workers=max_workers) as executor: + # Submit all tasks + future_to_file = { + executor.submit(get_ocr_data, args): args[1] + for args in process_args + } + + # Collect results as they complete + for future in as_completed(future_to_file): + filename = future_to_file[future] + try: + filename, result = future.result() + results[filename] = result + except Exception as e: + failed_files.append((filename, e)) + print(f"✗ Exception processing {filename}: {e}") + + if failed_files: + print(f"\nFailed to process {len(failed_files)} files:") + for filename, error in failed_files: + print(f" - {filename}: {error}") + + labels_df = pd.read_csv(training_label, keep_default_na=False, na_filter=False) + dataset = create_trainset(ocr_cache=results, labels_df=labels_df, fields=fields, ocr_mode="ocr" if ocr_method != "textract-kv" else "kv") + + optimizer = PromptOptimizer(fields=fields, model_type=model_type) + optimized_extractor = optimizer.optimize_prompt(extractor=base_extractor, train_data=dataset) + extractor_to_use = optimized_extractor + + + processed_files, file_results, ocr_cache = load_batch_results(batch_id) + pending_files = [] pending_paths = [] completed_results = {} @@ -346,13 +453,12 @@ def extract_dir(dirpath: str, fields=w2_fields, file_out=None, batch_id: str=Non pending_files.append(filename) pending_paths.append(filepath) - print(f"Processing {len(pending_files)} files ({len(files) - len(pending_files)} already completed)") - base_extractor = OptimizedExtraction(fields=fields, model_type=model_type) if prompt_opt else None + print(f"Processing {len(pending_files)} files ({len(files) - len(pending_files)} already completed)") # Prepare arguments for parallel processing process_args = [ - (filepath, filename, fields, batch_id, model_type, ocr_method, spatial, base_extractor) + (filepath, filename, fields, batch_id, model_type, ocr_method, spatial, extractor_to_use) for filepath, filename in zip(pending_paths, pending_files) ] @@ -402,139 +508,39 @@ def extract_dir(dirpath: str, fields=w2_fields, file_out=None, batch_id: str=Non temp.update(page) entries.append(temp) - # Optionally convert to DataFrame and save as CSV - # df = pd.DataFrame.from_dict(entries, orient='index') df = pd.DataFrame(entries) print(df) - if prompt_opt and label_file: - print("Running prompt optimization with evaluation...") - - test_files = [f for f in os.listdir(test_file) if f.endswith('.pdf') or f.endswith('.png') or f.endswith('.jpg')] - test_paths = [os.path.join(test_file, f) for f in test_files] + + # Save to CSV or Excel if specified + if file_out is not None: + # CSV + if file_out.endswith('.csv'): + df.to_csv(file_out) - # Create temporary directory for intermediate results - with tempfile.TemporaryDirectory() as temp_dir: - # Save initial results - temp_preds_file = os.path.join(temp_dir, "temp_preds.csv") - df.to_csv(temp_preds_file, index=False) - - # First evaluation - total_comparisons, correct_comparisons, mismatches = compare_datasets( - labels_files=[label_file], - preds_file=temp_preds_file - ) - print(f"Initial accuracy: {correct_comparisons/total_comparisons:.2%}") - print(f"Initial mismatches: {len(mismatches)}") - grouped_mismatches = group_mismatches_by_file(mismatches) - - file_mismatches = {} - for filename, mismatches in grouped_mismatches.items(): - cached_data = get_cached_file_data(filename, batch_id, ocr_method) - file_mismatches[filename] = (mismatches, cached_data) - - base_extractor.add_iteration_mismatches(file_mismatches) - - # Second run with mismatches - print("Running second extraction with mismatches...") - clear_batch_results(batch_id) - - process_args = [ - (filepath, filename, fields, batch_id, model_type, ocr_method, spatial, base_extractor) - for filepath, filename in zip(test_paths, test_files) - ] - - results = {} - with ThreadPoolExecutor(max_workers=max_workers) as executor: - future_to_file = { - executor.submit(process_single_file_parallel, args): args[1] - for args in process_args - } - - for future in as_completed(future_to_file): - filename = future_to_file[future] - try: - filename, result, error = future.result() - if error is None and result is not None: - results[filename] = result - except Exception as e: - print(f"Error in second run for {filename}: {e}") - - # Combine second run results - entries = [] - for filename in files: - if filename in results: - for pnum, page in enumerate(results[filename]): - temp = {'file': filename, 'page': pnum + 1} - temp.update(page) - entries.append(temp) - - df = pd.DataFrame(entries) - df.to_csv(temp_preds_file, index=False) - - # Second evaluation - total_comparisons, correct_comparisons, new_mismatches = compare_datasets( - labels_files=[label_file], - preds_file=temp_preds_file - ) - print(f"Second run accuracy: {correct_comparisons/total_comparisons:.2%}") - print(f"Second run mismatches: {len(new_mismatches)}") - print(new_mismatches) - - new_grouped_mismatches = group_mismatches_by_file(new_mismatches) - - new_file_mismatches = {} - for filename, mismatches in new_grouped_mismatches.items(): - cached_data = get_cached_file_data(filename, batch_id, ocr_method) - new_file_mismatches[filename] = (mismatches, cached_data) - - base_extractor.add_iteration_mismatches(new_file_mismatches) - - # Final run with updated mismatches - print("Running final extraction with updated mismatches...") - clear_batch_results(batch_id) - - process_args = [ - (filepath, filename, fields, batch_id, model_type, ocr_method, spatial, base_extractor) - for filepath, filename in zip(test_paths, test_files) - ] - - results = {} - with ThreadPoolExecutor(max_workers=max_workers) as executor: - future_to_file = { - executor.submit(process_single_file_parallel, args): args[1] - for args in process_args - } - - for future in as_completed(future_to_file): - filename = future_to_file[future] - try: - filename, result, error = future.result() - if error is None and result is not None: - results[filename] = result - except Exception as e: - print(f"Error in final run for {filename}: {e}") - - # Combine final run results - entries = [] - for filename in test_files: - if filename in results: - for pnum, page in enumerate(results[filename]): - temp = {'file': filename, 'page': pnum + 1} - temp.update(page) - entries.append(temp) - - df = pd.DataFrame(entries) - df.to_csv(temp_preds_file, index=False) - - # Final evaluation - total_comparisons, correct_comparisons, mismatches = compare_datasets( - labels_files=[test_label], - preds_file=temp_preds_file - ) - print(f"Final accuracy: {correct_comparisons/total_comparisons:.2%}") + # Excel + elif file_out.endswith('.xlsx'): + # Use xlsxwriter to save DataFrame to Excel + with pd.ExcelWriter(file_out, engine='xlsxwriter') as writer: + df.to_excel(writer, index=False, sheet_name='Sheet1') + worksheet = writer.sheets['Sheet1'] + + # Adjust column widths + for col_idx, col in enumerate(df.columns, start=1): + max_length = max(df[col].astype(str).map(len).max(), len(col)) + 2 # Add padding + worksheet.set_column(col_idx - 1, col_idx - 1, max_length) + + if label_file: + total_comparisons, correct_comparisons, final_mismatches = compare_datasets( + labels_files=[label_file], + preds_file=file_out + ) + print(f"Final test accuracy: {correct_comparisons/total_comparisons:.2%}") + print(f"Final test mismatches: {len(final_mismatches)}") + return df + def extract_zip(file, fields=w2_fields, file_out=None) -> pd.DataFrame: ''' @@ -587,14 +593,14 @@ def extract_zip(file, fields=w2_fields, file_out=None) -> pd.DataFrame: parser.add_argument("--type", default=None, choices=['file', 'dir', 'zip'], type=str, help="Type of extraction: 'file' for single file, 'dir' for directory, 'zip' for zip file.") parser.add_argument("--file_out", default=None, type=str, help="Path to save the output CSV or Excel file. None to skip saving.") parser.add_argument("--max_attempts", default=3, type=int, help="Maximum number of attempts to extract fields.") - parser.add_argument("--model_type_or_path", default="gpt-4o-mini", choices=["gpt-4o-mini", "gpt-4.1", "gpt-4.1-mini", "gpt-o3"], help="Model type used for fv extraction") + parser.add_argument("--model_type_or_path", default="gpt-4o-mini", choices=["gpt-4o-mini", "gpt-4.1", "gpt-4.1-mini", "gpt-o3", "Qwen/Qwen2.5-7B-Instruct-Turbo", "Qwen/Qwen3-8B", "mistralai/Mistral-7B-Instruct-v0.2"], help="Model type used for fv extraction") parser.add_argument("--ocr_method", default="textract-kv", choices=["mistral", "textractocr", "textract-kv"], help="OCR method to use") parser.add_argument("--spatial_ocr", action="store_true", help="Enable spatial OCR with coordinate information") parser.add_argument("--prompt_opt", action="store_true", help="Enable prompt optimization with evaluation") parser.add_argument("--label_file", type=str, help="Path to the label file for evaluation when prompt_opt is True") - parser.add_argument("--test_file", type=str, help="Path to the test file path for evaluation when prompt_opt is True") - parser.add_argument("--test_label", type=str, help="Path to the test label file for evaluation when prompt_opt is True") - parser.add_argument("--save_batch", action="store_true", help="Save batch results") + parser.add_argument("--train_file", type=str, help="Path to the training file path for prompt optimization when prompt_opt is True") + parser.add_argument("--train_label", type=str, help="Path to the training label file for prompt optimization when prompt_opt is True") + args = parser.parse_args() try: @@ -621,18 +627,20 @@ def extract_zip(file, fields=w2_fields, file_out=None) -> pd.DataFrame: if args.type == 'file': # If a single PDF file is provided, extract from that file filename = os.path.basename(args.filepath) - result_df = extract_file(args.filepath, custom_fields, args.max_attempts, batch_id=0, filename=filename, model_type=args.model_type_or_path, ocr_method=args.ocr_method, spatial=args.spatial_ocr, save_batch=args.save_batch) + result_df = extract_file(args.filepath, custom_fields, args.max_attempts, batch_id=0, filename=filename, model_type=args.model_type_or_path, ocr_method=args.ocr_method, spatial=args.spatial_ocr) elif args.type == 'dir': # If a directory is provided, extract from all PDF files in that directory - result_df = extract_dir(args.filepath, custom_fields, args.file_out, model_type=args.model_type_or_path, ocr_method=args.ocr_method, spatial=args.spatial_ocr, prompt_opt=args.prompt_opt, label_file=args.label_file, test_file=args.test_file, test_label=args.test_label) + result_df = extract_dir(args.filepath, custom_fields, args.file_out, model_type=args.model_type_or_path, ocr_method=args.ocr_method, spatial=args.spatial_ocr, prompt_opt=args.prompt_opt, label_file=args.label_file, training_file=args.train_file, training_label=args.train_label) elif args.type == 'zip': # If a ZIP file is provided, extract from all PDF files in that ZIP file with open(args.filepath, 'rb') as f: zip_file = BytesIO(f.read()) zip_file.name = os.path.basename(args.filepath) result_df = extract_zip(zip_file, custom_fields, args.file_out) + print("Processing completed successfully!") print(result_df) + # Save to CSV or Excel if specified if args.file_out is not None: # CSV