diff --git a/extensions/src/AWSSDK.Extensions.Bedrock.MEAI/BedrockChatClient.cs b/extensions/src/AWSSDK.Extensions.Bedrock.MEAI/BedrockChatClient.cs
index 0479bc878232..4b3d423ac106 100644
--- a/extensions/src/AWSSDK.Extensions.Bedrock.MEAI/BedrockChatClient.cs
+++ b/extensions/src/AWSSDK.Extensions.Bedrock.MEAI/BedrockChatClient.cs
@@ -18,8 +18,10 @@
using Amazon.Runtime.Internal.Util;
using Microsoft.Extensions.AI;
using System;
+using System.Buffers;
using System.Collections.Generic;
using System.Diagnostics;
+using System.IO;
using System.Linq;
using System.Runtime.CompilerServices;
using System.Text;
@@ -35,6 +37,13 @@ internal sealed partial class BedrockChatClient : IChatClient
/// A default logger to use.
private static readonly ILogger DefaultLogger = Logger.GetLogger(typeof(BedrockChatClient));
+ /// The name used for the synthetic tool that enforces response format.
+ private const string ResponseFormatToolName = "generate_response";
+ /// The description used for the synthetic tool that enforces response format.
+ private const string ResponseFormatToolDescription = "Generate response in specified format";
+ /// Maximum nesting depth for Document to JSON conversion to prevent stack overflow.
+ private const int MaxDocumentNestingDepth = 100;
+
/// The wrapped instance.
private readonly IAmazonBedrockRuntime _runtime;
/// Default model ID to use when no model is specified in the request.
@@ -42,11 +51,7 @@ internal sealed partial class BedrockChatClient : IChatClient
/// Metadata describing the chat client.
private readonly ChatClientMetadata _metadata;
- ///
- /// Initializes a new instance of the class.
- ///
- /// The instance to wrap.
- /// Model ID to use as the default when no model ID is specified in a request.
+ /// Initializes a new instance of the class.
public BedrockChatClient(IAmazonBedrockRuntime runtime, string? defaultModelId)
{
Debug.Assert(runtime is not null);
@@ -79,7 +84,34 @@ public async Task GetResponseAsync(
request.InferenceConfig = CreateInferenceConfiguration(request.InferenceConfig, options);
request.AdditionalModelRequestFields = CreateAdditionalModelRequestFields(request.AdditionalModelRequestFields, options);
- var response = await _runtime.ConverseAsync(request, cancellationToken).ConfigureAwait(false);
+ // Execute the request with proper error handling for ResponseFormat scenarios
+ ConverseResponse response;
+ try
+ {
+ response = await _runtime.ConverseAsync(request, cancellationToken).ConfigureAwait(false);
+ }
+ catch (AmazonBedrockRuntimeException ex) when (options?.ResponseFormat is ChatResponseFormatJson)
+ {
+ // Check if this is a ToolChoice validation error (model doesn't support it)
+ bool isToolChoiceNotSupported =
+ ex.ErrorCode == "ValidationException" &&
+ (ex.Message.IndexOf("toolChoice", StringComparison.OrdinalIgnoreCase) >= 0 ||
+ ex.Message.IndexOf("tool_choice", StringComparison.OrdinalIgnoreCase) >= 0 ||
+ ex.Message.IndexOf("ToolChoice", StringComparison.OrdinalIgnoreCase) >= 0);
+
+ if (isToolChoiceNotSupported)
+ {
+ // Provide a more helpful error message when ToolChoice fails due to model limitations
+ throw new NotSupportedException(
+ $"The model '{request.ModelId}' does not support ResponseFormat. " +
+ $"ResponseFormat requires ToolChoice support, which is only available in Claude 3+ and Mistral Large models. " +
+ $"See: https://docs.aws.amazon.com/bedrock/latest/userguide/conversation-inference-supported-models-features.html",
+ ex);
+ }
+
+ // Re-throw other exceptions as-is
+ throw;
+ }
ChatMessage result = new()
{
@@ -89,6 +121,50 @@ public async Task GetResponseAsync(
MessageId = Guid.NewGuid().ToString("N"),
};
+ // Check if ResponseFormat was used and extract structured content
+ bool usingResponseFormat = options?.ResponseFormat is ChatResponseFormatJson;
+ if (usingResponseFormat)
+ {
+ var structuredContent = ExtractResponseFormatContent(response.Output?.Message);
+ if (structuredContent is not null)
+ {
+ // Replace the content with the extracted JSON as a TextContent
+ result.Contents.Add(new TextContent(structuredContent) { RawRepresentation = response.Output?.Message });
+
+ // Skip normal content processing since we've extracted the structured response
+ if (DocumentToDictionary(response.AdditionalModelResponseFields) is { } responseFieldsDict)
+ {
+ result.AdditionalProperties = new(responseFieldsDict);
+ }
+
+ return new(result)
+ {
+ CreatedAt = result.CreatedAt,
+ FinishReason = response.StopReason is not null ? GetChatFinishReason(response.StopReason) : null,
+ Usage = response.Usage is TokenUsage tokenUsage ? CreateUsageDetails(tokenUsage) : null,
+ RawRepresentation = response,
+ };
+ }
+ else
+ {
+ // User requested structured output but didn't get it - this is a contract violation
+ var errorMessage = string.Format(
+ "ResponseFormat was specified but model did not return expected tool use. ModelId: {0}, StopReason: {1}",
+ request.ModelId,
+ response.StopReason?.Value ?? "unknown");
+
+ DefaultLogger.Error(new InvalidOperationException(errorMessage), errorMessage);
+
+ // Always throw when ResponseFormat was requested but not fulfilled
+ throw new InvalidOperationException(
+ $"Model '{request.ModelId}' did not return structured output as requested. " +
+ $"This may indicate the model refused to follow the tool use instruction, " +
+ $"the schema was too complex, or the prompt conflicted with the requirement. " +
+ $"StopReason: {response.StopReason?.Value ?? "unknown"}");
+ }
+ }
+
+ // Normal content processing when not using ResponseFormat or extraction failed
if (response.Output?.Message?.Content is { } contents)
{
foreach (var content in contents)
@@ -182,6 +258,14 @@ public async IAsyncEnumerable GetStreamingResponseAsync(
throw new ArgumentNullException(nameof(messages));
}
+ // Check if ResponseFormat is set - not supported for streaming yet
+ if (options?.ResponseFormat is ChatResponseFormatJson)
+ {
+ throw new NotSupportedException(
+ "ResponseFormat is not yet supported for streaming responses with Amazon Bedrock. " +
+ "Please use GetResponseAsync for structured output.");
+ }
+
ConverseStreamRequest request = options?.RawRepresentationFactory?.Invoke(this) as ConverseStreamRequest ?? new();
request.ModelId ??= options?.ModelId ?? _modelId;
request.Messages = CreateMessages(request.Messages, messages);
@@ -792,7 +876,11 @@ private static Document ToDocument(JsonElement json)
}
}
- /// Creates an from the specified options.
+ /// Creates a from the specified options.
+ ///
+ /// When ResponseFormat is specified, creates a synthetic tool to enforce structured output.
+ /// This conflicts with user-provided tools as Bedrock only supports a single ToolChoice value.
+ ///
private static ToolConfiguration? CreateToolConfig(ToolConfiguration? toolConfig, ChatOptions? options)
{
if (options?.Tools is { Count: > 0 } tools)
@@ -855,6 +943,56 @@ private static Document ToDocument(JsonElement json)
}
}
+ // Handle ResponseFormat by creating a synthetic tool
+ if (options?.ResponseFormat is ChatResponseFormatJson jsonFormat)
+ {
+ // Check for conflict with user-provided tools
+ if (toolConfig?.Tools?.Count > 0)
+ {
+ throw new ArgumentException(
+ "ResponseFormat cannot be used with Tools in Amazon Bedrock. " +
+ "ResponseFormat uses Bedrock's tool mechanism for structured output, " +
+ "which conflicts with user-provided tools.");
+ }
+
+ // Create the synthetic tool with the schema from ResponseFormat
+ toolConfig ??= new();
+ toolConfig.Tools ??= [];
+
+ // Parse the schema if provided, otherwise create an empty object schema
+ Document schemaDoc;
+ if (jsonFormat.Schema.HasValue)
+ {
+ // Schema is already a JsonElement (parsed JSON), convert directly to Document
+ schemaDoc = ToDocument(jsonFormat.Schema.Value);
+ }
+ else
+ {
+ // For JSON mode without schema, create a generic object schema
+ schemaDoc = new Document(new Dictionary
+ {
+ ["type"] = new Document("object"),
+ ["additionalProperties"] = new Document(true)
+ });
+ }
+
+ toolConfig.Tools.Add(new Tool
+ {
+ ToolSpec = new ToolSpecification
+ {
+ Name = ResponseFormatToolName,
+ Description = jsonFormat.SchemaDescription ?? ResponseFormatToolDescription,
+ InputSchema = new ToolInputSchema
+ {
+ Json = schemaDoc
+ }
+ }
+ });
+
+ // Force the model to use the synthetic tool
+ toolConfig.ToolChoice = new ToolChoice { Tool = new() { Name = ResponseFormatToolName } };
+ }
+
if (toolConfig?.Tools is { Count: > 0 } && toolConfig.ToolChoice is null)
{
switch (options!.ToolMode)
@@ -870,6 +1008,96 @@ private static Document ToDocument(JsonElement json)
return toolConfig;
}
+ /// Extracts JSON content from the synthetic ResponseFormat tool use, if present.
+ private static string? ExtractResponseFormatContent(Message? message)
+ {
+ if (message?.Content is null)
+ {
+ return null;
+ }
+
+ foreach (var content in message.Content)
+ {
+ if (content.ToolUse is ToolUseBlock toolUse &&
+ toolUse.Name == ResponseFormatToolName &&
+ toolUse.Input != default)
+ {
+ // Convert the Document back to JSON string
+ return DocumentToJsonString(toolUse.Input);
+ }
+ }
+
+ return null;
+ }
+
+ /// Converts a to a JSON string.
+ private static string DocumentToJsonString(Document document)
+ {
+ using var stream = new MemoryStream();
+ using (var writer = new Utf8JsonWriter(stream, new JsonWriterOptions { Indented = false }))
+ {
+ WriteDocumentAsJson(writer, document);
+ } // Explicit scope to ensure writer is flushed before reading buffer
+
+ return Encoding.UTF8.GetString(stream.ToArray());
+ }
+
+ /// Recursively writes a as JSON.
+ private static void WriteDocumentAsJson(Utf8JsonWriter writer, Document document, int depth = 0)
+ {
+ // Check depth to prevent stack overflow from deeply nested or circular structures
+ if (depth > MaxDocumentNestingDepth)
+ {
+ throw new InvalidOperationException(
+ $"Document nesting depth exceeds maximum of {MaxDocumentNestingDepth}. " +
+ $"This may indicate a circular reference or excessively nested data structure.");
+ }
+
+ if (document.IsBool())
+ {
+ writer.WriteBooleanValue(document.AsBool());
+ }
+ else if (document.IsInt())
+ {
+ writer.WriteNumberValue(document.AsInt());
+ }
+ else if (document.IsLong())
+ {
+ writer.WriteNumberValue(document.AsLong());
+ }
+ else if (document.IsDouble())
+ {
+ writer.WriteNumberValue(document.AsDouble());
+ }
+ else if (document.IsString())
+ {
+ writer.WriteStringValue(document.AsString());
+ }
+ else if (document.IsDictionary())
+ {
+ writer.WriteStartObject();
+ foreach (var kvp in document.AsDictionary())
+ {
+ writer.WritePropertyName(kvp.Key);
+ WriteDocumentAsJson(writer, kvp.Value, depth + 1);
+ }
+ writer.WriteEndObject();
+ }
+ else if (document.IsList())
+ {
+ writer.WriteStartArray();
+ foreach (var item in document.AsList())
+ {
+ WriteDocumentAsJson(writer, item, depth + 1);
+ }
+ writer.WriteEndArray();
+ }
+ else
+ {
+ writer.WriteNullValue();
+ }
+ }
+
/// Creates an from the specified options.
private static InferenceConfiguration CreateInferenceConfiguration(InferenceConfiguration config, ChatOptions? options)
{
diff --git a/extensions/test/BedrockMEAITests/BedrockChatClientTests.cs b/extensions/test/BedrockMEAITests/BedrockChatClientTests.cs
index 8f5099c973d8..b599c643bf3d 100644
--- a/extensions/test/BedrockMEAITests/BedrockChatClientTests.cs
+++ b/extensions/test/BedrockMEAITests/BedrockChatClientTests.cs
@@ -1,11 +1,111 @@
-using Microsoft.Extensions.AI;
+using Amazon.BedrockRuntime.Model;
+using Amazon.Runtime.Documents;
+using Microsoft.Extensions.AI;
using System;
+using System.Collections.Generic;
+using System.Linq;
+using System.Text.Json;
+using System.Threading;
+using System.Threading.Tasks;
using Xunit;
namespace Amazon.BedrockRuntime;
+// Mock implementation to capture requests and control responses
+internal sealed class MockBedrockRuntime : IAmazonBedrockRuntime
+{
+ public ConverseRequest CapturedRequest { get; private set; }
+ public ConverseStreamRequest CapturedStreamRequest { get; private set; }
+ public Func ResponseFactory { get; set; }
+ public Exception ExceptionToThrow { get; set; }
+
+ public Task ConverseAsync(ConverseRequest request, CancellationToken cancellationToken = default)
+ {
+ CapturedRequest = request;
+
+ if (ExceptionToThrow != null)
+ {
+ throw ExceptionToThrow;
+ }
+
+ if (ResponseFactory != null)
+ {
+ return Task.FromResult(ResponseFactory(request));
+ }
+
+ // Default response
+ return Task.FromResult(new ConverseResponse
+ {
+ Output = new ConverseOutput
+ {
+ Message = new Message
+ {
+ Role = ConversationRole.Assistant,
+ Content = new List
+ {
+ new ContentBlock { Text = "Default response" }
+ }
+ }
+ },
+ StopReason = new StopReason("end_turn")
+ });
+ }
+
+ public Task ConverseStreamAsync(ConverseStreamRequest request, CancellationToken cancellationToken = default)
+ {
+ CapturedStreamRequest = request;
+ throw new NotImplementedException("Stream testing not implemented in this mock");
+ }
+
+ public void Dispose() { }
+
+ // Unused interface members - all throw NotImplementedException
+ public IBedrockRuntimePaginatorFactory Paginators => throw new NotImplementedException();
+ public Amazon.Runtime.IClientConfig Config => throw new NotImplementedException();
+
+ // Sync methods
+ public ApplyGuardrailResponse ApplyGuardrail(ApplyGuardrailRequest request) => throw new NotImplementedException();
+ public ConverseResponse Converse(ConverseRequest request) => throw new NotImplementedException();
+ public ConverseStreamResponse ConverseStream(ConverseStreamRequest request) => throw new NotImplementedException();
+ public CountTokensResponse CountTokens(CountTokensRequest request) => throw new NotImplementedException();
+ public GetAsyncInvokeResponse GetAsyncInvoke(GetAsyncInvokeRequest request) => throw new NotImplementedException();
+ public InvokeModelResponse InvokeModel(InvokeModelRequest request) => throw new NotImplementedException();
+ public InvokeModelWithResponseStreamResponse InvokeModelWithResponseStream(InvokeModelWithResponseStreamRequest request) => throw new NotImplementedException();
+ public ListAsyncInvokesResponse ListAsyncInvokes(ListAsyncInvokesRequest request) => throw new NotImplementedException();
+ public StartAsyncInvokeResponse StartAsyncInvoke(StartAsyncInvokeRequest request) => throw new NotImplementedException();
+
+ // Async methods
+ public Task ApplyGuardrailAsync(ApplyGuardrailRequest request, CancellationToken cancellationToken = default) => throw new NotImplementedException();
+ public Task CountTokensAsync(CountTokensRequest request, CancellationToken cancellationToken = default) => throw new NotImplementedException();
+ public Task GetAsyncInvokeAsync(GetAsyncInvokeRequest request, CancellationToken cancellationToken = default) => throw new NotImplementedException();
+ public Task InvokeModelAsync(InvokeModelRequest request, CancellationToken cancellationToken = default) => throw new NotImplementedException();
+ public Task InvokeModelWithResponseStreamAsync(InvokeModelWithResponseStreamRequest request, CancellationToken cancellationToken = default) => throw new NotImplementedException();
+ public Task ListAsyncInvokesAsync(ListAsyncInvokesRequest request, CancellationToken cancellationToken = default) => throw new NotImplementedException();
+ public Task StartAsyncInvokeAsync(StartAsyncInvokeRequest request, CancellationToken cancellationToken = default) => throw new NotImplementedException();
+
+ // Endpoint determination
+ public Amazon.Runtime.Endpoints.Endpoint DetermineServiceOperationEndpoint(Amazon.Runtime.AmazonWebServiceRequest request) => throw new NotImplementedException();
+}
+
+// Simple test implementation of AIFunctionDeclaration
+internal sealed class TestAIFunction : AIFunctionDeclaration
+{
+ public TestAIFunction(string name, string description, JsonElement jsonSchema)
+ {
+ Name = name;
+ Description = description;
+ JsonSchema = jsonSchema;
+ }
+
+ public override string Name { get; }
+ public override string Description { get; }
+ public override JsonElement JsonSchema { get; }
+}
+
public class BedrockChatClientTests
{
+ #region Basic Client Tests
+
[Fact]
[Trait("UnitTest", "BedrockRuntime")]
public void AsIChatClient_InvalidArguments_Throws()
@@ -19,8 +119,8 @@ public void AsIChatClient_InvalidArguments_Throws()
[InlineData("claude")]
public void AsIChatClient_ReturnsInstance(string modelId)
{
- IAmazonBedrockRuntime runtime = new AmazonBedrockRuntimeClient("awsAccessKeyId", "awsSecretAccessKey", RegionEndpoint.USEast1);
- IChatClient client = runtime.AsIChatClient(modelId);
+ var mock = new MockBedrockRuntime();
+ IChatClient client = mock.AsIChatClient(modelId);
Assert.NotNull(client);
Assert.Equal("aws.bedrock", client.GetService()?.ProviderName);
@@ -31,17 +131,135 @@ public void AsIChatClient_ReturnsInstance(string modelId)
[Trait("UnitTest", "BedrockRuntime")]
public void AsIChatClient_GetService()
{
- IAmazonBedrockRuntime runtime = new AmazonBedrockRuntimeClient("awsAccessKeyId", "awsSecretAccessKey", RegionEndpoint.USEast1);
- IChatClient client = runtime.AsIChatClient();
+ var mock = new MockBedrockRuntime();
+ IChatClient client = mock.AsIChatClient();
- Assert.Same(runtime, client.GetService());
- Assert.Same(runtime, client.GetService());
+ Assert.Same(mock, client.GetService());
Assert.Same(client, client.GetService());
-
Assert.Null(client.GetService());
-
- Assert.Null(client.GetService("key"));
Assert.Null(client.GetService("key"));
- Assert.Null(client.GetService("key"));
}
+
+ #endregion
+
+ #region ResponseFormat Tests
+
+ [Fact]
+ [Trait("UnitTest", "BedrockRuntime")]
+ public async Task ResponseFormat_Json_WithSchema_CreatesSyntheticToolWithCorrectSchema()
+ {
+ // Arrange
+ var mock = new MockBedrockRuntime();
+ var client = mock.AsIChatClient("claude-3");
+ var messages = new[] { new ChatMessage(ChatRole.User, "Test") };
+
+ var schemaJson = """
+ {
+ "type": "object",
+ "properties": {
+ "name": { "type": "string" },
+ "age": { "type": "number" }
+ },
+ "required": ["name"]
+ }
+ """;
+ var schemaElement = JsonDocument.Parse(schemaJson).RootElement;
+ var options = new ChatOptions
+ {
+ ResponseFormat = ChatResponseFormat.ForJsonSchema(schemaElement,
+ schemaName: "PersonSchema",
+ schemaDescription: "A person object")
+ };
+
+ // Act
+ try
+ {
+ await client.GetResponseAsync(messages, options);
+ }
+ catch
+ {
+ // We're testing request creation
+ }
+
+ // Assert
+ var tool = mock.CapturedRequest.ToolConfig.Tools[0];
+ Assert.Equal("generate_response", tool.ToolSpec.Name);
+ Assert.Equal("A person object", tool.ToolSpec.Description);
+
+ // Verify schema structure matches input
+ var schema = tool.ToolSpec.InputSchema.Json;
+ Assert.True(schema.IsDictionary());
+ var schemaDict = schema.AsDictionary();
+
+ Assert.Equal("object", schemaDict["type"].AsString());
+ Assert.True(schemaDict.ContainsKey("properties"));
+
+ var properties = schemaDict["properties"].AsDictionary();
+ Assert.True(properties.ContainsKey("name"));
+ Assert.True(properties.ContainsKey("age"));
+ Assert.Equal("string", properties["name"].AsDictionary()["type"].AsString());
+ Assert.Equal("number", properties["age"].AsDictionary()["type"].AsString());
+
+ Assert.True(schemaDict.ContainsKey("required"));
+ var required = schemaDict["required"].AsList();
+ Assert.Single(required);
+ Assert.Equal("name", required[0].AsString());
+ }
+
+ [Fact]
+ [Trait("UnitTest", "BedrockRuntime")]
+ public async Task ResponseFormat_Json_ModelReturnsToolUse_ExtractsJsonCorrectly()
+ {
+ // Arrange
+ var mock = new MockBedrockRuntime();
+ var client = mock.AsIChatClient("claude-3");
+ var messages = new[] { new ChatMessage(ChatRole.User, "Get weather") };
+ var options = new ChatOptions { ResponseFormat = ChatResponseFormat.Json };
+
+ // Setup mock to return tool use with structured data
+ mock.ResponseFactory = req => new ConverseResponse
+ {
+ Output = new ConverseOutput
+ {
+ Message = new Message
+ {
+ Role = ConversationRole.Assistant,
+ Content = new List
+ {
+ new ContentBlock
+ {
+ ToolUse = new ToolUseBlock
+ {
+ ToolUseId = "test-id",
+ Name = "generate_response",
+ Input = new Document(new Dictionary
+ {
+ ["city"] = new Document("Seattle"),
+ ["temperature"] = new Document(72),
+ ["conditions"] = new Document("sunny")
+ })
+ }
+ }
+ }
+ }
+ },
+ StopReason = new StopReason("tool_use"),
+ Usage = new TokenUsage { InputTokens = 10, OutputTokens = 20, TotalTokens = 30 }
+ };
+
+ // Act
+ var response = await client.GetResponseAsync(messages, options);
+
+ // Assert
+ Assert.NotNull(response);
+ Assert.NotNull(response.Text);
+
+ // Parse the JSON to verify structure
+ var json = JsonDocument.Parse(response.Text);
+ Assert.Equal("Seattle", json.RootElement.GetProperty("city").GetString());
+ Assert.Equal(72, json.RootElement.GetProperty("temperature").GetInt32());
+ Assert.Equal("sunny", json.RootElement.GetProperty("conditions").GetString());
+ }
+
+ #endregion
}