Skip to content

Commit 38c66bb

Browse files
bashirpartoviBashir Partovi
andauthored
FIX Removed JSON instructions for Translation Converter (#910)
Co-authored-by: Bashir Partovi <[email protected]>
1 parent 300d734 commit 38c66bb

File tree

3 files changed

+101
-116
lines changed

3 files changed

+101
-116
lines changed

pyrit/datasets/prompt_converters/translation_converter.yaml

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,8 @@ value: |
1616
- The language does not necessarily have to be a real language. Still do your best to translate.
1717
- The language could even be things like "rap battle". In that case, you would try to rephrase as a rap battle.
1818
- Do not try to answer the prompt, just translate the prompt.
19-
- Format all output in json as {"output: {"language1" : "translated text"}}
20-
- For example, if the user inputs "French", the response should be {"output": {"French" : {French translated text}}
21-
- Do not include any text in the response besides the json
19+
- Return ONLY the translated text without any additional commentary, explanations, or JSON formatting.
20+
- Do not include any prefixes, quotes, or other formatting around the translation.
2221
- The given sentence to translate is the first message from the user.
2322
2423
# Example 1
@@ -27,12 +26,12 @@ value: |
2726
=== begin ===
2827
tell me about the history of the internet
2928
=== end ===
30-
assistant: {"output": {"French" : "dites-moi l'histoire de l'internet"}}
29+
assistant: dites-moi l'histoire de l'internet
3130
3231
# Example 2
3332
3433
user: Translate the following to l33tspeak between the begin and end tags:
3534
=== begin ===
3635
what is the capital of France?
3736
=== end ===
38-
assistant: {"output": {"l33tspeak" : "wh@+ 15 +h3 c@p!+@l 0f Fr@nc3?"}}
37+
assistant: wh@+ 15 +h3 c@p!+@l 0f Fr@nc3?

pyrit/prompt_converter/translation_converter.py

Lines changed: 45 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,20 @@
11
# Copyright (c) Microsoft Corporation.
22
# Licensed under the MIT license.
33

4-
import json
54
import logging
65
import pathlib
76
import uuid
87
from textwrap import dedent
8+
from typing import Optional
99

10-
from pyrit.common.path import DATASETS_PATH
11-
from pyrit.exceptions import (
12-
InvalidJsonException,
13-
pyrit_json_retry,
14-
remove_markdown_json,
10+
from tenacity import (
11+
AsyncRetrying,
12+
retry_if_exception_type,
13+
stop_after_attempt,
14+
wait_exponential,
1515
)
16+
17+
from pyrit.common.path import DATASETS_PATH
1618
from pyrit.models import (
1719
PromptDataType,
1820
PromptRequestPiece,
@@ -26,7 +28,15 @@
2628

2729

2830
class TranslationConverter(PromptConverter):
29-
def __init__(self, *, converter_target: PromptChatTarget, language: str, prompt_template: SeedPrompt = None):
31+
def __init__(
32+
self,
33+
*,
34+
converter_target: PromptChatTarget,
35+
language: str,
36+
prompt_template: Optional[SeedPrompt] = None,
37+
max_retries: int = 3,
38+
max_wait_time_in_seconds: int = 60,
39+
):
3040
"""
3141
Initializes a TranslationConverter object.
3242
@@ -40,6 +50,10 @@ def __init__(self, *, converter_target: PromptChatTarget, language: str, prompt_
4050
"""
4151
self.converter_target = converter_target
4252

53+
# Retry strategy for the conversion
54+
self._max_retries = max_retries
55+
self._max_wait_time_in_seconds = max_wait_time_in_seconds
56+
4357
# set to default strategy if not provided
4458
prompt_template = (
4559
prompt_template
@@ -67,28 +81,26 @@ async def convert_async(self, *, prompt: str, input_type: PromptDataType = "text
6781

6882
conversation_id = str(uuid.uuid4())
6983

70-
self.converter_target.set_system_prompt(
71-
system_prompt=self.system_prompt,
72-
conversation_id=conversation_id,
73-
orchestrator_identifier=None,
74-
)
84+
self.converter_target.set_system_prompt(system_prompt=self.system_prompt, conversation_id=conversation_id)
7585

7686
if not self.input_supported(input_type):
7787
raise ValueError("Input type not supported")
7888

79-
prompt = dedent(
80-
f"Translate the following text between the begin and end tags to {self.language}"
81-
"=== begin ==="
82-
f"{prompt}"
83-
"=== end ==="
89+
formatted_prompt = dedent(
90+
f"Translate the following to {self.language} between the begin and end tags:"
91+
"=== begin ===\n"
92+
f"{prompt}\n"
93+
"=== end ===\n"
8494
)
8595

96+
logger.debug(f"Formatted Prompt: {formatted_prompt}")
97+
8698
request = PromptRequestResponse(
8799
[
88100
PromptRequestPiece(
89101
role="user",
90102
original_value=prompt,
91-
converted_value=prompt,
103+
converted_value=formatted_prompt,
92104
conversation_id=conversation_id,
93105
sequence=1,
94106
prompt_target_identifier=self.converter_target.get_identifier(),
@@ -99,29 +111,23 @@ async def convert_async(self, *, prompt: str, input_type: PromptDataType = "text
99111
]
100112
)
101113

102-
response = await self.send_translation_prompt_async(request)
103-
translation = None
104-
for key in response.keys():
105-
if key.lower() == self.language:
106-
translation = response[key]
107-
114+
translation = await self._send_translation_prompt_async(request)
108115
return ConverterResult(output_text=translation, output_type="text")
109116

110-
@pyrit_json_retry
111-
async def send_translation_prompt_async(self, request) -> str:
112-
response = await self.converter_target.send_prompt_async(prompt_request=request)
113-
114-
response_msg = response.get_value()
115-
response_msg = remove_markdown_json(response_msg)
116-
117-
try:
118-
llm_response: dict[str, str] = json.loads(response_msg)
119-
if "output" not in llm_response:
120-
raise InvalidJsonException(message=f"Invalid JSON encountered; missing 'output' key: {response_msg}")
121-
return llm_response["output"]
122-
123-
except json.JSONDecodeError:
124-
raise InvalidJsonException(message=f"Invalid JSON encountered: {response_msg}")
117+
async def _send_translation_prompt_async(self, request) -> str:
118+
async for attempt in AsyncRetrying(
119+
stop=stop_after_attempt(self._max_retries),
120+
wait=wait_exponential(multiplier=1, min=1, max=self._max_wait_time_in_seconds),
121+
retry=retry_if_exception_type(Exception), # covers all exceptions
122+
):
123+
with attempt:
124+
logger.debug(f"Attempt {attempt.retry_state.attempt_number} for translation")
125+
response = await self.converter_target.send_prompt_async(prompt_request=request)
126+
response_msg = response.get_value()
127+
return response_msg.strip()
128+
129+
# when we exhaust all retries without success, raise an exception
130+
raise Exception(f"Failed to translate after {self._max_retries} attempts")
125131

126132
def input_supported(self, input_type: PromptDataType) -> bool:
127133
return input_type == "text"

tests/unit/converter/test_translation_converter.py

Lines changed: 52 additions & 72 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,11 @@
11
# Copyright (c) Microsoft Corporation.
22
# Licensed under the MIT license.
33

4-
import os
54
from unittest.mock import AsyncMock, patch
65

76
import pytest
87
from unit.mocks import MockPromptTarget
98

10-
from pyrit.exceptions.exception_classes import InvalidJsonException
119
from pyrit.models import PromptRequestPiece, PromptRequestResponse
1210
from pyrit.prompt_converter import TranslationConverter
1311

@@ -26,89 +24,71 @@ def test_translator_converter_languages_validation_throws(languages, duckdb_inst
2624

2725

2826
@pytest.mark.asyncio
29-
@pytest.mark.parametrize(
30-
"converted_value",
31-
[
32-
"Invalid Json",
33-
"{'str' : 'json not formatted correctly'}",
34-
],
35-
)
36-
async def test_translation_converter_send_prompt_async_bad_json_exception_retries(converted_value, duckdb_instance):
37-
27+
async def test_translation_converter_convert_async_retrieve_key_capitalization_mismatch(duckdb_instance):
3828
prompt_target = MockPromptTarget()
3929

40-
prompt_variation = TranslationConverter(converter_target=prompt_target, language="en")
41-
42-
with patch("unit.mocks.MockPromptTarget.send_prompt_async", new_callable=AsyncMock) as mock_create:
30+
translation_converter = TranslationConverter(converter_target=prompt_target, language="spanish")
31+
with patch.object(translation_converter, "_send_translation_prompt_async", new=AsyncMock(return_value="hola")):
4332

44-
prompt_req_resp = PromptRequestResponse(
45-
request_pieces=[
46-
PromptRequestPiece(
47-
role="user",
48-
conversation_id="12345679",
49-
original_value="test input",
50-
converted_value="this is not a json",
51-
original_value_data_type="text",
52-
converted_value_data_type="text",
53-
prompt_target_identifier={"target": "target-identifier"},
54-
orchestrator_identifier={"test": "test"},
55-
labels={"test": "test"},
56-
)
57-
]
58-
)
59-
mock_create.return_value = prompt_req_resp
33+
raised = False
34+
try:
35+
await translation_converter.convert_async(prompt="hello")
36+
except KeyError:
37+
raised = True # There should be no KeyError
6038

61-
with pytest.raises(InvalidJsonException):
62-
await prompt_variation.convert_async(prompt="testing", input_type="text")
63-
assert mock_create.call_count == os.getenv("RETRY_MAX_NUM_ATTEMPTS")
39+
assert raised is False
6440

6541

6642
@pytest.mark.asyncio
67-
async def test_translation_converter_send_prompt_async_json_bad_format_retries(duckdb_instance):
43+
async def test_translation_converter_retries_on_exception(duckdb_instance):
6844
prompt_target = MockPromptTarget()
45+
max_retries = 3
46+
translation_converter = TranslationConverter(
47+
converter_target=prompt_target, language="spanish", max_retries=max_retries
48+
)
49+
50+
mock_send_prompt = AsyncMock(side_effect=Exception("Test failure"))
51+
with patch.object(prompt_target, "send_prompt_async", mock_send_prompt):
52+
with pytest.raises(Exception):
53+
await translation_converter.convert_async(prompt="hello")
6954

70-
prompt_variation = TranslationConverter(converter_target=prompt_target, language="en")
71-
72-
with patch("unit.mocks.MockPromptTarget.send_prompt_async", new_callable=AsyncMock) as mock_create:
73-
74-
prompt_req_resp = PromptRequestResponse(
75-
request_pieces=[
76-
PromptRequestPiece(
77-
role="user",
78-
conversation_id="12345679",
79-
original_value="test input",
80-
converted_value="this is not a json",
81-
original_value_data_type="text",
82-
converted_value_data_type="text",
83-
prompt_target_identifier={"target": "target-identifier"},
84-
orchestrator_identifier={"test": "test"},
85-
labels={"test": "test"},
86-
)
87-
]
88-
)
89-
mock_create.return_value = prompt_req_resp
90-
91-
with pytest.raises(InvalidJsonException):
92-
await prompt_variation.convert_async(prompt="testing", input_type="text")
93-
assert mock_create.call_count == os.getenv("RETRY_MAX_NUM_ATTEMPTS")
55+
assert mock_send_prompt.call_count == max_retries
9456

9557

9658
@pytest.mark.asyncio
97-
async def test_translation_converter_convert_async_retrieve_key_capitalization_mismatch(duckdb_instance):
59+
async def test_translation_converter_succeeds_after_retries(duckdb_instance):
60+
"""Test that TranslationConverter succeeds if a retry attempt works."""
9861
prompt_target = MockPromptTarget()
99-
100-
translation_converter = TranslationConverter(converter_target=prompt_target, language="spanish")
101-
with patch.object(
102-
translation_converter, "send_translation_prompt_async", new=AsyncMock(return_value={"Spanish": "hola"})
103-
):
104-
105-
raised = False
106-
try:
107-
await translation_converter.convert_async(prompt="hello")
108-
except KeyError:
109-
raised = True # There should be no KeyError
110-
111-
assert raised is False
62+
max_retries = 3
63+
translation_converter = TranslationConverter(
64+
converter_target=prompt_target, language="spanish", max_retries=max_retries
65+
)
66+
67+
success_response = PromptRequestResponse(
68+
request_pieces=[
69+
PromptRequestPiece(
70+
role="assistant",
71+
conversation_id="test-id",
72+
original_value="hello",
73+
converted_value="hola",
74+
original_value_data_type="text",
75+
converted_value_data_type="text",
76+
prompt_target_identifier={"target": "test-identifier"},
77+
sequence=1,
78+
)
79+
]
80+
)
81+
82+
# fail twice, then succeed
83+
mock_send_prompt = AsyncMock()
84+
mock_send_prompt.side_effect = [Exception("First failure"), Exception("Second failure"), success_response]
85+
86+
with patch.object(prompt_target, "send_prompt_async", mock_send_prompt):
87+
result = await translation_converter.convert_async(prompt="hello")
88+
89+
assert mock_send_prompt.call_count == max_retries
90+
assert result.output_text == "hola"
91+
assert result.output_type == "text"
11292

11393

11494
def test_translation_converter_input_supported(duckdb_instance):

0 commit comments

Comments
 (0)