diff --git a/src/ether0/model_prompts.py b/src/ether0/model_prompts.py index 2e42287..d6fd9f9 100644 --- a/src/ether0/model_prompts.py +++ b/src/ether0/model_prompts.py @@ -86,9 +86,9 @@ def get_prompt(self) -> str: case ProblemPrompt.NONE: return "" case ProblemPrompt.THINK_ANSWER: - return XMLAnswerPrompts.REASONING_ANSWER + return XMLAnswerPrompts.REASONING_ANSWER.value case ProblemPrompt.ANSWER: - return XMLAnswerPrompts.ANSWER_ONLY + return XMLAnswerPrompts.ANSWER_ONLY.value case _: assert_never(self) diff --git a/tests/test_model_prompts.py b/tests/test_model_prompts.py index 55f6b3d..80b25a0 100644 --- a/tests/test_model_prompts.py +++ b/tests/test_model_prompts.py @@ -5,11 +5,29 @@ ANSWER_START, THINK_END, THINK_START, + ProblemPrompt, extract_answer_loose, extract_thought_answer_strict, ) +def test_problem_prompt() -> None: + none_prompt = ProblemPrompt.NONE.get_prompt() + assert isinstance(none_prompt, str) + assert "think" not in none_prompt + assert "answer" not in none_prompt + + answer_prompt = ProblemPrompt.ANSWER.get_prompt() + assert isinstance(answer_prompt, str) + assert "think" not in answer_prompt + assert "answer" in answer_prompt + + think_answer_prompt = ProblemPrompt.THINK_ANSWER.get_prompt() + assert isinstance(think_answer_prompt, str) + assert "think" in think_answer_prompt + assert "answer" in think_answer_prompt + + @pytest.mark.parametrize( ("content", "expected"), [