diff --git a/scripts/response_rewrite.py b/scripts/response_rewrite.py index 31945cb1..6c19f568 100644 --- a/scripts/response_rewrite.py +++ b/scripts/response_rewrite.py @@ -507,7 +507,7 @@ def main(): variants_dataset, ["fcs", "fcs_plus1", "fcs_reflection"], llm ) - system_prompt = ModelConfig.from_model_id(args.target_model).system_prompt + system_prompt = str(ModelConfig.from_model_id(args.target_model).system_prompt) # Generate conversation format for each variant, which can be used in SimPO/DPO/etc. fcs_convo = make_preference_conversations(final_dataset, "fcs", system_prompt) diff --git a/skythought/evals/inference_and_check.py b/skythought/evals/inference_and_check.py index 9f8d47b2..7a534565 100644 --- a/skythought/evals/inference_and_check.py +++ b/skythought/evals/inference_and_check.py @@ -237,7 +237,7 @@ def generate_responses_for_dataset( # Prepare conversations conversations = handler.make_conversations( remaining_data, - model_config.system_prompt, + str(model_config.system_prompt), model_config.user_template, model_config.assistant_prefill, ) diff --git a/skythought/evals/models/base.py b/skythought/evals/models/base.py index 3f85cf58..425b215f 100644 --- a/skythought/evals/models/base.py +++ b/skythought/evals/models/base.py @@ -28,6 +28,9 @@ def validate_and_extract_string(self): def string(self): return self._string + def __str__(self) -> str: + return self._string + def read_yaml(path: str): with open(path, "r") as f: