-
Notifications
You must be signed in to change notification settings - Fork 588
FEAT new target class for AWS Bedrock Anthropic Claude models #699
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
kmarsh77
wants to merge
49
commits into
Azure:main
Choose a base branch
from
kmarsh77:main
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
+302
−1
Open
Changes from 7 commits
Commits
Show all changes
49 commits
Select commit
Hold shift + click to select a range
ed1e3ff
Adding AWS Bedrock Anthropic Claude target class
kmarsh77 63a9b2e
Adding unit tests for AWSBedrockClaudeTarget class
kmarsh77 5de223d
Add optional aws dependency (boto3)
kmarsh77 ac87b28
Update aws_bedrock_claude_target.py
kmarsh77 45785d6
Adding bedrock claude target class
kmarsh77 f7a8767
Update __init__.py for new target classes
kmarsh77 57252d0
Unit test for AWSBedrockClaudeChatTarget
kmarsh77 f254145
Delete pyrit/prompt_target/aws_bedrock_claude_target.py
kmarsh77 f0bc2bc
Update __init__.py
kmarsh77 2ebb519
Delete tests/unit/test_aws_bedrock_claude_target.py
kmarsh77 258f287
Update aws_bedrock_claude_chat_target.py
kmarsh77 408c308
Update test_aws_bedrock_claude_chat_target.py
kmarsh77 5865ac6
Update pyproject.toml
kmarsh77 01addd1
Update aws_bedrock_claude_chat_target.py
kmarsh77 627396a
Update pyrit/prompt_target/aws_bedrock_claude_chat_target.py
kmarsh77 59dcb7e
Update pyrit/prompt_target/aws_bedrock_claude_chat_target.py
kmarsh77 b5c4924
Update aws_bedrock_claude_chat_target.py
kmarsh77 5d8d7e0
Update aws_bedrock_claude_chat_target.py
kmarsh77 185bcff
Update aws_bedrock_claude_chat_target.py
kmarsh77 2630b43
Update test_aws_bedrock_claude_chat_target.py
kmarsh77 7bcb075
Merge branch 'main' into main
kmarsh77 6d0485d
Merge branch 'Azure:main' into main
kmarsh77 0e0e300
Updates to address complaints from pre-commit hooks
kmarsh77 187cb16
Merge branch 'main' into main
romanlutz 3fd876b
Merge branch 'main' into main
romanlutz ef3ef17
Merge branch 'main' into main
romanlutz 6d531d5
Update pyrit/prompt_target/aws_bedrock_claude_chat_target.py
romanlutz e7c3c54
Adding exceptions for when boto3 isn't installed
kmarsh77 8ddb596
Adding exceptions for when boto3 isn't installed
kmarsh77 b80cbda
Merge branch 'main' of https://github.com/kmarsh77/PyRIT
kmarsh77 b772a9c
Adding noqa statements to pass pre-commit checks
kmarsh77 e4b10d3
Merge branch 'Azure:main' into main
kmarsh77 d88919a
Update tests/unit/test_aws_bedrock_claude_chat_target.py
romanlutz fbf6a86
Merge branch 'main' into main
romanlutz ee1a220
Fixing merge conflict in pyproject.toml
kmarsh77 294cdc9
changing import error message
kmarsh77 0e84c64
Merge branch 'Azure:main' into main
kmarsh77 0915d70
Fixed invalid converted_value_data_type in test_aws_bedrock_claude_ch…
kmarsh77 37e92b5
Merge branch 'Azure:main' into main
kmarsh77 c913bd4
moving test_aws_bedrock_claude_chat_target.py to tests/unit/target fo…
kmarsh77 8bfed3d
Adding ignore statements after test_send_prompt_async and test_comple…
kmarsh77 9d3059b
Adding ignore after boto3 use in aws_bedrock_claude_chat_target.py
kmarsh77 b5f5df0
Removing ignore statements
kmarsh77 5541705
removing ignore statements
kmarsh77 e827cb4
putting boto3.client inside try statement
kmarsh77 78af630
fixing
kmarsh77 590167e
Moving boto3 import to within _complete_chat_async
kmarsh77 3d83b84
Merge branch 'Azure:main' into main
kmarsh77 b7fa256
Merge branch 'Azure:main' into main
kmarsh77 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,154 @@ | ||
| import asyncio | ||
| import logging | ||
| import json | ||
| import boto3 | ||
| from typing import MutableSequence, Optional | ||
|
|
||
| from botocore.exceptions import ClientError | ||
|
|
||
| from pyrit.chat_message_normalizer import ChatMessageNop, ChatMessageNormalizer | ||
| from pyrit.common import net_utility | ||
| from pyrit.models import ChatMessage, PromptRequestPiece, PromptRequestResponse, construct_response_from_request, ChatMessageListDictContent | ||
| from pyrit.prompt_target import PromptChatTarget, limit_requests_per_minute | ||
|
|
||
| logger = logging.getLogger(__name__) | ||
|
|
||
| class AWSBedrockClaudeChatTarget(PromptChatTarget): | ||
| """ | ||
| This class initializes an AWS Bedrock target for any of the Anthropic Claude models. | ||
romanlutz marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| Local AWS credentials (typically stored in ~/.aws) are used for authentication. | ||
| See the following for more information: https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-anthropic-claude-messages.html | ||
| Parameters: | ||
| model_id (str): The model ID for target claude model | ||
| max_tokens (int): maximum number of tokens to generate | ||
| temperature (float, optional): The amount of randomness injected into the response. | ||
| top_p (float, optional): Use nucleus sampling | ||
| top_k (int, optional): Only sample from the top K options for each subsequent token | ||
| verify (bool, optional): whether or not to perform SSL certificate verification | ||
| """ | ||
| def __init__( | ||
| self, | ||
| *, | ||
| model_id: str, | ||
| max_tokens: int, | ||
| temperature: Optional[float] = None, | ||
| top_p: Optional[float] = None, | ||
| top_k: Optional[int] = None, | ||
| verify: bool = True, | ||
| chat_message_normalizer: ChatMessageNormalizer = ChatMessageNop(), | ||
| max_requests_per_minute: Optional[int] = None, | ||
| ) -> None: | ||
| super().__init__(max_requests_per_minute=max_requests_per_minute) | ||
|
|
||
| self._model_id = model_id | ||
| self._max_tokens = max_tokens | ||
| self._temperature = temperature | ||
| self._top_p = top_p | ||
| self._top_k = top_k | ||
| self._verify = verify | ||
| self.chat_message_normalizer = chat_message_normalizer | ||
|
|
||
| self._system_prompt = '' | ||
|
|
||
| @limit_requests_per_minute | ||
| async def send_prompt_async(self, *, prompt_request: PromptRequestResponse) -> PromptRequestResponse: | ||
|
|
||
| self._validate_request(prompt_request=prompt_request) | ||
| request_piece = prompt_request.request_pieces[0] | ||
|
|
||
| prompt_req_res_entries = self._memory.get_conversation(conversation_id=request_piece.conversation_id) | ||
| prompt_req_res_entries.append(prompt_request) | ||
|
|
||
| logger.info(f"Sending the following prompt to the prompt target: {prompt_request}") | ||
|
|
||
| messages = self._build_chat_messages(prompt_req_res_entries) | ||
|
|
||
| response = await self._complete_chat_async(messages=messages) | ||
|
|
||
| response_entry = construct_response_from_request(request=request_piece, response_text_pieces=[response]) | ||
|
|
||
| return response_entry | ||
|
|
||
| def _validate_request(self, *, prompt_request: PromptRequestResponse) -> None: | ||
| if len(prompt_request.request_pieces) != 1: | ||
| raise ValueError("This target only supports a single prompt request piece.") | ||
|
|
||
| if prompt_request.request_pieces[0].converted_value_data_type != "text": | ||
| raise ValueError("This target only supports text prompt input.") | ||
|
|
||
| async def _complete_chat_async(self, messages: list[ChatMessageListDictContent]) -> str: | ||
| brt = boto3.client(service_name="bedrock-runtime", region_name='us-east-1', verify=self._verify) | ||
|
|
||
| native_request = self._construct_request_body(messages) | ||
|
|
||
| request = json.dumps(native_request) | ||
|
|
||
| try: | ||
| response = await asyncio.to_thread(brt.invoke_model, modelId=self._model_id, body=request) | ||
| except (ClientError, Exception) as e: | ||
| print(f"ERROR: Can't invoke '{self._model_id}'. Reason: {e}") | ||
| exit() | ||
romanlutz marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
|
|
||
| model_response = json.loads(response["body"].read()) | ||
|
|
||
| answer = model_response["content"][0]["text"] | ||
|
|
||
| logger.info(f'Received the following response from the prompt target "{answer}"') | ||
| return answer | ||
|
|
||
| def _build_chat_messages(self, prompt_req_res_entries: MutableSequence[PromptRequestResponse] | ||
| ) -> list[ChatMessageListDictContent]: | ||
| chat_messages: list[ChatMessageListDictContent] = [] | ||
| for prompt_req_resp_entry in prompt_req_res_entries: | ||
| prompt_request_pieces = prompt_req_resp_entry.request_pieces | ||
|
|
||
| content = [] | ||
| role = None | ||
| for prompt_request_piece in prompt_request_pieces: | ||
| role = prompt_request_piece.role | ||
| if role == "system": | ||
| #bedrock doesn't allow a message with role==system, but it does let you specify system role in a param | ||
kmarsh77 marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| self._system_prompt = prompt_request_piece.converted_value | ||
| elif prompt_request_piece.converted_value_data_type == "text": | ||
| entry = {"type": "text", "text": prompt_request_piece.converted_value} | ||
| content.append(entry) | ||
| else: | ||
| raise ValueError( | ||
| f"Multimodal data type {prompt_request_piece.converted_value_data_type} is not yet supported." | ||
romanlutz marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| ) | ||
|
|
||
| if not role: | ||
| raise ValueError("No role could be determined from the prompt request pieces.") | ||
|
|
||
| chat_message = ChatMessageListDictContent(role=role, content=content) | ||
| chat_messages.append(chat_message) | ||
| return chat_messages | ||
|
|
||
| def _construct_request_body(self, messages_list: list[ChatMessageListDictContent]) -> dict: | ||
| content = [] | ||
|
|
||
| for message in messages_list: | ||
| if message.role != "system": | ||
| entry = {"role": message.role, "content": message.content} | ||
| content.append(entry) | ||
|
|
||
| data = { | ||
| "anthropic_version": "bedrock-2023-05-31", | ||
romanlutz marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| "max_tokens": self._max_tokens, | ||
| "system": self._system_prompt, | ||
| "messages": content | ||
| } | ||
|
|
||
| if self._temperature: | ||
| data['temperature'] = self._temperature | ||
| if self._top_p: | ||
| data['top_p'] = self._top_p | ||
| if self._top_k: | ||
| data['top_k'] = self._top_k | ||
|
|
||
| return(data) | ||
kmarsh77 marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
|
|
||
| def is_json_response_supported(self) -> bool: | ||
| """Indicates that this target supports JSON response format.""" | ||
| return True | ||
romanlutz marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,104 @@ | ||
| import logging | ||
| import json | ||
| import boto3 | ||
romanlutz marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| from typing import Optional | ||
| import asyncio | ||
|
|
||
| from botocore.exceptions import ClientError | ||
|
|
||
| from pyrit.models import PromptRequestResponse, construct_response_from_request | ||
| from pyrit.prompt_target import PromptTarget, limit_requests_per_minute | ||
|
|
||
| logger = logging.getLogger(__name__) | ||
|
|
||
| class AWSBedrockClaudeTarget(PromptTarget): | ||
| """ | ||
| This class initializes an AWS Bedrock target for any of the Anthropic Claude models. | ||
| Local AWS credentials (typically stored in ~/.aws) are used for authentication. | ||
romanlutz marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| See the following for more information: https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-anthropic-claude-messages.html | ||
|
|
||
| Parameters: | ||
| model_id (str): The model ID for target claude model | ||
| max_tokens (int): maximum number of tokens to generate | ||
| temperature (float, optional): The amount of randomness injected into the response. | ||
| top_p (float, optional): Use nucleus sampling | ||
| top_k (int, optional): Only sample from the top K options for each subsequent token | ||
| verify (bool, optional): whether or not to perform SSL certificate verification | ||
| """ | ||
|
|
||
| def __init__( | ||
| self, | ||
| *, | ||
| model_id: str, | ||
| max_tokens: int, | ||
| temperature: Optional[float] = None, | ||
| top_p: Optional[float] = None, | ||
| top_k: Optional[int] = None, | ||
| verify: bool = True, | ||
romanlutz marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| max_requests_per_minute: Optional[int] = None, | ||
| ) -> None: | ||
| super().__init__(max_requests_per_minute=max_requests_per_minute) | ||
|
|
||
| self._model_id = model_id | ||
| self._max_tokens = max_tokens | ||
| self._temperature = temperature | ||
| self._top_p = top_p | ||
| self._top_k = top_k | ||
| self._verify = verify | ||
|
|
||
| @limit_requests_per_minute | ||
| async def send_prompt_async(self, *, prompt_request: PromptRequestResponse) -> PromptRequestResponse: | ||
|
|
||
| self._validate_request(prompt_request=prompt_request) | ||
| request = prompt_request.request_pieces[0] | ||
|
|
||
| logger.info(f"Sending the following prompt to the prompt target: {request}") | ||
|
|
||
| response = await self._complete_text_async(request.converted_value) | ||
romanlutz marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
|
|
||
| response_entry = construct_response_from_request(request=request, response_text_pieces=[response]) | ||
|
|
||
| return response_entry | ||
|
|
||
| def _validate_request(self, *, prompt_request: PromptRequestResponse) -> None: | ||
| if len(prompt_request.request_pieces) != 1: | ||
| raise ValueError("This target only supports a single prompt request piece.") | ||
|
|
||
| if prompt_request.request_pieces[0].converted_value_data_type != "text": | ||
| raise ValueError("This target only supports text prompt input.") | ||
|
|
||
| async def _complete_text_async(self, text: str) -> str: | ||
| brt = boto3.client(service_name="bedrock-runtime", verify=self._verify) | ||
|
|
||
| native_request = { | ||
| "anthropic_version": "bedrock-2023-05-31", | ||
| "max_tokens": self._max_tokens, | ||
| "messages": [ | ||
| { | ||
| "role": "user", | ||
| "content": text | ||
| } | ||
| ] | ||
| } | ||
|
|
||
| if self._temperature: | ||
| native_request['temperature'] = self._temperature | ||
| if self._top_p: | ||
| native_request['top_p'] = self._top_p | ||
| if self._top_k: | ||
| native_request['top_k'] = self._top_k | ||
|
|
||
| request = json.dumps(native_request) | ||
|
|
||
| try: | ||
| response = await asyncio.to_thread(brt.invoke_model, modelId=self._model_id, body=request) | ||
| except (ClientError, Exception) as e: | ||
| print(f"ERROR: Can't invoke '{self._model_id}'. Reason: {e}") | ||
| exit() | ||
|
|
||
| model_response = json.loads(response["body"].read()) | ||
|
|
||
| answer = model_response["content"][0]["text"] | ||
|
|
||
| logger.info(f'Received the following response from the prompt target "{answer}"') | ||
| return answer | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,73 @@ | ||
| import pytest | ||
| import json | ||
| from unittest.mock import AsyncMock, patch, MagicMock | ||
|
|
||
| from pyrit.models import PromptRequestResponse, PromptRequestPiece, ChatMessageListDictContent | ||
| from pyrit.prompt_target.aws_bedrock_claude_chat_target import AWSBedrockClaudeChatTarget | ||
|
|
||
| @pytest.fixture | ||
| def aws_target() -> AWSBedrockClaudeChatTarget: | ||
| return AWSBedrockClaudeChatTarget( | ||
| model_id="anthropic.claude-v2", | ||
| max_tokens=100, | ||
| temperature=0.7, | ||
| top_p=0.9, | ||
| top_k=50, | ||
| verify=True, | ||
| ) | ||
|
|
||
| @pytest.fixture | ||
| def mock_prompt_request(): | ||
| request_piece = PromptRequestPiece( | ||
| role="user", | ||
| original_value="Hello, Claude!", | ||
| converted_value="Hello, how are you?" | ||
| ) | ||
| return PromptRequestResponse(request_pieces=[request_piece]) | ||
|
|
||
| @pytest.mark.asyncio | ||
| async def test_send_prompt_async(aws_target, mock_prompt_request): | ||
romanlutz marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| with patch("boto3.client", new_callable=MagicMock) as mock_boto: | ||
| mock_client = mock_boto.return_value | ||
| mock_client.invoke_model.return_value = { | ||
| "body": MagicMock(read=MagicMock(return_value=json.dumps({"content": [{"text": "I'm good, thanks!"}]}))) | ||
| } | ||
|
|
||
| response = await aws_target.send_prompt_async(prompt_request=mock_prompt_request) | ||
|
|
||
| assert response.request_pieces[0].converted_value == "I'm good, thanks!" | ||
|
|
||
| @pytest.mark.asyncio | ||
| async def test_validate_request_valid(aws_target, mock_prompt_request): | ||
| aws_target._validate_request(prompt_request=mock_prompt_request) | ||
|
|
||
| @pytest.mark.asyncio | ||
| async def test_validate_request_invalid_multiple_pieces(aws_target): | ||
| request_pieces = [ | ||
| PromptRequestPiece(role="user", original_value="test", converted_value="Text 1", converted_value_data_type="text"), | ||
| PromptRequestPiece(role="user", original_value="test", converted_value="Text 2", converted_value_data_type="text") | ||
| ] | ||
| invalid_request = PromptRequestResponse(request_pieces=request_pieces) | ||
|
|
||
| with pytest.raises(ValueError, match="This target only supports a single prompt request piece."): | ||
| aws_target._validate_request(prompt_request=invalid_request) | ||
|
|
||
| @pytest.mark.asyncio | ||
| async def test_validate_request_invalid_data_type(aws_target): | ||
| request_pieces = [PromptRequestPiece(role="user", original_value="test", converted_value="ImageData", converted_value_data_type="image_path")] | ||
| invalid_request = PromptRequestResponse(request_pieces=request_pieces) | ||
|
|
||
| with pytest.raises(ValueError, match="This target only supports text prompt input."): | ||
| aws_target._validate_request(prompt_request=invalid_request) | ||
|
|
||
| @pytest.mark.asyncio | ||
| async def test_complete_chat_async(aws_target): | ||
| with patch("boto3.client", new_callable=MagicMock) as mock_boto: | ||
| mock_client = mock_boto.return_value | ||
| mock_client.invoke_model.return_value = { | ||
| "body": MagicMock(read=MagicMock(return_value=json.dumps({"content": [{"text": "Test Response"}]}))) | ||
| } | ||
|
|
||
| response = await aws_target._complete_chat_async(messages=[ChatMessageListDictContent(role="user", content=[{"type":"text", "text":"Test input"}])]) | ||
|
|
||
| assert response == "Test Response" | ||
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.