diff --git a/pyrit/prompt_target/openai/openai_chat_target.py b/pyrit/prompt_target/openai/openai_chat_target.py index 3dd12182b..709e8c955 100644 --- a/pyrit/prompt_target/openai/openai_chat_target.py +++ b/pyrit/prompt_target/openai/openai_chat_target.py @@ -28,6 +28,33 @@ class OpenAIChatTarget(OpenAIChatTargetBase): This class facilitates multimodal (image and text) input and text output generation This works with GPT3.5, GPT4, GPT4o, GPT-V, and other compatible models + + Args: + api_key (str): The api key for the OpenAI API + endpoint (str): The endpoint for the OpenAI API + model_name (str): The model name for the OpenAI API + deployment_name (str): For Azure, the deployment name + api_version (str): The api version for the OpenAI API + temperature (float): The temperature for the completion + max_completion_tokens (int): The maximum number of tokens to be returned by the model. + The total length of input tokens and generated tokens is limited by + the model's context length. + max_tokens (int): Deprecated. Use max_completion_tokens instead + top_p (float): The nucleus sampling probability. + frequency_penalty (float): Number between -2.0 and 2.0. Positive values + penalize new tokens based on their existing frequency in the text so far, + decreasing the model's likelihood to repeat the same line verbatim. + presence_penalty (float): Number between -2.0 and 2.0. Positive values + penalize new tokens based on whether they appear in the text so far, + increasing the model's likelihood to talk about new topics. + seed (int): This feature is in Beta. If specified, our system will make a best effort to sample + deterministically, such that repeated requests with the same seed + and parameters should return the same result. + n (int): How many chat completion choices to generate for each input message. + Note that you will be charged based on the number of generated tokens across all + of the choices. Keep n as 1 to minimize costs. + extra_body_parameters (dict): Additional parameters to send in the request body + """ def __init__( @@ -107,6 +134,9 @@ def __init__( if max_completion_tokens and max_tokens: raise ValueError("Cannot provide both max_tokens and max_completion_tokens.") + # Validate endpoint URL + self._warn_if_irregular_endpoint(self.CHAT_URL_REGEX) + self._max_completion_tokens = max_completion_tokens self._max_tokens = max_tokens self._frequency_penalty = frequency_penalty diff --git a/pyrit/prompt_target/openai/openai_dall_e_target.py b/pyrit/prompt_target/openai/openai_dall_e_target.py index 13f8ed027..31a8aa9f3 100644 --- a/pyrit/prompt_target/openai/openai_dall_e_target.py +++ b/pyrit/prompt_target/openai/openai_dall_e_target.py @@ -96,6 +96,9 @@ def __init__( super().__init__(*args, **kwargs) + # Validate endpoint URL + self._warn_if_irregular_endpoint(self.DALLE_URL_REGEX) + def _set_openai_env_configuration_vars(self): self.model_name_environment_variable = "OPENAI_DALLE_MODEL" self.endpoint_environment_variable = "OPENAI_DALLE_ENDPOINT" diff --git a/pyrit/prompt_target/openai/openai_response_target.py b/pyrit/prompt_target/openai/openai_response_target.py index 5cef257f9..72ff395d1 100644 --- a/pyrit/prompt_target/openai/openai_response_target.py +++ b/pyrit/prompt_target/openai/openai_response_target.py @@ -121,6 +121,9 @@ def __init__( super().__init__(api_version=api_version, temperature=temperature, top_p=top_p, **kwargs) self._max_output_tokens = max_output_tokens + # Validate endpoint URL for OpenAI Response API + self._warn_if_irregular_endpoint(self.RESPONSE_URL_REGEX) + # Reasoning parameters are not yet supported by PyRIT. # See https://platform.openai.com/docs/api-reference/responses/create#responses-create-reasoning # for more information. diff --git a/pyrit/prompt_target/openai/openai_sora_target.py b/pyrit/prompt_target/openai/openai_sora_target.py index 68df213a8..ffd3cbea6 100644 --- a/pyrit/prompt_target/openai/openai_sora_target.py +++ b/pyrit/prompt_target/openai/openai_sora_target.py @@ -158,6 +158,9 @@ def __init__( # Detect API version self._detected_api_version = self._detect_api_version() + # Validate endpoint URL + self._warn_if_irregular_endpoint(self.SORA_URL_REGEX) + # Set instance variables self._n_seconds = n_seconds self._validate_duration() diff --git a/pyrit/prompt_target/openai/openai_target.py b/pyrit/prompt_target/openai/openai_target.py index 02866130c..e71b5aa4c 100644 --- a/pyrit/prompt_target/openai/openai_target.py +++ b/pyrit/prompt_target/openai/openai_target.py @@ -3,13 +3,13 @@ import json import logging +import re from abc import abstractmethod from typing import Optional +from urllib.parse import urlparse -from pyrit.auth.azure_auth import ( - AzureAuth, - get_default_scope, -) +from pyrit.auth import AzureAuth +from pyrit.auth.azure_auth import get_default_scope from pyrit.common import default_values from pyrit.prompt_target import PromptChatTarget @@ -20,6 +20,21 @@ class OpenAITarget(PromptChatTarget): ADDITIONAL_REQUEST_HEADERS: str = "OPENAI_ADDITIONAL_REQUEST_HEADERS" + # Expected URL regex patterns for different OpenAI and AOAI targets + CHAT_URL_REGEX = [ + r"/v1/chat/completions$", # Standard OpenAI & Anthropic endpoints + r"/openai/deployments/[^/]+/chat/completions$", # AOAI pattern + r"/openai/chat/completions$", # Gemini endpoint + ] + SORA_URL_REGEX = [ + r"/videos/v1/video/generations$", # Azure sora1 endpoint + r"/videos/v1/videos$", # Azure sora2 endpoint + r"/v1/videos$", # oai sora2 endpoint + ] + DALLE_URL_REGEX = [r"/images/generations$"] + TTS_URL_REGEX = [r"/audio/speech$"] + RESPONSE_URL_REGEX = [r"/openai/responses$", r"v1/responses$"] + model_name_environment_variable: str endpoint_environment_variable: str api_key_environment_variable: str @@ -126,6 +141,44 @@ def _set_openai_env_configuration_vars(self) -> None: """ raise NotImplementedError + def _warn_if_irregular_endpoint(self, expected_url_regex) -> None: + """ + Validate that the endpoint URL ends with one of the expected routes for this OpenAI target. + + Args: + expected_url_regex: Expected regex pattern(s) for this target. Should be a list of regex strings. + + Prints a warning if the endpoint doesn't match any of the expected routes. + This validation helps ensure the endpoint is configured correctly for the specific API. + """ + if not self._endpoint or not expected_url_regex: + return + + # Use urllib to extract the path part and normalize it + parsed_url = urlparse(self._endpoint) + normalized_route = parsed_url.path.lower().rstrip("/") + + # Check if the endpoint matches any of the expected regex patterns + for regex_pattern in expected_url_regex: + if re.search(regex_pattern, normalized_route): + return + + # No matches found, log warning + if len(expected_url_regex) == 1: + # Convert regex back to human-readable format for the warning + pattern_str = expected_url_regex[0].replace(r"[^/]+", "*").replace("$", "") + expected_routes_str = pattern_str + else: + # Convert all regex patterns to human-readable format + readable_patterns = [p.replace(r"[^/]+", "*").replace("$", "") for p in expected_url_regex] + expected_routes_str = f"one of: {', '.join(readable_patterns)}" + + logger.warning( + f"The provided endpoint URL {parsed_url} does not match any of the expected formats: {expected_routes_str}." + f"This may be intentional, especially if you are using an endpoint other than Azure or OpenAI." + f"For more details and guidance, please see the .env_example file in the repository." + ) + @abstractmethod def is_json_response_supported(self) -> bool: """ diff --git a/pyrit/prompt_target/openai/openai_tts_target.py b/pyrit/prompt_target/openai/openai_tts_target.py index a46a90d34..f19911d7e 100644 --- a/pyrit/prompt_target/openai/openai_tts_target.py +++ b/pyrit/prompt_target/openai/openai_tts_target.py @@ -71,6 +71,9 @@ def __init__( if not self._model_name: self._model_name = "tts-1" + # Validate endpoint URL + self._warn_if_irregular_endpoint(self.TTS_URL_REGEX) + self._voice = voice self._response_format = response_format self._language = language diff --git a/tests/unit/target/test_openai_chat_target.py b/tests/unit/target/test_openai_chat_target.py index 02343ec4c..8f73e3dbd 100644 --- a/tests/unit/target/test_openai_chat_target.py +++ b/tests/unit/target/test_openai_chat_target.py @@ -2,6 +2,7 @@ # Licensed under the MIT license. import json +import logging import os from tempfile import NamedTemporaryFile from typing import MutableSequence @@ -806,3 +807,60 @@ def test_set_auth_headers_with_api_key(patch_central_database): assert target._api_key == "test_api_key_456" assert target._headers["Api-Key"] == "test_api_key_456" assert target._headers["Authorization"] == "Bearer test_api_key_456" + + +def test_url_validation_warning_for_incorrect_endpoint(caplog, patch_central_database): + """Test that URL validation warns for incorrect endpoints.""" + with patch.dict(os.environ, {}, clear=True): + with caplog.at_level(logging.WARNING): + target = OpenAIChatTarget( + model_name="gpt-4", + endpoint="https://api.openai.com/v1/wrong/path", # Incorrect endpoint + api_key="test-key", + api_version="2024-10-21", + ) + + # Should have a warning about incorrect endpoint + warning_logs = [record for record in caplog.records if record.levelno >= logging.WARNING] + assert len(warning_logs) >= 1 + endpoint_warnings = [log for log in warning_logs if "The provided endpoint URL" in log.message] + assert len(endpoint_warnings) == 1 + assert "/v1/chat/completions" in endpoint_warnings[0].message + assert "/openai/deployments/*/chat/completions" in endpoint_warnings[0].message + assert target + + +def test_url_validation_no_warning_for_correct_azure_endpoint(caplog, patch_central_database): + """Test that URL validation doesn't warn for correct Azure endpoints.""" + with patch.dict(os.environ, {}, clear=True): + with caplog.at_level(logging.WARNING): + target = OpenAIChatTarget( + model_name="gpt-4", + endpoint="https://myservice.openai.azure.com/openai/deployments/gpt-4/chat/completions", + api_key="test-key", + api_version="2024-10-21", + ) + + # Should not have URL validation warnings + warning_logs = [record for record in caplog.records if record.levelno >= logging.WARNING] + endpoint_warnings = [log for log in warning_logs if "The provided endpoint URL" in log.message] + assert len(endpoint_warnings) == 0 + assert target + + +def test_url_validation_no_warning_for_correct_openai_endpoint(caplog, patch_central_database): + """Test that URL validation doesn't warn for correct OpenAI endpoints.""" + with patch.dict(os.environ, {}, clear=True): + with caplog.at_level(logging.WARNING): + target = OpenAIChatTarget( + model_name="gpt-4", + endpoint="https://api.openai.com/v1/chat/completions", + api_key="test-key", + api_version="2024-10-21", + ) + + # Should not have URL validation warnings + warning_logs = [record for record in caplog.records if record.levelno >= logging.WARNING] + endpoint_warnings = [log for log in warning_logs if "The provided endpoint URL" in log.message] + assert len(endpoint_warnings) == 0 + assert target diff --git a/tests/unit/target/test_tts_target.py b/tests/unit/target/test_tts_target.py index a89d3377b..28676b1aa 100644 --- a/tests/unit/target/test_tts_target.py +++ b/tests/unit/target/test_tts_target.py @@ -36,6 +36,8 @@ def test_tts_initializes(tts_target: OpenAITTSTarget): def test_tts_initializes_calls_get_required_parameters(patch_central_database): with patch("pyrit.common.default_values.get_required_value") as mock_get_required: + mock_get_required.side_effect = lambda env_var_name, passed_value: passed_value + target = OpenAITTSTarget( model_name="deploymenttest", endpoint="endpointtest",