Skip to content

Commit a8ca7a1

Browse files
authored
FEAT: Add GroqChatTarget (#704) (#705)
1 parent 3dbd738 commit a8ca7a1

File tree

6 files changed

+803
-0
lines changed

6 files changed

+803
-0
lines changed

doc/_toc.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@ chapters:
5757
- file: code/targets/5_multi_modal_targets
5858
- file: code/targets/6_rate_limiting
5959
- file: code/targets/7_http_target
60+
- file: code/targets/groq_chat_target
6061
- file: code/targets/open_ai_completions
6162
- file: code/targets/playwright_target
6263
- file: code/targets/prompt_shield_target
Lines changed: 114 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,114 @@
1+
{
2+
"cells": [
3+
{
4+
"cell_type": "markdown",
5+
"id": "0",
6+
"metadata": {
7+
"lines_to_next_cell": 0
8+
},
9+
"source": [
10+
"# GroqChatTarget\n",
11+
"\n",
12+
"This example demonstrates how to use the `GroqChatTarget` class in PyRIT to send a prompt\n",
13+
"to a Groq model and retrieve a response.\n",
14+
"\n",
15+
"## Setup\n",
16+
"Before running this example, you need to set the following environment variables:\n",
17+
"\n",
18+
"```\n",
19+
"export GROQ_API_KEY=\"your_api_key_here\"\n",
20+
"export GROQ_MODEL_NAME=\"llama3-8b-8192\"\n",
21+
"```\n",
22+
"\n",
23+
"Alternatively, you can pass these values as arguments when initializing `GroqChatTarget`:\n",
24+
"\n",
25+
"```python\n",
26+
"groq_target = GroqChatTarget(model_name=\"llama3-8b-8192\", api_key=\"your_api_key_here\")\n",
27+
"```\n",
28+
"\n",
29+
"You can also limit the request rate using `max_requests_per_minute`.\n",
30+
"\n",
31+
"## Example\n",
32+
"The following code initializes `GroqChatTarget`, sends a prompt using `PromptSendingOrchestrator`,\n",
33+
"and retrieves a response."
34+
]
35+
},
36+
{
37+
"cell_type": "code",
38+
"execution_count": null,
39+
"id": "1",
40+
"metadata": {},
41+
"outputs": [
42+
{
43+
"name": "stderr",
44+
"output_type": "stream",
45+
"text": [
46+
"None of PyTorch, TensorFlow >= 2.0, or Flax have been found. Models won't be available and only tokenizers, configuration and file/data utilities can be used.\n"
47+
]
48+
},
49+
{
50+
"name": "stdout",
51+
"output_type": "stream",
52+
"text": [
53+
"\u001b[22m\u001b[39mConversation ID: 7ae4ae98-a23b-4330-9c3e-5fd9e8c37854\n",
54+
"\u001b[1m\u001b[34muser: Why is the sky blue ?\n",
55+
"\u001b[22m\u001b[33massistant: The sky appears blue because of a phenomenon called Rayleigh scattering, which is the scattering of light by small particles or molecules in the atmosphere.\n",
56+
"\n",
57+
"When sunlight enters Earth's atmosphere, it encounters tiny molecules of gases such as nitrogen (N2) and oxygen (O2). These molecules scatter the light in all directions, but they scatter shorter (blue) wavelengths more than longer (red) wavelengths. This is known as Rayleigh scattering.\n",
58+
"\n",
59+
"As a result of this scattering, the blue light is dispersed throughout the atmosphere, reaching our eyes from all directions. This is why the sky appears blue during the daytime, as the blue light is being scattered in all directions and reaching our eyes from all parts of the sky.\n",
60+
"\n",
61+
"In addition to Rayleigh scattering, there are other factors that can affect the color of the sky, such as:\n",
62+
"\n",
63+
"* Mie scattering: This is the scattering of light by larger particles, such as dust, pollen, and water droplets. Mie scattering can give the sky a more orange or pinkish hue during sunrise and sunset.\n",
64+
"* Scattering by cloud droplets: Clouds can scatter light in a way that gives the sky a more white or gray appearance.\n",
65+
"* Atmospheric conditions: Factors such as pollution, dust, and water vapor can also affect the color of the sky, making it appear more hazy or brownish.\n",
66+
"\n",
67+
"Overall, the combination of Rayleigh scattering and other atmospheric effects is what gives the sky its blue color during the daytime.\n"
68+
]
69+
}
70+
],
71+
"source": [
72+
"\n",
73+
"from pyrit.common import IN_MEMORY, initialize_pyrit\n",
74+
"from pyrit.orchestrator import PromptSendingOrchestrator\n",
75+
"from pyrit.prompt_target import GroqChatTarget\n",
76+
"\n",
77+
"initialize_pyrit(memory_db_type=IN_MEMORY)\n",
78+
"\n",
79+
"groq_target = GroqChatTarget()\n",
80+
"\n",
81+
"prompt = \"Why is the sky blue ?\"\n",
82+
"\n",
83+
"orchestrator = PromptSendingOrchestrator(objective_target=groq_target)\n",
84+
"\n",
85+
"response = await orchestrator.send_prompts_async(prompt_list=[prompt]) # type: ignore\n",
86+
"await orchestrator.print_conversations_async() # type: ignore"
87+
]
88+
}
89+
],
90+
"metadata": {
91+
"jupytext": {
92+
"cell_metadata_filter": "-all"
93+
},
94+
"kernelspec": {
95+
"display_name": "pyrt_env",
96+
"language": "python",
97+
"name": "pyrt_env"
98+
},
99+
"language_info": {
100+
"codemirror_mode": {
101+
"name": "ipython",
102+
"version": 3
103+
},
104+
"file_extension": ".py",
105+
"mimetype": "text/x-python",
106+
"name": "python",
107+
"nbconvert_exporter": "python",
108+
"pygments_lexer": "ipython3",
109+
"version": "3.11.2"
110+
}
111+
},
112+
"nbformat": 4,
113+
"nbformat_minor": 5
114+
}
Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
# ---
2+
# jupyter:
3+
# jupytext:
4+
# cell_metadata_filter: -all
5+
# text_representation:
6+
# extension: .py
7+
# format_name: percent
8+
# format_version: '1.3'
9+
# jupytext_version: 1.16.7
10+
# kernelspec:
11+
# display_name: pyrt_env
12+
# language: python
13+
# name: pyrt_env
14+
# ---
15+
16+
# %% [markdown]
17+
# # GroqChatTarget
18+
#
19+
# This example demonstrates how to use the `GroqChatTarget` class in PyRIT to send a prompt
20+
# to a Groq model and retrieve a response.
21+
#
22+
# ## Setup
23+
# Before running this example, you need to set the following environment variables:
24+
#
25+
# ```
26+
# export GROQ_API_KEY="your_api_key_here"
27+
# export GROQ_MODEL_NAME="llama3-8b-8192"
28+
# ```
29+
#
30+
# Alternatively, you can pass these values as arguments when initializing `GroqChatTarget`:
31+
#
32+
# ```python
33+
# groq_target = GroqChatTarget(model_name="llama3-8b-8192", api_key="your_api_key_here")
34+
# ```
35+
#
36+
# You can also limit the request rate using `max_requests_per_minute`.
37+
#
38+
# ## Example
39+
# The following code initializes `GroqChatTarget`, sends a prompt using `PromptSendingOrchestrator`,
40+
# and retrieves a response.
41+
# %%
42+
43+
from pyrit.common import IN_MEMORY, initialize_pyrit
44+
from pyrit.orchestrator import PromptSendingOrchestrator
45+
from pyrit.prompt_target import GroqChatTarget
46+
47+
initialize_pyrit(memory_db_type=IN_MEMORY)
48+
49+
groq_target = GroqChatTarget()
50+
51+
prompt = "Why is the sky blue ?"
52+
53+
orchestrator = PromptSendingOrchestrator(objective_target=groq_target)
54+
55+
response = await orchestrator.send_prompts_async(prompt_list=[prompt]) # type: ignore
56+
await orchestrator.print_conversations_async() # type: ignore

pyrit/prompt_target/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from pyrit.prompt_target.azure_ml_chat_target import AzureMLChatTarget
1313
from pyrit.prompt_target.crucible_target import CrucibleTarget
1414
from pyrit.prompt_target.gandalf_target import GandalfLevel, GandalfTarget
15+
from pyrit.prompt_target.groq_chat_target import GroqChatTarget
1516
from pyrit.prompt_target.http_target.http_target import HTTPTarget
1617
from pyrit.prompt_target.http_target.http_target_callback_functions import (
1718
get_http_target_json_response_callback_function,
@@ -34,6 +35,7 @@
3435
"CrucibleTarget",
3536
"GandalfLevel",
3637
"GandalfTarget",
38+
"GroqChatTarget",
3739
"get_http_target_json_response_callback_function",
3840
"get_http_target_regex_matching_callback_function",
3941
"HTTPTarget",
Lines changed: 129 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,129 @@
1+
# Copyright (c) Microsoft Corporation.
2+
# Licensed under the MIT license.
3+
4+
import logging
5+
6+
from openai import AsyncOpenAI
7+
from openai.types.chat import ChatCompletion
8+
9+
from pyrit.common import default_values
10+
from pyrit.exceptions import EmptyResponseException, PyritException, pyrit_target_retry
11+
from pyrit.models import ChatMessageListDictContent
12+
from pyrit.prompt_target.openai.openai_chat_target import OpenAIChatTarget
13+
14+
logger = logging.getLogger(__name__)
15+
16+
17+
class GroqChatTarget(OpenAIChatTarget):
18+
"""
19+
A chat target for interacting with Groq's OpenAI-compatible API.
20+
21+
This class extends `OpenAIChatTarget` and ensures compatibility with Groq's API,
22+
which requires `msg.content` to be a string instead of a list of dictionaries.
23+
24+
Attributes:
25+
API_KEY_ENVIRONMENT_VARIABLE (str): The environment variable for the Groq API key.
26+
MODEL_NAME_ENVIRONMENT_VARIABLE (str): The environment variable for the Groq model name.
27+
GROQ_API_BASE_URL (str): The fixed API base URL for Groq.
28+
"""
29+
30+
API_KEY_ENVIRONMENT_VARIABLE = "GROQ_API_KEY"
31+
MODEL_NAME_ENVIRONMENT_VARIABLE = "GROQ_MODEL_NAME"
32+
GROQ_API_BASE_URL = "https://api.groq.com/openai/v1/"
33+
34+
def __init__(self, *, model_name: str = None, api_key: str = None, max_requests_per_minute: int = None, **kwargs):
35+
"""
36+
Initializes GroqChatTarget with the correct API settings.
37+
38+
Args:
39+
model_name (str, optional): The model to use. Defaults to `GROQ_MODEL_NAME` env variable.
40+
api_key (str, optional): The API key for authentication. Defaults to `GROQ_API_KEY` env variable.
41+
max_requests_per_minute (int, optional): Rate limit for requests.
42+
"""
43+
44+
kwargs.pop("endpoint", None)
45+
kwargs.pop("deployment_name", None)
46+
47+
super().__init__(
48+
deployment_name=model_name,
49+
endpoint=self.GROQ_API_BASE_URL,
50+
api_key=api_key,
51+
is_azure_target=False,
52+
max_requests_per_minute=max_requests_per_minute,
53+
**kwargs,
54+
)
55+
56+
def _initialize_non_azure_vars(self, deployment_name: str, endpoint: str, api_key: str):
57+
"""
58+
Initializes variables to communicate with the (non-Azure) OpenAI API, in this case Groq.
59+
60+
Args:
61+
deployment_name (str): The model name.
62+
endpoint (str): The API base URL.
63+
api_key (str): The API key.
64+
65+
Raises:
66+
ValueError: If _deployment_name or _api_key is missing.
67+
"""
68+
self._api_key = default_values.get_required_value(
69+
env_var_name=self.API_KEY_ENVIRONMENT_VARIABLE, passed_value=api_key
70+
)
71+
if not self._api_key:
72+
raise ValueError("API key for Groq is missing. Ensure GROQ_API_KEY is set in the environment.")
73+
74+
self._deployment_name = default_values.get_required_value(
75+
env_var_name=self.MODEL_NAME_ENVIRONMENT_VARIABLE, passed_value=deployment_name
76+
)
77+
if not self._deployment_name:
78+
raise ValueError("Model name for Groq is missing. Ensure GROQ_MODEL_NAME is set in the environment.")
79+
80+
# Ignoring mypy type error. The OpenAI client and Azure OpenAI client have the same private base class
81+
self._async_client = AsyncOpenAI( # type: ignore
82+
api_key=self._api_key, default_headers=self._extra_headers, base_url=endpoint
83+
)
84+
85+
@pyrit_target_retry
86+
async def _complete_chat_async(self, messages: list[ChatMessageListDictContent], is_json_response: bool) -> str:
87+
"""
88+
Completes asynchronous chat request.
89+
90+
Sends a chat message to the OpenAI chat model and retrieves the generated response.
91+
This method modifies the request structure to ensure compatibility with Groq,
92+
which requires `msg.content` as a string instead of a list of dictionaries.
93+
msg.content -> msg.content[0].get("text")
94+
95+
Args:
96+
messages (list[ChatMessageListDictContent]): The chat message objects containing the role and content.
97+
is_json_response (bool): Boolean indicating if the response should be in JSON format.
98+
99+
Returns:
100+
str: The generated response message.
101+
"""
102+
response: ChatCompletion = await self._async_client.chat.completions.create(
103+
model=self._deployment_name,
104+
max_completion_tokens=self._max_completion_tokens,
105+
max_tokens=self._max_tokens,
106+
temperature=self._temperature,
107+
top_p=self._top_p,
108+
frequency_penalty=self._frequency_penalty,
109+
presence_penalty=self._presence_penalty,
110+
n=1,
111+
stream=False,
112+
seed=self._seed,
113+
messages=[{"role": msg.role, "content": msg.content[0].get("text")} for msg in messages], # type: ignore
114+
response_format={"type": "json_object"} if is_json_response else None,
115+
)
116+
finish_reason = response.choices[0].finish_reason
117+
extracted_response: str = ""
118+
# finish_reason="stop" means API returned complete message and
119+
# "length" means API returned incomplete message due to max_tokens limit.
120+
if finish_reason in ["stop", "length"]:
121+
extracted_response = self._parse_chat_completion(response)
122+
# Handle empty response
123+
if not extracted_response:
124+
logger.log(logging.ERROR, "The chat returned an empty response.")
125+
raise EmptyResponseException(message="The chat returned an empty response.")
126+
else:
127+
raise PyritException(message=f"Unknown finish_reason {finish_reason}")
128+
129+
return extracted_response

0 commit comments

Comments
 (0)