|
| 1 | +# Copyright (c) Microsoft Corporation. |
| 2 | +# Licensed under the MIT license. |
| 3 | + |
| 4 | +import logging |
| 5 | + |
| 6 | +from openai import AsyncOpenAI |
| 7 | +from openai.types.chat import ChatCompletion |
| 8 | + |
| 9 | +from pyrit.common import default_values |
| 10 | +from pyrit.exceptions import EmptyResponseException, PyritException, pyrit_target_retry |
| 11 | +from pyrit.models import ChatMessageListDictContent |
| 12 | +from pyrit.prompt_target.openai.openai_chat_target import OpenAIChatTarget |
| 13 | + |
| 14 | +logger = logging.getLogger(__name__) |
| 15 | + |
| 16 | + |
| 17 | +class GroqChatTarget(OpenAIChatTarget): |
| 18 | + """ |
| 19 | + A chat target for interacting with Groq's OpenAI-compatible API. |
| 20 | +
|
| 21 | + This class extends `OpenAIChatTarget` and ensures compatibility with Groq's API, |
| 22 | + which requires `msg.content` to be a string instead of a list of dictionaries. |
| 23 | +
|
| 24 | + Attributes: |
| 25 | + API_KEY_ENVIRONMENT_VARIABLE (str): The environment variable for the Groq API key. |
| 26 | + MODEL_NAME_ENVIRONMENT_VARIABLE (str): The environment variable for the Groq model name. |
| 27 | + GROQ_API_BASE_URL (str): The fixed API base URL for Groq. |
| 28 | + """ |
| 29 | + |
| 30 | + API_KEY_ENVIRONMENT_VARIABLE = "GROQ_API_KEY" |
| 31 | + MODEL_NAME_ENVIRONMENT_VARIABLE = "GROQ_MODEL_NAME" |
| 32 | + GROQ_API_BASE_URL = "https://api.groq.com/openai/v1/" |
| 33 | + |
| 34 | + def __init__(self, *, model_name: str = None, api_key: str = None, max_requests_per_minute: int = None, **kwargs): |
| 35 | + """ |
| 36 | + Initializes GroqChatTarget with the correct API settings. |
| 37 | +
|
| 38 | + Args: |
| 39 | + model_name (str, optional): The model to use. Defaults to `GROQ_MODEL_NAME` env variable. |
| 40 | + api_key (str, optional): The API key for authentication. Defaults to `GROQ_API_KEY` env variable. |
| 41 | + max_requests_per_minute (int, optional): Rate limit for requests. |
| 42 | + """ |
| 43 | + |
| 44 | + kwargs.pop("endpoint", None) |
| 45 | + kwargs.pop("deployment_name", None) |
| 46 | + |
| 47 | + super().__init__( |
| 48 | + deployment_name=model_name, |
| 49 | + endpoint=self.GROQ_API_BASE_URL, |
| 50 | + api_key=api_key, |
| 51 | + is_azure_target=False, |
| 52 | + max_requests_per_minute=max_requests_per_minute, |
| 53 | + **kwargs, |
| 54 | + ) |
| 55 | + |
| 56 | + def _initialize_non_azure_vars(self, deployment_name: str, endpoint: str, api_key: str): |
| 57 | + """ |
| 58 | + Initializes variables to communicate with the (non-Azure) OpenAI API, in this case Groq. |
| 59 | +
|
| 60 | + Args: |
| 61 | + deployment_name (str): The model name. |
| 62 | + endpoint (str): The API base URL. |
| 63 | + api_key (str): The API key. |
| 64 | +
|
| 65 | + Raises: |
| 66 | + ValueError: If _deployment_name or _api_key is missing. |
| 67 | + """ |
| 68 | + self._api_key = default_values.get_required_value( |
| 69 | + env_var_name=self.API_KEY_ENVIRONMENT_VARIABLE, passed_value=api_key |
| 70 | + ) |
| 71 | + if not self._api_key: |
| 72 | + raise ValueError("API key for Groq is missing. Ensure GROQ_API_KEY is set in the environment.") |
| 73 | + |
| 74 | + self._deployment_name = default_values.get_required_value( |
| 75 | + env_var_name=self.MODEL_NAME_ENVIRONMENT_VARIABLE, passed_value=deployment_name |
| 76 | + ) |
| 77 | + if not self._deployment_name: |
| 78 | + raise ValueError("Model name for Groq is missing. Ensure GROQ_MODEL_NAME is set in the environment.") |
| 79 | + |
| 80 | + # Ignoring mypy type error. The OpenAI client and Azure OpenAI client have the same private base class |
| 81 | + self._async_client = AsyncOpenAI( # type: ignore |
| 82 | + api_key=self._api_key, default_headers=self._extra_headers, base_url=endpoint |
| 83 | + ) |
| 84 | + |
| 85 | + @pyrit_target_retry |
| 86 | + async def _complete_chat_async(self, messages: list[ChatMessageListDictContent], is_json_response: bool) -> str: |
| 87 | + """ |
| 88 | + Completes asynchronous chat request. |
| 89 | +
|
| 90 | + Sends a chat message to the OpenAI chat model and retrieves the generated response. |
| 91 | + This method modifies the request structure to ensure compatibility with Groq, |
| 92 | + which requires `msg.content` as a string instead of a list of dictionaries. |
| 93 | + msg.content -> msg.content[0].get("text") |
| 94 | +
|
| 95 | + Args: |
| 96 | + messages (list[ChatMessageListDictContent]): The chat message objects containing the role and content. |
| 97 | + is_json_response (bool): Boolean indicating if the response should be in JSON format. |
| 98 | +
|
| 99 | + Returns: |
| 100 | + str: The generated response message. |
| 101 | + """ |
| 102 | + response: ChatCompletion = await self._async_client.chat.completions.create( |
| 103 | + model=self._deployment_name, |
| 104 | + max_completion_tokens=self._max_completion_tokens, |
| 105 | + max_tokens=self._max_tokens, |
| 106 | + temperature=self._temperature, |
| 107 | + top_p=self._top_p, |
| 108 | + frequency_penalty=self._frequency_penalty, |
| 109 | + presence_penalty=self._presence_penalty, |
| 110 | + n=1, |
| 111 | + stream=False, |
| 112 | + seed=self._seed, |
| 113 | + messages=[{"role": msg.role, "content": msg.content[0].get("text")} for msg in messages], # type: ignore |
| 114 | + response_format={"type": "json_object"} if is_json_response else None, |
| 115 | + ) |
| 116 | + finish_reason = response.choices[0].finish_reason |
| 117 | + extracted_response: str = "" |
| 118 | + # finish_reason="stop" means API returned complete message and |
| 119 | + # "length" means API returned incomplete message due to max_tokens limit. |
| 120 | + if finish_reason in ["stop", "length"]: |
| 121 | + extracted_response = self._parse_chat_completion(response) |
| 122 | + # Handle empty response |
| 123 | + if not extracted_response: |
| 124 | + logger.log(logging.ERROR, "The chat returned an empty response.") |
| 125 | + raise EmptyResponseException(message="The chat returned an empty response.") |
| 126 | + else: |
| 127 | + raise PyritException(message=f"Unknown finish_reason {finish_reason}") |
| 128 | + |
| 129 | + return extracted_response |
0 commit comments