11# Copyright (c) Microsoft Corporation.
22# Licensed under the MIT license.
33
4- import os
54from unittest .mock import AsyncMock , patch
65
76import pytest
87from unit .mocks import MockPromptTarget
98
10- from pyrit .exceptions .exception_classes import InvalidJsonException
119from pyrit .models import PromptRequestPiece , PromptRequestResponse
1210from pyrit .prompt_converter import TranslationConverter
1311
@@ -26,89 +24,71 @@ def test_translator_converter_languages_validation_throws(languages, duckdb_inst
2624
2725
2826@pytest .mark .asyncio
29- @pytest .mark .parametrize (
30- "converted_value" ,
31- [
32- "Invalid Json" ,
33- "{'str' : 'json not formatted correctly'}" ,
34- ],
35- )
36- async def test_translation_converter_send_prompt_async_bad_json_exception_retries (converted_value , duckdb_instance ):
37-
27+ async def test_translation_converter_convert_async_retrieve_key_capitalization_mismatch (duckdb_instance ):
3828 prompt_target = MockPromptTarget ()
3929
40- prompt_variation = TranslationConverter (converter_target = prompt_target , language = "en" )
41-
42- with patch ("unit.mocks.MockPromptTarget.send_prompt_async" , new_callable = AsyncMock ) as mock_create :
30+ translation_converter = TranslationConverter (converter_target = prompt_target , language = "spanish" )
31+ with patch .object (translation_converter , "_send_translation_prompt_async" , new = AsyncMock (return_value = "hola" )):
4332
44- prompt_req_resp = PromptRequestResponse (
45- request_pieces = [
46- PromptRequestPiece (
47- role = "user" ,
48- conversation_id = "12345679" ,
49- original_value = "test input" ,
50- converted_value = "this is not a json" ,
51- original_value_data_type = "text" ,
52- converted_value_data_type = "text" ,
53- prompt_target_identifier = {"target" : "target-identifier" },
54- orchestrator_identifier = {"test" : "test" },
55- labels = {"test" : "test" },
56- )
57- ]
58- )
59- mock_create .return_value = prompt_req_resp
33+ raised = False
34+ try :
35+ await translation_converter .convert_async (prompt = "hello" )
36+ except KeyError :
37+ raised = True # There should be no KeyError
6038
61- with pytest .raises (InvalidJsonException ):
62- await prompt_variation .convert_async (prompt = "testing" , input_type = "text" )
63- assert mock_create .call_count == os .getenv ("RETRY_MAX_NUM_ATTEMPTS" )
39+ assert raised is False
6440
6541
6642@pytest .mark .asyncio
67- async def test_translation_converter_send_prompt_async_json_bad_format_retries (duckdb_instance ):
43+ async def test_translation_converter_retries_on_exception (duckdb_instance ):
6844 prompt_target = MockPromptTarget ()
45+ max_retries = 3
46+ translation_converter = TranslationConverter (
47+ converter_target = prompt_target , language = "spanish" , max_retries = max_retries
48+ )
49+
50+ mock_send_prompt = AsyncMock (side_effect = Exception ("Test failure" ))
51+ with patch .object (prompt_target , "send_prompt_async" , mock_send_prompt ):
52+ with pytest .raises (Exception ):
53+ await translation_converter .convert_async (prompt = "hello" )
6954
70- prompt_variation = TranslationConverter (converter_target = prompt_target , language = "en" )
71-
72- with patch ("unit.mocks.MockPromptTarget.send_prompt_async" , new_callable = AsyncMock ) as mock_create :
73-
74- prompt_req_resp = PromptRequestResponse (
75- request_pieces = [
76- PromptRequestPiece (
77- role = "user" ,
78- conversation_id = "12345679" ,
79- original_value = "test input" ,
80- converted_value = "this is not a json" ,
81- original_value_data_type = "text" ,
82- converted_value_data_type = "text" ,
83- prompt_target_identifier = {"target" : "target-identifier" },
84- orchestrator_identifier = {"test" : "test" },
85- labels = {"test" : "test" },
86- )
87- ]
88- )
89- mock_create .return_value = prompt_req_resp
90-
91- with pytest .raises (InvalidJsonException ):
92- await prompt_variation .convert_async (prompt = "testing" , input_type = "text" )
93- assert mock_create .call_count == os .getenv ("RETRY_MAX_NUM_ATTEMPTS" )
55+ assert mock_send_prompt .call_count == max_retries
9456
9557
9658@pytest .mark .asyncio
97- async def test_translation_converter_convert_async_retrieve_key_capitalization_mismatch (duckdb_instance ):
59+ async def test_translation_converter_succeeds_after_retries (duckdb_instance ):
60+ """Test that TranslationConverter succeeds if a retry attempt works."""
9861 prompt_target = MockPromptTarget ()
99-
100- translation_converter = TranslationConverter (converter_target = prompt_target , language = "spanish" )
101- with patch .object (
102- translation_converter , "send_translation_prompt_async" , new = AsyncMock (return_value = {"Spanish" : "hola" })
103- ):
104-
105- raised = False
106- try :
107- await translation_converter .convert_async (prompt = "hello" )
108- except KeyError :
109- raised = True # There should be no KeyError
110-
111- assert raised is False
62+ max_retries = 3
63+ translation_converter = TranslationConverter (
64+ converter_target = prompt_target , language = "spanish" , max_retries = max_retries
65+ )
66+
67+ success_response = PromptRequestResponse (
68+ request_pieces = [
69+ PromptRequestPiece (
70+ role = "assistant" ,
71+ conversation_id = "test-id" ,
72+ original_value = "hello" ,
73+ converted_value = "hola" ,
74+ original_value_data_type = "text" ,
75+ converted_value_data_type = "text" ,
76+ prompt_target_identifier = {"target" : "test-identifier" },
77+ sequence = 1 ,
78+ )
79+ ]
80+ )
81+
82+ # fail twice, then succeed
83+ mock_send_prompt = AsyncMock ()
84+ mock_send_prompt .side_effect = [Exception ("First failure" ), Exception ("Second failure" ), success_response ]
85+
86+ with patch .object (prompt_target , "send_prompt_async" , mock_send_prompt ):
87+ result = await translation_converter .convert_async (prompt = "hello" )
88+
89+ assert mock_send_prompt .call_count == max_retries
90+ assert result .output_text == "hola"
91+ assert result .output_type == "text"
11292
11393
11494def test_translation_converter_input_supported (duckdb_instance ):
0 commit comments