From f3bf2a18f7d28bfc142996ff1798497b036936b0 Mon Sep 17 00:00:00 2001 From: Robert Blackwell Date: Thu, 24 Jul 2025 15:45:30 +0100 Subject: [PATCH] Add -n parameter, number of generations --- golem.py | 14 +++++++++++++- openai.py | 4 ++++ 2 files changed, 17 insertions(+), 1 deletion(-) diff --git a/golem.py b/golem.py index c52840c..fc35942 100755 --- a/golem.py +++ b/golem.py @@ -1,5 +1,5 @@ #!/usr/bin/env python3 -# Time-stamp: <2025-07-21 15:39:41 rblackwell> +# Time-stamp: <2025-07-24 09:22:48 rblackwell> """Golem @@ -81,6 +81,7 @@ def ask( url = args.url key = args.key reasoning_effort = args.reasoning_effort + n = args.n if provider == "openai": return ask_openai( @@ -96,6 +97,7 @@ def ask( logprobs, top_logprobs, reasoning_effort, + n, ) if provider == "deepseek": @@ -123,6 +125,7 @@ def ask( logprobs, top_logprobs, reasoning_effort, + n, ) if provider == "xai": @@ -150,6 +153,7 @@ def ask( logprobs, top_logprobs, reasoning_effort, + n, ) if provider == "azure": @@ -209,6 +213,7 @@ def ask( logprobs, top_logprobs, reasoning_effort, + n, ) if top_logprobs is not None: @@ -334,6 +339,13 @@ def make_parser(): help="Skip n records in the JSONL. Useful for restarting after a crash.", ) + parser.add_argument( + "--n", + type=int, + default=None, + help="How many chat completion choices to generate for each input message.", + ) + parser.add_argument( "--repeat", type=str, diff --git a/openai.py b/openai.py index 1d2b70c..dbffcc7 100644 --- a/openai.py +++ b/openai.py @@ -20,6 +20,7 @@ def ask_openai( logprobs, top_logprobs, reasoning_effort, + n, ): """ Make a request to the OpenAI API. @@ -52,6 +53,9 @@ def ask_openai( if top_p is not None: json_data["top_p"] = top_p + if n is not None: + json_data["n"] = n + if max_tokens is not None: json_data["max_tokens"] = max_tokens