diff --git a/extensions/src/AWSSDK.Extensions.Bedrock.MEAI/BedrockChatClient.cs b/extensions/src/AWSSDK.Extensions.Bedrock.MEAI/BedrockChatClient.cs index 0479bc878232..4b3d423ac106 100644 --- a/extensions/src/AWSSDK.Extensions.Bedrock.MEAI/BedrockChatClient.cs +++ b/extensions/src/AWSSDK.Extensions.Bedrock.MEAI/BedrockChatClient.cs @@ -18,8 +18,10 @@ using Amazon.Runtime.Internal.Util; using Microsoft.Extensions.AI; using System; +using System.Buffers; using System.Collections.Generic; using System.Diagnostics; +using System.IO; using System.Linq; using System.Runtime.CompilerServices; using System.Text; @@ -35,6 +37,13 @@ internal sealed partial class BedrockChatClient : IChatClient /// A default logger to use. private static readonly ILogger DefaultLogger = Logger.GetLogger(typeof(BedrockChatClient)); + /// The name used for the synthetic tool that enforces response format. + private const string ResponseFormatToolName = "generate_response"; + /// The description used for the synthetic tool that enforces response format. + private const string ResponseFormatToolDescription = "Generate response in specified format"; + /// Maximum nesting depth for Document to JSON conversion to prevent stack overflow. + private const int MaxDocumentNestingDepth = 100; + /// The wrapped instance. private readonly IAmazonBedrockRuntime _runtime; /// Default model ID to use when no model is specified in the request. @@ -42,11 +51,7 @@ internal sealed partial class BedrockChatClient : IChatClient /// Metadata describing the chat client. private readonly ChatClientMetadata _metadata; - /// - /// Initializes a new instance of the class. - /// - /// The instance to wrap. - /// Model ID to use as the default when no model ID is specified in a request. + /// Initializes a new instance of the class. public BedrockChatClient(IAmazonBedrockRuntime runtime, string? defaultModelId) { Debug.Assert(runtime is not null); @@ -79,7 +84,34 @@ public async Task GetResponseAsync( request.InferenceConfig = CreateInferenceConfiguration(request.InferenceConfig, options); request.AdditionalModelRequestFields = CreateAdditionalModelRequestFields(request.AdditionalModelRequestFields, options); - var response = await _runtime.ConverseAsync(request, cancellationToken).ConfigureAwait(false); + // Execute the request with proper error handling for ResponseFormat scenarios + ConverseResponse response; + try + { + response = await _runtime.ConverseAsync(request, cancellationToken).ConfigureAwait(false); + } + catch (AmazonBedrockRuntimeException ex) when (options?.ResponseFormat is ChatResponseFormatJson) + { + // Check if this is a ToolChoice validation error (model doesn't support it) + bool isToolChoiceNotSupported = + ex.ErrorCode == "ValidationException" && + (ex.Message.IndexOf("toolChoice", StringComparison.OrdinalIgnoreCase) >= 0 || + ex.Message.IndexOf("tool_choice", StringComparison.OrdinalIgnoreCase) >= 0 || + ex.Message.IndexOf("ToolChoice", StringComparison.OrdinalIgnoreCase) >= 0); + + if (isToolChoiceNotSupported) + { + // Provide a more helpful error message when ToolChoice fails due to model limitations + throw new NotSupportedException( + $"The model '{request.ModelId}' does not support ResponseFormat. " + + $"ResponseFormat requires ToolChoice support, which is only available in Claude 3+ and Mistral Large models. " + + $"See: https://docs.aws.amazon.com/bedrock/latest/userguide/conversation-inference-supported-models-features.html", + ex); + } + + // Re-throw other exceptions as-is + throw; + } ChatMessage result = new() { @@ -89,6 +121,50 @@ public async Task GetResponseAsync( MessageId = Guid.NewGuid().ToString("N"), }; + // Check if ResponseFormat was used and extract structured content + bool usingResponseFormat = options?.ResponseFormat is ChatResponseFormatJson; + if (usingResponseFormat) + { + var structuredContent = ExtractResponseFormatContent(response.Output?.Message); + if (structuredContent is not null) + { + // Replace the content with the extracted JSON as a TextContent + result.Contents.Add(new TextContent(structuredContent) { RawRepresentation = response.Output?.Message }); + + // Skip normal content processing since we've extracted the structured response + if (DocumentToDictionary(response.AdditionalModelResponseFields) is { } responseFieldsDict) + { + result.AdditionalProperties = new(responseFieldsDict); + } + + return new(result) + { + CreatedAt = result.CreatedAt, + FinishReason = response.StopReason is not null ? GetChatFinishReason(response.StopReason) : null, + Usage = response.Usage is TokenUsage tokenUsage ? CreateUsageDetails(tokenUsage) : null, + RawRepresentation = response, + }; + } + else + { + // User requested structured output but didn't get it - this is a contract violation + var errorMessage = string.Format( + "ResponseFormat was specified but model did not return expected tool use. ModelId: {0}, StopReason: {1}", + request.ModelId, + response.StopReason?.Value ?? "unknown"); + + DefaultLogger.Error(new InvalidOperationException(errorMessage), errorMessage); + + // Always throw when ResponseFormat was requested but not fulfilled + throw new InvalidOperationException( + $"Model '{request.ModelId}' did not return structured output as requested. " + + $"This may indicate the model refused to follow the tool use instruction, " + + $"the schema was too complex, or the prompt conflicted with the requirement. " + + $"StopReason: {response.StopReason?.Value ?? "unknown"}"); + } + } + + // Normal content processing when not using ResponseFormat or extraction failed if (response.Output?.Message?.Content is { } contents) { foreach (var content in contents) @@ -182,6 +258,14 @@ public async IAsyncEnumerable GetStreamingResponseAsync( throw new ArgumentNullException(nameof(messages)); } + // Check if ResponseFormat is set - not supported for streaming yet + if (options?.ResponseFormat is ChatResponseFormatJson) + { + throw new NotSupportedException( + "ResponseFormat is not yet supported for streaming responses with Amazon Bedrock. " + + "Please use GetResponseAsync for structured output."); + } + ConverseStreamRequest request = options?.RawRepresentationFactory?.Invoke(this) as ConverseStreamRequest ?? new(); request.ModelId ??= options?.ModelId ?? _modelId; request.Messages = CreateMessages(request.Messages, messages); @@ -792,7 +876,11 @@ private static Document ToDocument(JsonElement json) } } - /// Creates an from the specified options. + /// Creates a from the specified options. + /// + /// When ResponseFormat is specified, creates a synthetic tool to enforce structured output. + /// This conflicts with user-provided tools as Bedrock only supports a single ToolChoice value. + /// private static ToolConfiguration? CreateToolConfig(ToolConfiguration? toolConfig, ChatOptions? options) { if (options?.Tools is { Count: > 0 } tools) @@ -855,6 +943,56 @@ private static Document ToDocument(JsonElement json) } } + // Handle ResponseFormat by creating a synthetic tool + if (options?.ResponseFormat is ChatResponseFormatJson jsonFormat) + { + // Check for conflict with user-provided tools + if (toolConfig?.Tools?.Count > 0) + { + throw new ArgumentException( + "ResponseFormat cannot be used with Tools in Amazon Bedrock. " + + "ResponseFormat uses Bedrock's tool mechanism for structured output, " + + "which conflicts with user-provided tools."); + } + + // Create the synthetic tool with the schema from ResponseFormat + toolConfig ??= new(); + toolConfig.Tools ??= []; + + // Parse the schema if provided, otherwise create an empty object schema + Document schemaDoc; + if (jsonFormat.Schema.HasValue) + { + // Schema is already a JsonElement (parsed JSON), convert directly to Document + schemaDoc = ToDocument(jsonFormat.Schema.Value); + } + else + { + // For JSON mode without schema, create a generic object schema + schemaDoc = new Document(new Dictionary + { + ["type"] = new Document("object"), + ["additionalProperties"] = new Document(true) + }); + } + + toolConfig.Tools.Add(new Tool + { + ToolSpec = new ToolSpecification + { + Name = ResponseFormatToolName, + Description = jsonFormat.SchemaDescription ?? ResponseFormatToolDescription, + InputSchema = new ToolInputSchema + { + Json = schemaDoc + } + } + }); + + // Force the model to use the synthetic tool + toolConfig.ToolChoice = new ToolChoice { Tool = new() { Name = ResponseFormatToolName } }; + } + if (toolConfig?.Tools is { Count: > 0 } && toolConfig.ToolChoice is null) { switch (options!.ToolMode) @@ -870,6 +1008,96 @@ private static Document ToDocument(JsonElement json) return toolConfig; } + /// Extracts JSON content from the synthetic ResponseFormat tool use, if present. + private static string? ExtractResponseFormatContent(Message? message) + { + if (message?.Content is null) + { + return null; + } + + foreach (var content in message.Content) + { + if (content.ToolUse is ToolUseBlock toolUse && + toolUse.Name == ResponseFormatToolName && + toolUse.Input != default) + { + // Convert the Document back to JSON string + return DocumentToJsonString(toolUse.Input); + } + } + + return null; + } + + /// Converts a to a JSON string. + private static string DocumentToJsonString(Document document) + { + using var stream = new MemoryStream(); + using (var writer = new Utf8JsonWriter(stream, new JsonWriterOptions { Indented = false })) + { + WriteDocumentAsJson(writer, document); + } // Explicit scope to ensure writer is flushed before reading buffer + + return Encoding.UTF8.GetString(stream.ToArray()); + } + + /// Recursively writes a as JSON. + private static void WriteDocumentAsJson(Utf8JsonWriter writer, Document document, int depth = 0) + { + // Check depth to prevent stack overflow from deeply nested or circular structures + if (depth > MaxDocumentNestingDepth) + { + throw new InvalidOperationException( + $"Document nesting depth exceeds maximum of {MaxDocumentNestingDepth}. " + + $"This may indicate a circular reference or excessively nested data structure."); + } + + if (document.IsBool()) + { + writer.WriteBooleanValue(document.AsBool()); + } + else if (document.IsInt()) + { + writer.WriteNumberValue(document.AsInt()); + } + else if (document.IsLong()) + { + writer.WriteNumberValue(document.AsLong()); + } + else if (document.IsDouble()) + { + writer.WriteNumberValue(document.AsDouble()); + } + else if (document.IsString()) + { + writer.WriteStringValue(document.AsString()); + } + else if (document.IsDictionary()) + { + writer.WriteStartObject(); + foreach (var kvp in document.AsDictionary()) + { + writer.WritePropertyName(kvp.Key); + WriteDocumentAsJson(writer, kvp.Value, depth + 1); + } + writer.WriteEndObject(); + } + else if (document.IsList()) + { + writer.WriteStartArray(); + foreach (var item in document.AsList()) + { + WriteDocumentAsJson(writer, item, depth + 1); + } + writer.WriteEndArray(); + } + else + { + writer.WriteNullValue(); + } + } + /// Creates an from the specified options. private static InferenceConfiguration CreateInferenceConfiguration(InferenceConfiguration config, ChatOptions? options) { diff --git a/extensions/test/BedrockMEAITests/BedrockChatClientTests.cs b/extensions/test/BedrockMEAITests/BedrockChatClientTests.cs index 8f5099c973d8..b599c643bf3d 100644 --- a/extensions/test/BedrockMEAITests/BedrockChatClientTests.cs +++ b/extensions/test/BedrockMEAITests/BedrockChatClientTests.cs @@ -1,11 +1,111 @@ -using Microsoft.Extensions.AI; +using Amazon.BedrockRuntime.Model; +using Amazon.Runtime.Documents; +using Microsoft.Extensions.AI; using System; +using System.Collections.Generic; +using System.Linq; +using System.Text.Json; +using System.Threading; +using System.Threading.Tasks; using Xunit; namespace Amazon.BedrockRuntime; +// Mock implementation to capture requests and control responses +internal sealed class MockBedrockRuntime : IAmazonBedrockRuntime +{ + public ConverseRequest CapturedRequest { get; private set; } + public ConverseStreamRequest CapturedStreamRequest { get; private set; } + public Func ResponseFactory { get; set; } + public Exception ExceptionToThrow { get; set; } + + public Task ConverseAsync(ConverseRequest request, CancellationToken cancellationToken = default) + { + CapturedRequest = request; + + if (ExceptionToThrow != null) + { + throw ExceptionToThrow; + } + + if (ResponseFactory != null) + { + return Task.FromResult(ResponseFactory(request)); + } + + // Default response + return Task.FromResult(new ConverseResponse + { + Output = new ConverseOutput + { + Message = new Message + { + Role = ConversationRole.Assistant, + Content = new List + { + new ContentBlock { Text = "Default response" } + } + } + }, + StopReason = new StopReason("end_turn") + }); + } + + public Task ConverseStreamAsync(ConverseStreamRequest request, CancellationToken cancellationToken = default) + { + CapturedStreamRequest = request; + throw new NotImplementedException("Stream testing not implemented in this mock"); + } + + public void Dispose() { } + + // Unused interface members - all throw NotImplementedException + public IBedrockRuntimePaginatorFactory Paginators => throw new NotImplementedException(); + public Amazon.Runtime.IClientConfig Config => throw new NotImplementedException(); + + // Sync methods + public ApplyGuardrailResponse ApplyGuardrail(ApplyGuardrailRequest request) => throw new NotImplementedException(); + public ConverseResponse Converse(ConverseRequest request) => throw new NotImplementedException(); + public ConverseStreamResponse ConverseStream(ConverseStreamRequest request) => throw new NotImplementedException(); + public CountTokensResponse CountTokens(CountTokensRequest request) => throw new NotImplementedException(); + public GetAsyncInvokeResponse GetAsyncInvoke(GetAsyncInvokeRequest request) => throw new NotImplementedException(); + public InvokeModelResponse InvokeModel(InvokeModelRequest request) => throw new NotImplementedException(); + public InvokeModelWithResponseStreamResponse InvokeModelWithResponseStream(InvokeModelWithResponseStreamRequest request) => throw new NotImplementedException(); + public ListAsyncInvokesResponse ListAsyncInvokes(ListAsyncInvokesRequest request) => throw new NotImplementedException(); + public StartAsyncInvokeResponse StartAsyncInvoke(StartAsyncInvokeRequest request) => throw new NotImplementedException(); + + // Async methods + public Task ApplyGuardrailAsync(ApplyGuardrailRequest request, CancellationToken cancellationToken = default) => throw new NotImplementedException(); + public Task CountTokensAsync(CountTokensRequest request, CancellationToken cancellationToken = default) => throw new NotImplementedException(); + public Task GetAsyncInvokeAsync(GetAsyncInvokeRequest request, CancellationToken cancellationToken = default) => throw new NotImplementedException(); + public Task InvokeModelAsync(InvokeModelRequest request, CancellationToken cancellationToken = default) => throw new NotImplementedException(); + public Task InvokeModelWithResponseStreamAsync(InvokeModelWithResponseStreamRequest request, CancellationToken cancellationToken = default) => throw new NotImplementedException(); + public Task ListAsyncInvokesAsync(ListAsyncInvokesRequest request, CancellationToken cancellationToken = default) => throw new NotImplementedException(); + public Task StartAsyncInvokeAsync(StartAsyncInvokeRequest request, CancellationToken cancellationToken = default) => throw new NotImplementedException(); + + // Endpoint determination + public Amazon.Runtime.Endpoints.Endpoint DetermineServiceOperationEndpoint(Amazon.Runtime.AmazonWebServiceRequest request) => throw new NotImplementedException(); +} + +// Simple test implementation of AIFunctionDeclaration +internal sealed class TestAIFunction : AIFunctionDeclaration +{ + public TestAIFunction(string name, string description, JsonElement jsonSchema) + { + Name = name; + Description = description; + JsonSchema = jsonSchema; + } + + public override string Name { get; } + public override string Description { get; } + public override JsonElement JsonSchema { get; } +} + public class BedrockChatClientTests { + #region Basic Client Tests + [Fact] [Trait("UnitTest", "BedrockRuntime")] public void AsIChatClient_InvalidArguments_Throws() @@ -19,8 +119,8 @@ public void AsIChatClient_InvalidArguments_Throws() [InlineData("claude")] public void AsIChatClient_ReturnsInstance(string modelId) { - IAmazonBedrockRuntime runtime = new AmazonBedrockRuntimeClient("awsAccessKeyId", "awsSecretAccessKey", RegionEndpoint.USEast1); - IChatClient client = runtime.AsIChatClient(modelId); + var mock = new MockBedrockRuntime(); + IChatClient client = mock.AsIChatClient(modelId); Assert.NotNull(client); Assert.Equal("aws.bedrock", client.GetService()?.ProviderName); @@ -31,17 +131,135 @@ public void AsIChatClient_ReturnsInstance(string modelId) [Trait("UnitTest", "BedrockRuntime")] public void AsIChatClient_GetService() { - IAmazonBedrockRuntime runtime = new AmazonBedrockRuntimeClient("awsAccessKeyId", "awsSecretAccessKey", RegionEndpoint.USEast1); - IChatClient client = runtime.AsIChatClient(); + var mock = new MockBedrockRuntime(); + IChatClient client = mock.AsIChatClient(); - Assert.Same(runtime, client.GetService()); - Assert.Same(runtime, client.GetService()); + Assert.Same(mock, client.GetService()); Assert.Same(client, client.GetService()); - Assert.Null(client.GetService()); - - Assert.Null(client.GetService("key")); Assert.Null(client.GetService("key")); - Assert.Null(client.GetService("key")); } + + #endregion + + #region ResponseFormat Tests + + [Fact] + [Trait("UnitTest", "BedrockRuntime")] + public async Task ResponseFormat_Json_WithSchema_CreatesSyntheticToolWithCorrectSchema() + { + // Arrange + var mock = new MockBedrockRuntime(); + var client = mock.AsIChatClient("claude-3"); + var messages = new[] { new ChatMessage(ChatRole.User, "Test") }; + + var schemaJson = """ + { + "type": "object", + "properties": { + "name": { "type": "string" }, + "age": { "type": "number" } + }, + "required": ["name"] + } + """; + var schemaElement = JsonDocument.Parse(schemaJson).RootElement; + var options = new ChatOptions + { + ResponseFormat = ChatResponseFormat.ForJsonSchema(schemaElement, + schemaName: "PersonSchema", + schemaDescription: "A person object") + }; + + // Act + try + { + await client.GetResponseAsync(messages, options); + } + catch + { + // We're testing request creation + } + + // Assert + var tool = mock.CapturedRequest.ToolConfig.Tools[0]; + Assert.Equal("generate_response", tool.ToolSpec.Name); + Assert.Equal("A person object", tool.ToolSpec.Description); + + // Verify schema structure matches input + var schema = tool.ToolSpec.InputSchema.Json; + Assert.True(schema.IsDictionary()); + var schemaDict = schema.AsDictionary(); + + Assert.Equal("object", schemaDict["type"].AsString()); + Assert.True(schemaDict.ContainsKey("properties")); + + var properties = schemaDict["properties"].AsDictionary(); + Assert.True(properties.ContainsKey("name")); + Assert.True(properties.ContainsKey("age")); + Assert.Equal("string", properties["name"].AsDictionary()["type"].AsString()); + Assert.Equal("number", properties["age"].AsDictionary()["type"].AsString()); + + Assert.True(schemaDict.ContainsKey("required")); + var required = schemaDict["required"].AsList(); + Assert.Single(required); + Assert.Equal("name", required[0].AsString()); + } + + [Fact] + [Trait("UnitTest", "BedrockRuntime")] + public async Task ResponseFormat_Json_ModelReturnsToolUse_ExtractsJsonCorrectly() + { + // Arrange + var mock = new MockBedrockRuntime(); + var client = mock.AsIChatClient("claude-3"); + var messages = new[] { new ChatMessage(ChatRole.User, "Get weather") }; + var options = new ChatOptions { ResponseFormat = ChatResponseFormat.Json }; + + // Setup mock to return tool use with structured data + mock.ResponseFactory = req => new ConverseResponse + { + Output = new ConverseOutput + { + Message = new Message + { + Role = ConversationRole.Assistant, + Content = new List + { + new ContentBlock + { + ToolUse = new ToolUseBlock + { + ToolUseId = "test-id", + Name = "generate_response", + Input = new Document(new Dictionary + { + ["city"] = new Document("Seattle"), + ["temperature"] = new Document(72), + ["conditions"] = new Document("sunny") + }) + } + } + } + } + }, + StopReason = new StopReason("tool_use"), + Usage = new TokenUsage { InputTokens = 10, OutputTokens = 20, TotalTokens = 30 } + }; + + // Act + var response = await client.GetResponseAsync(messages, options); + + // Assert + Assert.NotNull(response); + Assert.NotNull(response.Text); + + // Parse the JSON to verify structure + var json = JsonDocument.Parse(response.Text); + Assert.Equal("Seattle", json.RootElement.GetProperty("city").GetString()); + Assert.Equal(72, json.RootElement.GetProperty("temperature").GetInt32()); + Assert.Equal("sunny", json.RootElement.GetProperty("conditions").GetString()); + } + + #endregion }