-
Notifications
You must be signed in to change notification settings - Fork 591
FEAT new target class for AWS Bedrock Anthropic Claude models #699
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
kmarsh77
wants to merge
49
commits into
Azure:main
Choose a base branch
from
kmarsh77:main
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
+302
−1
Open
Changes from 34 commits
Commits
Show all changes
49 commits
Select commit
Hold shift + click to select a range
ed1e3ff
Adding AWS Bedrock Anthropic Claude target class
kmarsh77 63a9b2e
Adding unit tests for AWSBedrockClaudeTarget class
kmarsh77 5de223d
Add optional aws dependency (boto3)
kmarsh77 ac87b28
Update aws_bedrock_claude_target.py
kmarsh77 45785d6
Adding bedrock claude target class
kmarsh77 f7a8767
Update __init__.py for new target classes
kmarsh77 57252d0
Unit test for AWSBedrockClaudeChatTarget
kmarsh77 f254145
Delete pyrit/prompt_target/aws_bedrock_claude_target.py
kmarsh77 f0bc2bc
Update __init__.py
kmarsh77 2ebb519
Delete tests/unit/test_aws_bedrock_claude_target.py
kmarsh77 258f287
Update aws_bedrock_claude_chat_target.py
kmarsh77 408c308
Update test_aws_bedrock_claude_chat_target.py
kmarsh77 5865ac6
Update pyproject.toml
kmarsh77 01addd1
Update aws_bedrock_claude_chat_target.py
kmarsh77 627396a
Update pyrit/prompt_target/aws_bedrock_claude_chat_target.py
kmarsh77 59dcb7e
Update pyrit/prompt_target/aws_bedrock_claude_chat_target.py
kmarsh77 b5c4924
Update aws_bedrock_claude_chat_target.py
kmarsh77 5d8d7e0
Update aws_bedrock_claude_chat_target.py
kmarsh77 185bcff
Update aws_bedrock_claude_chat_target.py
kmarsh77 2630b43
Update test_aws_bedrock_claude_chat_target.py
kmarsh77 7bcb075
Merge branch 'main' into main
kmarsh77 6d0485d
Merge branch 'Azure:main' into main
kmarsh77 0e0e300
Updates to address complaints from pre-commit hooks
kmarsh77 187cb16
Merge branch 'main' into main
romanlutz 3fd876b
Merge branch 'main' into main
romanlutz ef3ef17
Merge branch 'main' into main
romanlutz 6d531d5
Update pyrit/prompt_target/aws_bedrock_claude_chat_target.py
romanlutz e7c3c54
Adding exceptions for when boto3 isn't installed
kmarsh77 8ddb596
Adding exceptions for when boto3 isn't installed
kmarsh77 b80cbda
Merge branch 'main' of https://github.com/kmarsh77/PyRIT
kmarsh77 b772a9c
Adding noqa statements to pass pre-commit checks
kmarsh77 e4b10d3
Merge branch 'Azure:main' into main
kmarsh77 d88919a
Update tests/unit/test_aws_bedrock_claude_chat_target.py
romanlutz fbf6a86
Merge branch 'main' into main
romanlutz ee1a220
Fixing merge conflict in pyproject.toml
kmarsh77 294cdc9
changing import error message
kmarsh77 0e84c64
Merge branch 'Azure:main' into main
kmarsh77 0915d70
Fixed invalid converted_value_data_type in test_aws_bedrock_claude_ch…
kmarsh77 37e92b5
Merge branch 'Azure:main' into main
kmarsh77 c913bd4
moving test_aws_bedrock_claude_chat_target.py to tests/unit/target fo…
kmarsh77 8bfed3d
Adding ignore statements after test_send_prompt_async and test_comple…
kmarsh77 9d3059b
Adding ignore after boto3 use in aws_bedrock_claude_chat_target.py
kmarsh77 b5f5df0
Removing ignore statements
kmarsh77 5541705
removing ignore statements
kmarsh77 e827cb4
putting boto3.client inside try statement
kmarsh77 78af630
fixing
kmarsh77 590167e
Moving boto3 import to within _complete_chat_async
kmarsh77 3d83b84
Merge branch 'Azure:main' into main
kmarsh77 b7fa256
Merge branch 'Azure:main' into main
kmarsh77 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
| @@ -0,0 +1,203 @@ | ||||||
| # Copyright (c) Microsoft Corporation. | ||||||
| # Licensed under the MIT license. | ||||||
|
|
||||||
| import asyncio | ||||||
| import base64 | ||||||
| import json | ||||||
| import logging | ||||||
| from typing import TYPE_CHECKING, MutableSequence, Optional | ||||||
|
|
||||||
| from pyrit.chat_message_normalizer import ChatMessageNop, ChatMessageNormalizer | ||||||
| from pyrit.models import ( | ||||||
| ChatMessageListDictContent, | ||||||
| PromptRequestResponse, | ||||||
| construct_response_from_request, | ||||||
| ) | ||||||
| from pyrit.prompt_target import PromptChatTarget, limit_requests_per_minute | ||||||
|
|
||||||
| logger = logging.getLogger(__name__) | ||||||
|
|
||||||
| if TYPE_CHECKING: | ||||||
| import boto3 | ||||||
| from botocore.exceptions import ClientError | ||||||
|
|
||||||
|
|
||||||
| class AWSBedrockClaudeChatTarget(PromptChatTarget): | ||||||
| """ | ||||||
| This class initializes an AWS Bedrock target for any of the Anthropic Claude models. | ||||||
| Local AWS credentials (typically stored in ~/.aws) are used for authentication. | ||||||
| See the following for more information: | ||||||
| https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-anthropic-claude-messages.html | ||||||
|
|
||||||
| Parameters: | ||||||
| model_id (str): The model ID for target claude model | ||||||
| max_tokens (int): maximum number of tokens to generate | ||||||
| temperature (float, optional): The amount of randomness injected into the response. | ||||||
| top_p (float, optional): Use nucleus sampling | ||||||
| top_k (int, optional): Only sample from the top K options for each subsequent token | ||||||
| enable_ssl_verification (bool, optional): whether or not to perform SSL certificate verification | ||||||
| """ | ||||||
|
|
||||||
| def __init__( | ||||||
| self, | ||||||
| *, | ||||||
| model_id: str, | ||||||
| max_tokens: int, | ||||||
| temperature: Optional[float] = None, | ||||||
| top_p: Optional[float] = None, | ||||||
| top_k: Optional[int] = None, | ||||||
| enable_ssl_verification: bool = True, | ||||||
| chat_message_normalizer: ChatMessageNormalizer = ChatMessageNop(), | ||||||
| max_requests_per_minute: Optional[int] = None, | ||||||
| ) -> None: | ||||||
| super().__init__(max_requests_per_minute=max_requests_per_minute) | ||||||
|
|
||||||
| self._model_id = model_id | ||||||
| self._max_tokens = max_tokens | ||||||
| self._temperature = temperature | ||||||
| self._top_p = top_p | ||||||
| self._top_k = top_k | ||||||
| self._enable_ssl_verification = enable_ssl_verification | ||||||
| self.chat_message_normalizer = chat_message_normalizer | ||||||
|
|
||||||
| self._system_prompt = "" | ||||||
|
|
||||||
| self._valid_image_types = ["jpeg", "png", "webp", "gif"] | ||||||
|
|
||||||
| try: | ||||||
| import boto3 # noqa: F401 | ||||||
| from botocore.exceptions import ClientError # noqa: F401 | ||||||
| except ModuleNotFoundError as e: | ||||||
| logger.error("Could not import boto. You may need to install it via 'pip install pyrit[all]'") | ||||||
|
||||||
| logger.error("Could not import boto. You may need to install it via 'pip install pyrit[all]'") | |
| logger.error("Could not import boto. You may need to install it via 'pip install pyrit[all] or pyrit[aws]'") |
romanlutz marked this conversation as resolved.
Show resolved
Hide resolved
romanlutz marked this conversation as resolved.
Show resolved
Hide resolved
romanlutz marked this conversation as resolved.
Show resolved
Hide resolved
romanlutz marked this conversation as resolved.
Show resolved
Hide resolved
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,95 @@ | ||
| # Copyright (c) Microsoft Corporation. | ||
| # Licensed under the MIT license. | ||
|
|
||
| import json | ||
| from unittest.mock import MagicMock, patch | ||
|
|
||
| import pytest | ||
|
|
||
| from pyrit.models import ( | ||
| ChatMessageListDictContent, | ||
| PromptRequestPiece, | ||
| PromptRequestResponse, | ||
| ) | ||
| from pyrit.prompt_target.aws_bedrock_claude_chat_target import ( | ||
| AWSBedrockClaudeChatTarget, | ||
| ) | ||
|
|
||
|
|
||
| def is_boto3_installed(): | ||
| try: | ||
| import boto3 # noqa: F401 | ||
|
|
||
| return True | ||
| except ModuleNotFoundError: | ||
| return False | ||
|
|
||
|
|
||
| @pytest.fixture | ||
| def aws_target() -> AWSBedrockClaudeChatTarget: | ||
| return AWSBedrockClaudeChatTarget( | ||
| model_id="anthropic.claude-v2", | ||
| max_tokens=100, | ||
| temperature=0.7, | ||
| top_p=0.9, | ||
| top_k=50, | ||
| enable_ssl_verification=True, | ||
| ) | ||
|
|
||
|
|
||
| @pytest.fixture | ||
| def mock_prompt_request(): | ||
| request_piece = PromptRequestPiece( | ||
| role="user", original_value="Hello, Claude!", converted_value="Hello, how are you?" | ||
| ) | ||
| return PromptRequestResponse(request_pieces=[request_piece]) | ||
|
|
||
|
|
||
| @pytest.mark.skipif(not is_boto3_installed(), reason="boto3 is not installed") | ||
| @pytest.mark.asyncio | ||
| async def test_send_prompt_async(aws_target, mock_prompt_request): | ||
romanlutz marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| with patch("boto3.client", new_callable=MagicMock) as mock_boto: | ||
| mock_client = mock_boto.return_value | ||
| mock_client.invoke_model.return_value = { | ||
| "body": MagicMock(read=MagicMock(return_value=json.dumps({"content": [{"text": "I'm good, thanks!"}]}))) | ||
| } | ||
|
|
||
| response = await aws_target.send_prompt_async(prompt_request=mock_prompt_request) | ||
|
|
||
| assert response.request_pieces[0].converted_value == "I'm good, thanks!" | ||
|
|
||
|
|
||
| @pytest.mark.skipif(not is_boto3_installed(), reason="boto3 is not installed") | ||
| @pytest.mark.asyncio | ||
| async def test_validate_request_valid(aws_target, mock_prompt_request): | ||
| aws_target._validate_request(prompt_request=mock_prompt_request) | ||
|
|
||
|
|
||
| @pytest.mark.skipif(not is_boto3_installed(), reason="boto3 is not installed") | ||
| @pytest.mark.asyncio | ||
| async def test_validate_request_invalid_data_type(aws_target): | ||
| request_pieces = [ | ||
| PromptRequestPiece( | ||
| role="user", original_value="test", converted_value="ImageData", converted_value_data_type="video" | ||
| ) | ||
| ] | ||
| invalid_request = PromptRequestResponse(request_pieces=request_pieces) | ||
|
|
||
| with pytest.raises(ValueError, match="This target only supports text and image_path."): | ||
| aws_target._validate_request(prompt_request=invalid_request) | ||
|
|
||
|
|
||
| @pytest.mark.skipif(not is_boto3_installed(), reason="boto3 is not installed") | ||
| @pytest.mark.asyncio | ||
| async def test_complete_chat_async(aws_target): | ||
| with patch("boto3.client", new_callable=MagicMock) as mock_boto: | ||
| mock_client = mock_boto.return_value | ||
| mock_client.invoke_model.return_value = { | ||
| "body": MagicMock(read=MagicMock(return_value=json.dumps({"content": [{"text": "Test Response"}]}))) | ||
| } | ||
|
|
||
| response = await aws_target._complete_chat_async( | ||
| messages=[ChatMessageListDictContent(role="user", content=[{"type": "text", "text": "Test input"}])] | ||
| ) | ||
|
|
||
| assert response == "Test Response" | ||
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.