Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 9 additions & 1 deletion eq-bench.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,9 @@ def main():
help="Set the language of the question dataset. Currently supported: en, de")
parser.add_argument('-r', type=int, default=5,
help="Set the number of retries to attempt if a benchmark run fails. Default 5.")
parser.add_argument('--longoutput', action='store_true',
help="Remove token limit restrictions on outputs to handle very long responses.")

args = parser.parse_args()
resume = not args.w

Expand Down Expand Up @@ -110,6 +113,11 @@ def main():
base_filename, extension = questions_fn.rsplit('.', 1)
# Appending language denotifier
questions_fn = f"{base_filename}_{language}.{extension}"

if args.longoutput:
COMPLETION_TOKENS = 4096 # Allow for very long outputs
else:
COMPLETION_TOKENS = 600 if REVISE else 60

# Creative writing Judge params
judge_params = {
Expand Down Expand Up @@ -242,7 +250,7 @@ def main():
ooba_params_global=ooba_params_global, fast_download=args.f,
hf_access_token=hf_access_token, ooba_request_timeout=ooba_request_timeout,
questions_fn=questions_fn, openai_client=openai_client, language=language,
REVISE=REVISE, benchmark_types=args.benchmarks, judge_params = judge_params)
REVISE=REVISE, benchmark_types=args.benchmarks, judge_params = judge_params, completion_tokens=COMPLETION_TOKENS)
except KeyboardInterrupt:
if ooba_instance:
ooba_instance.stop()
Expand Down
12 changes: 3 additions & 9 deletions lib/eq_bench_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
def process_question(question_id, q, model_path, prompt_type, model, tokenizer, results, run_index,
run_iter, verbose, n_question_attempts, inference_engine, ooba_instance,
launch_ooba, ooba_request_timeout, openai_client, eqbench_version, language,
REVISE):
REVISE, completion_tokens):
"""
Process a single question and update the results.
:param question_id: ID of the question.
Expand All @@ -24,6 +24,7 @@ def process_question(question_id, q, model_path, prompt_type, model, tokenizer,
:param verbose: Verbose output flag.
:param n_question_attempts: Number of attempts per question.
:param language: language of the test questions ("en" default, "de" also supported)
:param completion_tokens: Maximum number of tokens for model output.
:return: Updated results.
"""

