diff --git a/evaluation/async_evaluate.py b/evaluation/async_evaluate.py index ef4d088..396e4c4 100644 --- a/evaluation/async_evaluate.py +++ b/evaluation/async_evaluate.py @@ -53,7 +53,7 @@ def __init__( if not os.path.exists(self.model_config_path): raise FileNotFoundError(f"Model configuration for {model_name} not found in {model_configs_dir}") - + with open(self.model_config_path, "r") as file: self.model_config = json.load(file) @@ -84,11 +84,11 @@ def __init__( if not os.path.exists(haystack_path): raise FileNotFoundError(f"Haystack file not found at {haystack_path}") - + self.haystack_path = haystack_path self.haystack = BookHaystack(self.haystack_path) - + self.results_dir = results_dir os.makedirs(results_dir, exist_ok=True) @@ -103,14 +103,14 @@ def __init__( self.seed = seed self.prevent_duplicate = prevent_duplicate self.distractor = distractor - + self.log_placements = log_placements_dir != "" self.log_placements_dir = log_placements_dir self.test_name = test_name self.eval_name = f"{model_name}_book_{test_name}_{int(time.time())}" if test_name != "" else f"{model_name}_book_{int(time.time())}" - + def _evaluate_response(self, response: str, gold_answers = None) -> int: if gold_answers is None: gold_answers = self.gold_answers @@ -132,7 +132,7 @@ def get_hash(self, test_config: dict) -> str: del new_config["results"] del new_config["eval_name"] return hashlib.sha256(json.dumps(new_config, sort_keys=True).encode()).hexdigest() - + def evaluate(self) -> None: np.random.seed(self.seed) @@ -170,7 +170,7 @@ def evaluate(self) -> None: print(f"Results already exist at {results_path}") print("Skipping evaluation") return - + if self.prevent_duplicate: for result_filename in os.listdir(self.results_dir): with open(os.path.join(self.results_dir, result_filename), "r") as file: @@ -180,7 +180,7 @@ def evaluate(self) -> None: print(f"Duplicate test found with similar hash at {self.results_dir} -- TEST_HASH:", outputs["test_hash"]) print("Skipping evaluation") return - + async_tasks = [] for i in tqdm(np.linspace(self.document_depth_percent_min, self.document_depth_percent_max, self.document_depth_percent_intervals)): needle_depth = i / 100 @@ -200,7 +200,7 @@ def evaluate(self) -> None: retrieval_question = self.retrieval_question placement_output = self.haystack.generate_w_needle_placement( - needle=needle, + needle=needle, token_count_func=self.api_connector.token_count, encoding_func=self.api_connector.encode, decoding_func=self.api_connector.decode, @@ -216,7 +216,7 @@ def evaluate(self) -> None: async_tasks.append(self.api_connector.generate_response( system_prompt=self.system_prompt, user_prompt=filled_template, - max_tokens=self.model_config["max_tokens"], + max_tokens=self.model_config["max_tokens"], temperature=self.model_config["temperature"], top_p=self.model_config["top_p"] )) @@ -238,14 +238,14 @@ def evaluate(self) -> None: outputs["results"][i]["metric"] = self._evaluate_response(responses[i]["response"], gold_answers=[outputs["results"][i]["selected_character"]]) if "{CHAR}" in self.needle else self._evaluate_response(responses[i]["response"]) for k, v in responses[i].items(): outputs["results"][i][k] = v - + # Save results by model name, haystack type, timestamp with open(results_path, "w") as file: json.dump(outputs, file, indent=4) - + print(f"Results saved at {results_path}") - + if __name__ == "__main__": @@ -261,7 +261,6 @@ def evaluate(self) -> None: parser.add_argument("--system_prompt", type=str, help="System prompt for the model") parser.add_argument("--use_default_system_prompt", type=bool, default=False, help="Use default system prompt") parser.add_argument("--task_template", type=str, help="Task template for the model") - parser.add_argument("--system_prompt", type=str, help="System prompt for the model") parser.add_argument("--context_length", type=int, help="Context length for the needle placement") parser.add_argument("--document_depth_percent_min", type=int, default=0, help="Minimum document depth percentage") parser.add_argument("--document_depth_percent_max", type=int, default=100, help="Maximum document depth percentage")