Expand All @@ -34,16 +35,9 @@ def process_question(question_id, q, model_path, prompt_type, model, tokenizer,
else:
ref_fullscale = None

COMPLETION_TOKENS = 60
if REVISE:
COMPLETION_TOKENS = 600

if eqbench_version == 'v2' and not REVISE:
prompt = remove_revision_instructions(prompt, language)




tries = 0
success = False
temp = 0.01 # Low temp is important for consistency of results
Expand All @@ -52,7 +46,7 @@ def process_question(question_id, q, model_path, prompt_type, model, tokenizer,
prev_result_inference = None
prev_result_parsed_answers = None
while tries < n_question_attempts and not success:
inference = run_query(model_path, prompt_type, prompt, [], COMPLETION_TOKENS, model, tokenizer, temp, inference_engine, ooba_instance, launch_ooba, ooba_request_timeout, openai_client)
inference = run_query(model_path, prompt_type, prompt, [], completion_tokens, model, tokenizer, temp, inference_engine, ooba_instance, launch_ooba, ooba_request_timeout, openai_client)

try:
if verbose:
Expand Down
24 changes: 14 additions & 10 deletions lib/run_bench.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@ def load_model_and_launch_ooba(model_path, lora_path, quantization, inference_en
raise Exception("Ooba failed to launch.")
return model, tokenizer, ooba_instance

def process_questions(benchmark_type, model, ooba_instance, inference_engine, results, model_path, prompt_type, tokenizer, launch_ooba, ooba_request_timeout, run_index, run_iter, verbose, n_attempts, openai_client, questions, eqbench_version, language, REVISE, judge_params, test_model_outputs, process_fn):
def process_questions(benchmark_type, model, ooba_instance, inference_engine, results, model_path, prompt_type, tokenizer, launch_ooba, ooba_request_timeout, run_index, run_iter, verbose, n_attempts, openai_client, questions, eqbench_version, language, REVISE, judge_params, test_model_outputs, process_fn, completion_tokens):
if benchmark_type == 'judgemark':
for model_name, model_outputs in test_model_outputs.items():
print('########################')
Expand All @@ -174,7 +174,7 @@ def process_questions(benchmark_type, model, ooba_instance, inference_engine, re
scores = process_fn(prompt_id, prompt_data, None, None, None, None, results, run_index,
run_iter, verbose, 0, inference_engine, ooba_instance,
launch_ooba, ooba_request_timeout, openai_client, judge_params,
test_model_response, model_name)
test_model_response, model_name, completion_tokens)
model_scores.append(scores)
safe_dump(results, RAW_RESULTS_PATH)

Expand All @@ -187,11 +187,11 @@ def process_questions(benchmark_type, model, ooba_instance, inference_engine, re
if benchmark_type == 'eq-bench':
process_fn(question_id, q, model_path, prompt_type, model, tokenizer, results, run_index, run_iter, verbose,
n_attempts, inference_engine, ooba_instance, launch_ooba, ooba_request_timeout, openai_client, eqbench_version,
language, REVISE)
language, REVISE, completion_tokens)
elif benchmark_type == 'creative-writing':
scores = process_fn(question_id, q, model_path, prompt_type, model, tokenizer, results, run_index,
run_iter, verbose, n_attempts, inference_engine, ooba_instance, launch_ooba,
ooba_request_timeout, openai_client, judge_params)
ooba_request_timeout, openai_client, judge_params, completion_tokens)
if scores:
if verbose:
print(scores)
Expand Down Expand Up @@ -330,10 +330,14 @@ def run_generic_benchmark(run_id, model_path, lora_path, prompt_type, quantizati
ooba_params_global, fast_download,
hf_access_token, ooba_request_timeout,
questions_fn=None, openai_client=None, language='en',
REVISE=False, benchmark_type='eq-bench', judge_params={}):
REVISE=False, benchmark_type='eq-bench', judge_params={}, completion_tokens=None):

questions, process_fn, scoring_fn, save_result_to_db_fn, run_index, eqbench_version, test_model_outputs = setup_benchmark(benchmark_type, run_id, model_path, lora_path, prompt_type, quantization, inference_engine, ooba_params, include_patterns, exclude_patterns, language, judge_params, questions_fn)

if completion_tokens is None:
if benchmark_type == 'eq-bench':
completion_tokens = 600 if (REVISE or eqbench_version == 'v1') else 60

results = initialize_results(run_index, benchmark_type, resume, n_iterations, run_id, model_path, lora_path, prompt_type, quantization, inference_engine, ooba_params, include_patterns, exclude_patterns, judge_params, language, eqbench_version)

initialize_iterations(results, run_index, n_iterations, benchmark_type, resume)
Expand Down Expand Up @@ -364,7 +368,7 @@ def run_generic_benchmark(run_id, model_path, lora_path, prompt_type, quantizati
run_index, run_iter,
verbose)

process_questions(benchmark_type, model, ooba_instance, inference_engine, results, model_path, prompt_type, tokenizer, launch_ooba, ooba_request_timeout, run_index, run_iter, verbose, n_attempts, openai_client, questions, eqbench_version, language, REVISE, judge_params, test_model_outputs, process_fn)
process_questions(benchmark_type, model, ooba_instance, inference_engine, results, model_path, prompt_type, tokenizer, launch_ooba, ooba_request_timeout, run_index, run_iter, verbose, n_attempts, openai_client, questions, eqbench_version, language, REVISE, judge_params, test_model_outputs, process_fn, completion_tokens)

if benchmark_type == 'judgemark':
compute_judgemark_results(results, run_index, test_model_outputs, verbose)
Expand Down Expand Up @@ -460,7 +464,7 @@ def run_benchmark(run_id, model_path, lora_path, prompt_type, quantization,
ooba_params_global='', fast_download=False,
hf_access_token=None, ooba_request_timeout=300,
questions_fn=None, openai_client=None, language='en',
REVISE=False, benchmark_types=[], judge_params={}):
REVISE=False, benchmark_types=[], judge_params={}, completion_tokens=None):

for benchmark_type in benchmark_types:
if benchmark_type == 'eq-bench':
Expand All @@ -476,7 +480,7 @@ def run_benchmark(run_id, model_path, lora_path, prompt_type, quantization,
ooba_params_global, fast_download,
hf_access_token, ooba_request_timeout,
questions_fn, openai_client, language,
REVISE, benchmark_type)
REVISE, benchmark_type, completion_tokens=completion_tokens)

elif benchmark_type == 'creative-writing':
run_generic_benchmark(run_id, model_path, lora_path, prompt_type, quantization,
Expand All @@ -491,7 +495,7 @@ def run_benchmark(run_id, model_path, lora_path, prompt_type, quantization,
ooba_params_global, fast_download,
hf_access_token, ooba_request_timeout,
openai_client=openai_client, judge_params=judge_params,
benchmark_type=benchmark_type)
benchmark_type=benchmark_type, completion_tokens=completion_tokens)

elif benchmark_type == 'judgemark':
run_generic_benchmark(run_id, None, None, None, None,
Expand All @@ -506,7 +510,7 @@ def run_benchmark(run_id, model_path, lora_path, prompt_type, quantization,
ooba_params_global, fast_download,
hf_access_token, ooba_request_timeout,
openai_client=openai_client, judge_params=judge_params,
benchmark_type=benchmark_type)
benchmark_type=benchmark_type, completion_tokens=completion_tokens)



Expand Down
10 changes: 9 additions & 1 deletion lib/scoring.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,15 @@
import math
from lib.util import safe_dump

def remove_think_blocks(text):
"""Remove all content between <think> and </think> tags."""
return re.sub(r'<think>.*?</think>', '', text, flags=re.DOTALL)

# Parse the emotion intensity ratings from the raw inference text
def parse_answers(text, REVISE):
# First remove any think blocks
text = remove_think_blocks(text)

first_pass_answers = {}
revised_answers = {}

Expand All @@ -30,7 +37,8 @@ def parse_answers(text, REVISE):

# we parse answers in German language ("de")
def parse_answers_de(text, REVISE):
#print("Using german parsing.")
# First remove any think blocks
text = remove_think_blocks(text)
first_pass_answers = {}
revised_answers = {}

Expand Down