From 44e259304aaf8f01074fa4f85d464b0df709bb14 Mon Sep 17 00:00:00 2001 From: Danny Kopping Date: Thu, 13 Nov 2025 08:51:00 +0200 Subject: [PATCH 01/10] chore: renaming fixtures and making tests more specific Signed-off-by: Danny Kopping --- bridge_integration_test.go | 160 ++++++++---------- .../{error.txtar => stream_error.txtar} | 12 +- .../{error.txtar => stream_error.txtar} | 10 +- 3 files changed, 72 insertions(+), 110 deletions(-) rename fixtures/anthropic/{error.txtar => stream_error.txtar} (85%) rename fixtures/openai/{error.txtar => stream_error.txtar} (89%) diff --git a/bridge_integration_test.go b/bridge_integration_test.go index 5bf519e..06e8a76 100644 --- a/bridge_integration_test.go +++ b/bridge_integration_test.go @@ -47,8 +47,8 @@ var ( antSingleInjectedTool []byte //go:embed fixtures/anthropic/fallthrough.txtar antFallthrough []byte - //go:embed fixtures/anthropic/error.txtar - antErr []byte + //go:embed fixtures/anthropic/stream_error.txtar + antMidstreamErr []byte //go:embed fixtures/openai/simple.txtar oaiSimple []byte @@ -58,8 +58,8 @@ var ( oaiSingleInjectedTool []byte //go:embed fixtures/openai/fallthrough.txtar oaiFallthrough []byte - //go:embed fixtures/openai/error.txtar - oaiErr []byte + //go:embed fixtures/openai/stream_error.txtar + oaiMidstreamErr []byte ) const ( @@ -1009,23 +1009,23 @@ func setupInjectedToolTest(t *testing.T, fixture []byte, streaming bool, configu func TestErrorHandling(t *testing.T) { t.Parallel() - cases := []struct { - name string - fixture []byte - createRequestFunc createRequestFunc - configureFunc configureFunc - responseHandlerFn func(streaming bool, resp *http.Response) - }{ - { - name: aibridge.ProviderAnthropic, - fixture: antErr, - createRequestFunc: createAnthropicMessagesReq, - configureFunc: func(addr string, client aibridge.Recorder, srvProxyMgr *mcp.ServerProxyManager) (*aibridge.RequestBridge, error) { - logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: false}).Leveled(slog.LevelDebug) - return aibridge.NewRequestBridge(t.Context(), []aibridge.Provider{aibridge.NewAnthropicProvider(anthropicCfg(addr, apiKey), nil)}, logger, client, srvProxyMgr) - }, - responseHandlerFn: func(streaming bool, resp *http.Response) { - if streaming { + t.Run("mid-stream error", func(t *testing.T) { + cases := []struct { + name string + fixture []byte + createRequestFunc createRequestFunc + configureFunc configureFunc + responseHandlerFn func(resp *http.Response) + }{ + { + name: aibridge.ProviderAnthropic, + fixture: antMidstreamErr, + createRequestFunc: createAnthropicMessagesReq, + configureFunc: func(addr string, client aibridge.Recorder, srvProxyMgr *mcp.ServerProxyManager) (*aibridge.RequestBridge, error) { + logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: false}).Leveled(slog.LevelDebug) + return aibridge.NewRequestBridge(t.Context(), []aibridge.Provider{aibridge.NewAnthropicProvider(anthropicCfg(addr, apiKey), nil)}, logger, client, srvProxyMgr) + }, + responseHandlerFn: func(resp *http.Response) { // Server responds first with 200 OK then starts streaming. require.Equal(t, http.StatusOK, resp.StatusCode) @@ -1033,24 +1033,17 @@ func TestErrorHandling(t *testing.T) { require.NoError(t, sp.Parse(resp.Body)) require.Len(t, sp.EventsByType("error"), 1) require.Contains(t, sp.EventsByType("error")[0].Data, "Overloaded") - } else { - require.Equal(t, resp.StatusCode, http.StatusInternalServerError) - body, err := io.ReadAll(resp.Body) - require.NoError(t, err) - require.Contains(t, string(body), "Overloaded") - } - }, - }, - { - name: aibridge.ProviderOpenAI, - fixture: oaiErr, - createRequestFunc: createOpenAIChatCompletionsReq, - configureFunc: func(addr string, client aibridge.Recorder, srvProxyMgr *mcp.ServerProxyManager) (*aibridge.RequestBridge, error) { - logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: false}).Leveled(slog.LevelDebug) - return aibridge.NewRequestBridge(t.Context(), []aibridge.Provider{aibridge.NewOpenAIProvider(aibridge.OpenAIConfig(anthropicCfg(addr, apiKey)))}, logger, client, srvProxyMgr) + }, }, - responseHandlerFn: func(streaming bool, resp *http.Response) { - if streaming { + { + name: aibridge.ProviderOpenAI, + fixture: oaiMidstreamErr, + createRequestFunc: createOpenAIChatCompletionsReq, + configureFunc: func(addr string, client aibridge.Recorder, srvProxyMgr *mcp.ServerProxyManager) (*aibridge.RequestBridge, error) { + logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: false}).Leveled(slog.LevelDebug) + return aibridge.NewRequestBridge(t.Context(), []aibridge.Provider{aibridge.NewOpenAIProvider(aibridge.OpenAIConfig(anthropicCfg(addr, apiKey)))}, logger, client, srvProxyMgr) + }, + responseHandlerFn: func(resp *http.Response) { // Server responds first with 200 OK then starts streaming. require.Equal(t, http.StatusOK, resp.StatusCode) @@ -1063,72 +1056,55 @@ func TestErrorHandling(t *testing.T) { errEvent := sp.MessageEvents()[len(sp.MessageEvents())-2] // Last event is termination marker ("[DONE]"). require.NotEmpty(t, errEvent) require.Contains(t, errEvent.Data, "The server had an error while processing your request. Sorry about that!") - } else { - require.Equal(t, resp.StatusCode, http.StatusInternalServerError) - body, err := io.ReadAll(resp.Body) - require.NoError(t, err) - require.Contains(t, string(body), "The server had an error while processing your request. Sorry about that") - } + }, }, - }, - } - - for _, tc := range cases { - t.Run(tc.name, func(t *testing.T) { - t.Parallel() - - for _, streaming := range []bool{true, false} { - t.Run(fmt.Sprintf("streaming=%v", streaming), func(t *testing.T) { - t.Parallel() + } - ctx, cancel := context.WithTimeout(t.Context(), time.Second*30) - t.Cleanup(cancel) + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() - arc := txtar.Parse(tc.fixture) - t.Logf("%s: %s", t.Name(), arc.Comment) + ctx, cancel := context.WithTimeout(t.Context(), time.Second*30) + t.Cleanup(cancel) - files := filesMap(arc) - require.Len(t, files, 3) - require.Contains(t, files, fixtureRequest) - require.Contains(t, files, fixtureStreamingResponse) - require.Contains(t, files, fixtureNonStreamingResponse) + arc := txtar.Parse(tc.fixture) + t.Logf("%s: %s", t.Name(), arc.Comment) - reqBody := files[fixtureRequest] + files := filesMap(arc) + require.Len(t, files, 2) + require.Contains(t, files, fixtureRequest) + require.Contains(t, files, fixtureStreamingResponse) - // Add the stream param to the request. - newBody, err := setJSON(reqBody, "stream", streaming) - require.NoError(t, err) - reqBody = newBody + reqBody := files[fixtureRequest] - // Setup mock server. - mockSrv := newMockServer(ctx, t, files, nil) - mockSrv.statusCode = http.StatusInternalServerError - t.Cleanup(mockSrv.Close) + // Setup mock server. + mockSrv := newMockServer(ctx, t, files, nil) + mockSrv.statusCode = http.StatusInternalServerError + t.Cleanup(mockSrv.Close) - recorderClient := &mockRecorderClient{} + recorderClient := &mockRecorderClient{} - b, err := tc.configureFunc(mockSrv.URL, recorderClient, mcp.NewServerProxyManager(nil)) - require.NoError(t, err) + b, err := tc.configureFunc(mockSrv.URL, recorderClient, mcp.NewServerProxyManager(nil)) + require.NoError(t, err) - // Invoke request to mocked API via aibridge. - bridgeSrv := httptest.NewUnstartedServer(b) - bridgeSrv.Config.BaseContext = func(_ net.Listener) context.Context { - return aibridge.AsActor(ctx, userID, nil) - } - bridgeSrv.Start() - t.Cleanup(bridgeSrv.Close) + // Invoke request to mocked API via aibridge. + bridgeSrv := httptest.NewUnstartedServer(b) + bridgeSrv.Config.BaseContext = func(_ net.Listener) context.Context { + return aibridge.AsActor(ctx, userID, nil) + } + bridgeSrv.Start() + t.Cleanup(bridgeSrv.Close) - req := tc.createRequestFunc(t, bridgeSrv.URL, reqBody) - resp, err := http.DefaultClient.Do(req) - t.Cleanup(func() { _ = resp.Body.Close() }) - require.NoError(t, err) + req := tc.createRequestFunc(t, bridgeSrv.URL, reqBody) + resp, err := http.DefaultClient.Do(req) + t.Cleanup(func() { _ = resp.Body.Close() }) + require.NoError(t, err) - tc.responseHandlerFn(streaming, resp) - recorderClient.verifyAllInterceptionsEnded(t) - }) - } - }) - } + tc.responseHandlerFn(resp) + recorderClient.verifyAllInterceptionsEnded(t) + }) + } + }) } // TestStableRequestEncoding validates that a given intercepted request and a diff --git a/fixtures/anthropic/error.txtar b/fixtures/anthropic/stream_error.txtar similarity index 85% rename from fixtures/anthropic/error.txtar rename to fixtures/anthropic/stream_error.txtar index 81ed89d..8b63444 100644 --- a/fixtures/anthropic/error.txtar +++ b/fixtures/anthropic/stream_error.txtar @@ -15,7 +15,8 @@ Simple request + error. } ], "model": "claude-sonnet-4-0", - "temperature": 1 + "temperature": 1, + "stream": true } -- streaming -- @@ -31,12 +32,3 @@ data: {"type": "ping"} event: error data: {"type": "error", "error": {"type": "api_error", "message": "Overloaded"}} --- non-streaming -- -{ - "type": "error", - "error": { - "type": "api_error", - "message": "Overloaded" - }, - "request_id": null -} \ No newline at end of file diff --git a/fixtures/openai/error.txtar b/fixtures/openai/stream_error.txtar similarity index 89% rename from fixtures/openai/error.txtar rename to fixtures/openai/stream_error.txtar index 8e9efae..678800b 100644 --- a/fixtures/openai/error.txtar +++ b/fixtures/openai/stream_error.txtar @@ -8,7 +8,8 @@ Simple request + error. "content": "how many angels can dance on the head of a pin\n" } ], - "model": "gpt-4.1" + "model": "gpt-4.1", + "stream": true } -- streaming -- @@ -22,10 +23,3 @@ data: {"id":"chatcmpl-BwoiPTGRbKkY5rncfaM0s9KtWrq5N","object":"chat.completion.c data: {"error": {"message": "The server had an error while processing your request. Sorry about that!", "type": "server_error"}} --- non-streaming -- -{ - "error": { - "message": "The server had an error while processing your request. Sorry about that!", - "type": "server_error" - } -} \ No newline at end of file From 1570e2aa21f415b16dd99fc564d8d0b287b73c44 Mon Sep 17 00:00:00 2001 From: Danny Kopping Date: Thu, 13 Nov 2025 11:42:52 +0200 Subject: [PATCH 02/10] chore: add test for upstream errs that occur before stream starts (anthropic only) Signed-off-by: Danny Kopping --- bridge_integration_test.go | 138 +++++++++++++++++++++- fixtures/anthropic/non_stream_error.txtar | 49 ++++++++ go.mod | 2 +- intercept_anthropic_messages_base.go | 26 ++++ intercept_anthropic_messages_blocking.go | 10 +- intercept_anthropic_messages_streaming.go | 61 +++++----- streaming.go | 30 ++++- 7 files changed, 276 insertions(+), 40 deletions(-) create mode 100644 fixtures/anthropic/non_stream_error.txtar diff --git a/bridge_integration_test.go b/bridge_integration_test.go index 06e8a76..83d4029 100644 --- a/bridge_integration_test.go +++ b/bridge_integration_test.go @@ -30,6 +30,7 @@ import ( "github.com/google/uuid" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "github.com/tidwall/gjson" "github.com/openai/openai-go/v2" oaissestream "github.com/openai/openai-go/v2/packages/ssestream" @@ -48,7 +49,9 @@ var ( //go:embed fixtures/anthropic/fallthrough.txtar antFallthrough []byte //go:embed fixtures/anthropic/stream_error.txtar - antMidstreamErr []byte + antMidStreamErr []byte + //go:embed fixtures/anthropic/non_stream_error.txtar + antNonStreamErr []byte //go:embed fixtures/openai/simple.txtar oaiSimple []byte @@ -59,7 +62,7 @@ var ( //go:embed fixtures/openai/fallthrough.txtar oaiFallthrough []byte //go:embed fixtures/openai/stream_error.txtar - oaiMidstreamErr []byte + oaiMidStreamErr []byte ) const ( @@ -1009,6 +1012,95 @@ func setupInjectedToolTest(t *testing.T, fixture []byte, streaming bool, configu func TestErrorHandling(t *testing.T) { t.Parallel() + // Tests that errors which occur *before* a streaming response begins, or in non-streaming requests, are handled as expected. + t.Run("non-stream error", func(t *testing.T) { + cases := []struct { + name string + fixture []byte + createRequestFunc createRequestFunc + configureFunc configureFunc + responseHandlerFn func(resp *http.Response) + }{ + { + name: aibridge.ProviderAnthropic, + fixture: antNonStreamErr, + createRequestFunc: createAnthropicMessagesReq, + configureFunc: func(addr string, client aibridge.Recorder, srvProxyMgr *mcp.ServerProxyManager) (*aibridge.RequestBridge, error) { + logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: false}).Leveled(slog.LevelDebug) + return aibridge.NewRequestBridge(t.Context(), []aibridge.Provider{aibridge.NewAnthropicProvider(anthropicCfg(addr, apiKey), nil)}, logger, client, srvProxyMgr) + }, + responseHandlerFn: func(resp *http.Response) { + require.Equal(t, http.StatusBadRequest, resp.StatusCode) + body, err := io.ReadAll(resp.Body) + require.NoError(t, err) + require.Equal(t, "error", gjson.GetBytes(body, "type").Str) + require.Equal(t, "invalid_request_error", gjson.GetBytes(body, "error.type").Str) + require.Contains(t, gjson.GetBytes(body, "error.message").Str, "prompt is too long") + }, + }, + } + + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + for _, streaming := range []bool{true, false} { + t.Run(fmt.Sprintf("streaming=%v", streaming), func(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithTimeout(t.Context(), time.Second*30) + t.Cleanup(cancel) + + arc := txtar.Parse(tc.fixture) + t.Logf("%s: %s", t.Name(), arc.Comment) + + files := filesMap(arc) + require.Len(t, files, 3) + require.Contains(t, files, fixtureRequest) + require.Contains(t, files, fixtureStreamingResponse) + require.Contains(t, files, fixtureNonStreamingResponse) + + reqBody := files[fixtureRequest] + // Add the stream param to the request. + newBody, err := setJSON(reqBody, "stream", streaming) + require.NoError(t, err) + reqBody = newBody + + // Setup mock server. + mockResp := files[fixtureStreamingResponse] + if !streaming { + mockResp = files[fixtureNonStreamingResponse] + } + mockSrv := newMockHTTPReflector(ctx, t, mockResp) + t.Cleanup(mockSrv.Close) + + recorderClient := &mockRecorderClient{} + + b, err := tc.configureFunc(mockSrv.URL, recorderClient, mcp.NewServerProxyManager(nil)) + require.NoError(t, err) + + // Invoke request to mocked API via aibridge. + bridgeSrv := httptest.NewUnstartedServer(b) + bridgeSrv.Config.BaseContext = func(_ net.Listener) context.Context { + return aibridge.AsActor(ctx, userID, nil) + } + bridgeSrv.Start() + t.Cleanup(bridgeSrv.Close) + + req := tc.createRequestFunc(t, bridgeSrv.URL, reqBody) + resp, err := http.DefaultClient.Do(req) + t.Cleanup(func() { _ = resp.Body.Close() }) + require.NoError(t, err) + + tc.responseHandlerFn(resp) + recorderClient.verifyAllInterceptionsEnded(t) + }) + } + }) + } + }) + + // Tests that errors which occur *during* a streaming response are handled as expected. t.Run("mid-stream error", func(t *testing.T) { cases := []struct { name string @@ -1019,7 +1111,7 @@ func TestErrorHandling(t *testing.T) { }{ { name: aibridge.ProviderAnthropic, - fixture: antMidstreamErr, + fixture: antMidStreamErr, createRequestFunc: createAnthropicMessagesReq, configureFunc: func(addr string, client aibridge.Recorder, srvProxyMgr *mcp.ServerProxyManager) (*aibridge.RequestBridge, error) { logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: false}).Leveled(slog.LevelDebug) @@ -1037,7 +1129,7 @@ func TestErrorHandling(t *testing.T) { }, { name: aibridge.ProviderOpenAI, - fixture: oaiMidstreamErr, + fixture: oaiMidStreamErr, createRequestFunc: createOpenAIChatCompletionsReq, configureFunc: func(addr string, client aibridge.Recorder, srvProxyMgr *mcp.ServerProxyManager) (*aibridge.RequestBridge, error) { logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: false}).Leveled(slog.LevelDebug) @@ -1273,6 +1365,44 @@ func createOpenAIChatCompletionsReq(t *testing.T, baseURL string, input []byte) return req } +type mockHTTPReflector struct { + *httptest.Server +} + +func newMockHTTPReflector(ctx context.Context, t *testing.T, resp []byte) *mockHTTPReflector { + ref := &mockHTTPReflector{} + + srv := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + mock, err := http.ReadResponse(bufio.NewReader(bytes.NewBuffer(resp)), r) + require.NoError(t, err) + defer mock.Body.Close() + + // Copy headers from the mocked response. + for key, values := range mock.Header { + for _, value := range values { + w.Header().Add(key, value) + } + } + + // Write the status code. + w.WriteHeader(mock.StatusCode) + + // Copy the body. + _, err = io.Copy(w, mock.Body) + require.NoError(t, err) + })) + srv.Config.BaseContext = func(_ net.Listener) context.Context { + return ctx + } + + srv.Start() + t.Cleanup(srv.Close) + + ref.Server = srv + return ref +} + +// TODO: replace this with mockHTTPReflector. type mockServer struct { *httptest.Server diff --git a/fixtures/anthropic/non_stream_error.txtar b/fixtures/anthropic/non_stream_error.txtar new file mode 100644 index 0000000..9ef332a --- /dev/null +++ b/fixtures/anthropic/non_stream_error.txtar @@ -0,0 +1,49 @@ +Simple request + error which occurs before streaming begins (where applicable). + +-- request -- +{ + "max_tokens": 8192, + "messages": [ + { + "role": "user", + "content": [ + { + "type": "text", + "text": "yo" + } + ] + } + ], + "model": "claude-sonnet-4-0", + "temperature": 1 +} + +-- streaming -- +HTTP/2.0 400 Bad Request +Content-Length: 164 +Content-Type: application/json +Date: Thu, 13 Nov 2025 05:03:44 GMT +Request-Id: req_011CV5Jab6gR3ZNs9Sj6apiD +Server: cloudflare +Strict-Transport-Security: max-age=31536000; includeSubDomains; preload +X-Envoy-Upstream-Service-Time: 984 +X-Robots-Tag: none +X-Should-Retry: false + +{"type":"error","error":{"type":"invalid_request_error","message":"prompt is too long: 205429 tokens > 200000 maximum"},"request_id":"req_011CV5Jab6gR3ZNs9Sj6apiD"} + + +-- non-streaming -- +HTTP/2.0 400 Bad Request +Content-Length: 164 +Content-Type: application/json +Date: Thu, 13 Nov 2025 05:03:44 GMT +Request-Id: req_011CV5Jab6gR3ZNs9Sj6apiD +Server: cloudflare +Strict-Transport-Security: max-age=31536000; includeSubDomains; preload +X-Envoy-Upstream-Service-Time: 984 +X-Robots-Tag: none +X-Should-Retry: false + +{"type":"error","error":{"type":"invalid_request_error","message":"prompt is too long: 205429 tokens > 200000 maximum"},"request_id":"req_011CV5Jab6gR3ZNs9Sj6apiD"} + diff --git a/go.mod b/go.mod index a0e1ffa..827a224 100644 --- a/go.mod +++ b/go.mod @@ -9,7 +9,7 @@ require ( github.com/hashicorp/go-multierror v1.1.1 github.com/mark3labs/mcp-go v0.38.0 github.com/stretchr/testify v1.10.0 - github.com/tidwall/gjson v1.18.0 // indirect + github.com/tidwall/gjson v1.18.0 github.com/tidwall/sjson v1.2.5 go.uber.org/goleak v1.3.0 go.uber.org/mock v0.6.0 diff --git a/intercept_anthropic_messages_base.go b/intercept_anthropic_messages_base.go index 8459618..d58b8c5 100644 --- a/intercept_anthropic_messages_base.go +++ b/intercept_anthropic_messages_base.go @@ -2,6 +2,7 @@ package aibridge import ( "context" + "encoding/json" "fmt" "net/http" "net/url" @@ -173,6 +174,31 @@ func (i *AnthropicMessagesInterceptionBase) augmentRequestForBedrock() { i.req.MessageNewParams.Model = anthropic.Model(i.Model()) } +// writeUpstreamError marshals and writes a given error. +func (i *AnthropicMessagesInterceptionBase) writeUpstreamError(w http.ResponseWriter, antErr *AnthropicErrorResponse) { + if antErr == nil { + return + } + + w.WriteHeader(antErr.StatusCode) + out, err := json.Marshal(antErr) + if err != nil { + i.logger.Warn(context.Background(), "failed to marshal upstream error", slog.Error(err), slog.F("error_payload", slog.F("%+v", antErr))) + // Response has to match expected format. + // See https://docs.claude.com/en/api/errors#error-shapes. + _, _ = w.Write([]byte(fmt.Sprintf(`{ + "type":"error", + "error": { + "type": "error", + "message":"error marshaling upstream error" + }, + "request_id": "%s", +}`, i.ID().String()))) + } else { + _, _ = w.Write(out) + } +} + // redirectTransport is an HTTP RoundTripper that redirects requests to a different endpoint. // This is useful for testing when we need to redirect AWS Bedrock requests to a mock server. type redirectTransport struct { diff --git a/intercept_anthropic_messages_blocking.go b/intercept_anthropic_messages_blocking.go index cc2daaf..f978113 100644 --- a/intercept_anthropic_messages_blocking.go +++ b/intercept_anthropic_messages_blocking.go @@ -76,19 +76,17 @@ func (i *AnthropicMessagesBlockingInterception) ProcessRequest(w http.ResponseWr resp, err = client.Messages.New(ctx, messages) if err != nil { if isConnError(err) { - logger.Warn(ctx, "upstream connection closed", slog.Error(err)) + // Can't write a response, just error out. return fmt.Errorf("upstream connection closed: %w", err) } - logger.Warn(ctx, "anthropic API error", slog.Error(err)) if antErr := getAnthropicErrorResponse(err); antErr != nil { - http.Error(w, antErr.Error(), antErr.StatusCode) - return fmt.Errorf("api error: %w", err) + i.writeUpstreamError(w, antErr) + return fmt.Errorf("anthropic API error: %w", err) } - logger.Warn(ctx, "upstream API error", slog.Error(err)) http.Error(w, "internal error", http.StatusInternalServerError) - return fmt.Errorf("upstream API error: %w", err) + return fmt.Errorf("internal error: %w", err) } if prompt != nil { diff --git a/intercept_anthropic_messages_streaming.go b/intercept_anthropic_messages_streaming.go index 403d1e6..f3cbf91 100644 --- a/intercept_anthropic_messages_streaming.go +++ b/intercept_anthropic_messages_streaming.go @@ -97,7 +97,6 @@ func (i *AnthropicMessagesStreamingInterception) ProcessRequest(w http.ResponseW // events will either terminate when shutdown after interaction with upstream completes, or when streamCtx is done. events := newEventStream(streamCtx, logger.Named("sse-sender"), i.pingPayload()) - go events.run(w, r) defer func() { _ = events.Shutdown(streamCtx) // Catch-all in case it doesn't get shutdown after stream completes. }() @@ -126,6 +125,9 @@ newStream: pendingToolCalls := make(map[string]string) for stream.Next() { + // Only start the event stream if the upstream starts streaming (as opposed to erroring out prematurely). + go events.run(w, r) + event := stream.Current() if err := message.Accumulate(event); err != nil { logger.Warn(ctx, "failed to accumulate streaming events", slog.Error(err), slog.F("event", event), slog.F("msg", message.RawJSON())) @@ -414,35 +416,40 @@ newStream: prompt = nil } - // Check if the stream encountered any errors. - if streamErr := stream.Err(); streamErr != nil { - if isUnrecoverableError(streamErr) { - logger.Debug(ctx, "stream terminated", slog.Error(streamErr)) - // We can't reflect an error back if there's a connection error or the request context was canceled. - } else if antErr := getAnthropicErrorResponse(streamErr); antErr != nil { - logger.Warn(ctx, "anthropic stream error", slog.Error(streamErr)) - interceptionErr = fmt.Errorf("stream error: %w", antErr) - } else { - logger.Warn(ctx, "unknown error", slog.Error(streamErr)) - // Unfortunately, the Anthropic SDK does not support parsing errors received in the stream - // into known types (i.e. [shared.OverloadedError]). - // See https://github.com/anthropics/anthropic-sdk-go/blob/v1.12.0/packages/ssestream/ssestream.go#L172-L174 - // All it does is wrap the payload in an error - which is all we can return, currently. - interceptionErr = newAnthropicErr(fmt.Errorf("unknown stream error: %w", streamErr)) + if events.isRunning() { + // Check if the stream encountered any errors. + if streamErr := stream.Err(); streamErr != nil { + if isUnrecoverableError(streamErr) { + logger.Debug(ctx, "stream terminated", slog.Error(streamErr)) + // We can't reflect an error back if there's a connection error or the request context was canceled. + } else if antErr := getAnthropicErrorResponse(streamErr); antErr != nil { + logger.Warn(ctx, "anthropic stream error", slog.Error(streamErr)) + interceptionErr = antErr + } else { + logger.Warn(ctx, "unknown error", slog.Error(streamErr)) + // Unfortunately, the Anthropic SDK does not support parsing errors received in the stream + // into known types (i.e. [shared.OverloadedError]). + // See https://github.com/anthropics/anthropic-sdk-go/blob/v1.12.0/packages/ssestream/ssestream.go#L172-L174 + // All it does is wrap the payload in an error - which is all we can return, currently. + interceptionErr = newAnthropicErr(fmt.Errorf("unknown stream error: %w", streamErr)) + } + } else if lastErr != nil { + // Otherwise check if any logical errors occurred during processing. + logger.Warn(ctx, "stream failed", slog.Error(lastErr)) + interceptionErr = newAnthropicErr(fmt.Errorf("processing error: %w", lastErr)) } - } else if lastErr != nil { - // Otherwise check if any logical errors occurred during processing. - logger.Warn(ctx, "stream failed", slog.Error(lastErr)) - interceptionErr = newAnthropicErr(fmt.Errorf("processing error: %w", lastErr)) - } - if interceptionErr != nil { - payload, err := i.marshal(interceptionErr) - if err != nil { - logger.Warn(ctx, "failed to marshal error", slog.Error(err), slog.F("error_payload", slog.F("%+v", interceptionErr))) - } else if err := events.Send(streamCtx, payload); err != nil { - logger.Warn(ctx, "failed to relay error", slog.Error(err), slog.F("payload", payload)) + if interceptionErr != nil { + payload, err := i.marshal(interceptionErr) + if err != nil { + logger.Warn(ctx, "failed to marshal error", slog.Error(err), slog.F("error_payload", slog.F("%+v", interceptionErr))) + } else if err := events.Send(streamCtx, payload); err != nil { + logger.Warn(ctx, "failed to relay error", slog.Error(err), slog.F("payload", payload)) + } } + } else { + // Stream has not started yet; write to response if present. + i.writeUpstreamError(w, getAnthropicErrorResponse(stream.Err())) } shutdownCtx, shutdownCancel := context.WithTimeout(ctx, time.Second*30) diff --git a/streaming.go b/streaming.go index e6fe72d..15950df 100644 --- a/streaming.go +++ b/streaming.go @@ -9,6 +9,7 @@ import ( "net/http" "strings" "sync" + "sync/atomic" "syscall" "time" @@ -27,6 +28,8 @@ type eventStream struct { pingPayload []byte + running atomic.Bool + closeOnce sync.Once shutdownOnce sync.Once eventsCh chan event @@ -50,8 +53,18 @@ func newEventStream(ctx context.Context, logger slog.Logger, pingPayload []byte) // run handles sending Server-Sent Event to the client. func (s *eventStream) run(w http.ResponseWriter, r *http.Request) { - // Signal completion on exit so senders don't block indefinitely after closure. - defer close(s.doneCh) + // Only one instance is allowed to run. + if s.running.Load() { + return + } + + s.running.Store(true) + defer func() { + // Signal completion on exit so senders don't block indefinitely after closure. + close(s.doneCh) + + s.running.Store(false) + }() ctx := r.Context() @@ -147,6 +160,10 @@ func (s *eventStream) sendRaw(ctx context.Context, payload []byte) error { // Shutdown gracefully shuts down the stream, sending any supplementary events downstream if required. // ONLY call this once all events have been submitted. func (s *eventStream) Shutdown(shutdownCtx context.Context) error { + defer func() { + s.running.Store(false) + }() + s.shutdownOnce.Do(func() { s.logger.Debug(shutdownCtx, "shutdown initiated", slog.F("outstanding_events", len(s.eventsCh))) @@ -155,6 +172,11 @@ func (s *eventStream) Shutdown(shutdownCtx context.Context) error { close(s.eventsCh) }) + // TODO: consider the safety of this approach. + if !s.running.Load() { + return nil + } + select { case <-shutdownCtx.Done(): // If shutdownCtx completes, shutdown likely exceeded its timeout. @@ -166,6 +188,10 @@ func (s *eventStream) Shutdown(shutdownCtx context.Context) error { } } +func (s *eventStream) isRunning() bool { + return s.running.Load() +} + // isConnError checks if an error is related to client disconnection or context cancellation. func isConnError(err error) bool { if err == nil { From 2b5ba45e355fe94ea4b9049b9fc6fe532c8a07fa Mon Sep 17 00:00:00 2001 From: Danny Kopping Date: Thu, 13 Nov 2025 14:10:29 +0200 Subject: [PATCH 03/10] chore: implement openai error handling Signed-off-by: Danny Kopping --- bridge_integration_test.go | 19 +++++++ fixtures/anthropic/non_stream_error.txtar | 14 ----- fixtures/openai/non_stream_error.txtar | 43 ++++++++++++++++ intercept_anthropic_messages_base.go | 2 + intercept_openai_chat_base.go | 35 +++++++++++++ intercept_openai_chat_blocking.go | 12 ++--- intercept_openai_chat_streaming.go | 63 +++++++++++++---------- openai.go | 50 ++++++------------ provider_openai.go | 10 ---- 9 files changed, 154 insertions(+), 94 deletions(-) create mode 100644 fixtures/openai/non_stream_error.txtar diff --git a/bridge_integration_test.go b/bridge_integration_test.go index 83d4029..b9b2745 100644 --- a/bridge_integration_test.go +++ b/bridge_integration_test.go @@ -63,6 +63,8 @@ var ( oaiFallthrough []byte //go:embed fixtures/openai/stream_error.txtar oaiMidStreamErr []byte + //go:embed fixtures/openai/non_stream_error.txtar + oaiNonStreamErr []byte ) const ( @@ -1038,6 +1040,23 @@ func TestErrorHandling(t *testing.T) { require.Contains(t, gjson.GetBytes(body, "error.message").Str, "prompt is too long") }, }, + { + name: aibridge.ProviderOpenAI, + fixture: oaiNonStreamErr, + createRequestFunc: createOpenAIChatCompletionsReq, + configureFunc: func(addr string, client aibridge.Recorder, srvProxyMgr *mcp.ServerProxyManager) (*aibridge.RequestBridge, error) { + logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: false}).Leveled(slog.LevelDebug) + return aibridge.NewRequestBridge(t.Context(), []aibridge.Provider{aibridge.NewOpenAIProvider(aibridge.OpenAIConfig(anthropicCfg(addr, apiKey)))}, logger, client, srvProxyMgr) + }, + responseHandlerFn: func(resp *http.Response) { + require.Equal(t, http.StatusBadRequest, resp.StatusCode) + body, err := io.ReadAll(resp.Body) + require.NoError(t, err) + require.Equal(t, "context_length_exceeded", gjson.GetBytes(body, "error.code").Str) + require.Equal(t, "invalid_request_error", gjson.GetBytes(body, "error.type").Str) + require.Contains(t, gjson.GetBytes(body, "error.message").Str, "Input tokens exceed the configured limit") + }, + }, } for _, tc := range cases { diff --git a/fixtures/anthropic/non_stream_error.txtar b/fixtures/anthropic/non_stream_error.txtar index 9ef332a..76a9347 100644 --- a/fixtures/anthropic/non_stream_error.txtar +++ b/fixtures/anthropic/non_stream_error.txtar @@ -22,13 +22,6 @@ Simple request + error which occurs before streaming begins (where applicable). HTTP/2.0 400 Bad Request Content-Length: 164 Content-Type: application/json -Date: Thu, 13 Nov 2025 05:03:44 GMT -Request-Id: req_011CV5Jab6gR3ZNs9Sj6apiD -Server: cloudflare -Strict-Transport-Security: max-age=31536000; includeSubDomains; preload -X-Envoy-Upstream-Service-Time: 984 -X-Robots-Tag: none -X-Should-Retry: false {"type":"error","error":{"type":"invalid_request_error","message":"prompt is too long: 205429 tokens > 200000 maximum"},"request_id":"req_011CV5Jab6gR3ZNs9Sj6apiD"} @@ -37,13 +30,6 @@ X-Should-Retry: false HTTP/2.0 400 Bad Request Content-Length: 164 Content-Type: application/json -Date: Thu, 13 Nov 2025 05:03:44 GMT -Request-Id: req_011CV5Jab6gR3ZNs9Sj6apiD -Server: cloudflare -Strict-Transport-Security: max-age=31536000; includeSubDomains; preload -X-Envoy-Upstream-Service-Time: 984 -X-Robots-Tag: none -X-Should-Retry: false {"type":"error","error":{"type":"invalid_request_error","message":"prompt is too long: 205429 tokens > 200000 maximum"},"request_id":"req_011CV5Jab6gR3ZNs9Sj6apiD"} diff --git a/fixtures/openai/non_stream_error.txtar b/fixtures/openai/non_stream_error.txtar new file mode 100644 index 0000000..e84ce09 --- /dev/null +++ b/fixtures/openai/non_stream_error.txtar @@ -0,0 +1,43 @@ +Simple request + error which occurs before streaming begins (where applicable). + +-- request -- +{ + "messages": [ + { + "role": "user", + "content": "how many angels can dance on the head of a pin\n" + } + ], + "model": "gpt-4.1", + "stream": true +} + +-- streaming -- +HTTP/2.0 400 Bad Request +Content-Length: 281 +Content-Type: application/json + +{ + "error": { + "message": "Input tokens exceed the configured limit of 272000 tokens. Your messages resulted in 3148588 tokens. Please reduce the length of the messages.", + "type": "invalid_request_error", + "param": "messages", + "code": "context_length_exceeded" + } +} + + +-- non-streaming -- +HTTP/2.0 400 Bad Request +Content-Length: 281 +Content-Type: application/json + +{ + "error": { + "message": "Input tokens exceed the configured limit of 272000 tokens. Your messages resulted in 3148588 tokens. Please reduce the length of the messages.", + "type": "invalid_request_error", + "param": "messages", + "code": "context_length_exceeded" + } +} + diff --git a/intercept_anthropic_messages_base.go b/intercept_anthropic_messages_base.go index d58b8c5..0215d6c 100644 --- a/intercept_anthropic_messages_base.go +++ b/intercept_anthropic_messages_base.go @@ -181,6 +181,8 @@ func (i *AnthropicMessagesInterceptionBase) writeUpstreamError(w http.ResponseWr } w.WriteHeader(antErr.StatusCode) + w.Header().Set("Content-Type", "application/json") + out, err := json.Marshal(antErr) if err != nil { i.logger.Warn(context.Background(), "failed to marshal upstream error", slog.Error(err), slog.F("error_payload", slog.F("%+v", antErr))) diff --git a/intercept_openai_chat_base.go b/intercept_openai_chat_base.go index 36b8ff0..1fa3a8b 100644 --- a/intercept_openai_chat_base.go +++ b/intercept_openai_chat_base.go @@ -3,11 +3,13 @@ package aibridge import ( "context" "encoding/json" + "net/http" "strings" "github.com/coder/aibridge/mcp" "github.com/google/uuid" "github.com/openai/openai-go/v2" + "github.com/openai/openai-go/v2/option" "github.com/openai/openai-go/v2/shared" "cdr.dev/slog" @@ -24,6 +26,14 @@ type OpenAIChatInterceptionBase struct { mcpProxy mcp.ServerProxier } +func (i *OpenAIChatInterceptionBase) newOpenAIClient(baseURL, key string) openai.Client { + var opts []option.RequestOption + opts = append(opts, option.WithAPIKey(key)) + opts = append(opts, option.WithBaseURL(baseURL)) + + return openai.NewClient(opts...) +} + func (i *OpenAIChatInterceptionBase) ID() uuid.UUID { return i.id } @@ -92,3 +102,28 @@ func (i *OpenAIChatInterceptionBase) unmarshalArgs(in string) (args ToolArgs) { return args } + +// writeUpstreamError marshals and writes a given error. +func (i *OpenAIChatInterceptionBase) writeUpstreamError(w http.ResponseWriter, oaiErr *OpenAIErrorResponse) { + if oaiErr == nil { + return + } + + w.WriteHeader(oaiErr.StatusCode) + w.Header().Set("Content-Type", "application/json") + + out, err := json.Marshal(oaiErr) + if err != nil { + i.logger.Warn(context.Background(), "failed to marshal upstream error", slog.Error(err), slog.F("error_payload", slog.F("%+v", oaiErr))) + // Response has to match expected format. + _, _ = w.Write([]byte(`{ + "error": { + "type": "error", + "message":"error marshaling upstream error", + "code": "server_error" + }, +}`)) + } else { + _, _ = w.Write(out) + } +} diff --git a/intercept_openai_chat_blocking.go b/intercept_openai_chat_blocking.go index 3b1fa7e..4c019a6 100644 --- a/intercept_openai_chat_blocking.go +++ b/intercept_openai_chat_blocking.go @@ -3,7 +3,6 @@ package aibridge import ( "bytes" "encoding/json" - "errors" "fmt" "net/http" "strings" @@ -42,7 +41,7 @@ func (i *OpenAIBlockingChatInterception) ProcessRequest(w http.ResponseWriter, r } ctx := r.Context() - client := newOpenAIClient(i.baseURL, i.key) + client := i.newOpenAIClient(i.baseURL, i.key) logger := i.logger.With(slog.F("model", i.req.Model)) var ( @@ -184,18 +183,15 @@ func (i *OpenAIBlockingChatInterception) ProcessRequest(w http.ResponseWriter, r } } - // TODO: these probably have to be formatted as JSON errs? if err != nil { if isConnError(err) { http.Error(w, err.Error(), http.StatusInternalServerError) return fmt.Errorf("upstream connection closed: %w", err) } - logger.Warn(ctx, "openai API error", slog.Error(err)) - var apierr *openai.Error - if errors.As(err, &apierr) { - http.Error(w, apierr.Message, apierr.StatusCode) - return fmt.Errorf("api error: %w", apierr) + if apiErr := getOpenAIErrorResponse(err); apiErr != nil { + i.writeUpstreamError(w, apiErr) + return fmt.Errorf("openai API error: %w", err) } http.Error(w, err.Error(), http.StatusInternalServerError) diff --git a/intercept_openai_chat_streaming.go b/intercept_openai_chat_streaming.go index 0c5f554..50c6bbd 100644 --- a/intercept_openai_chat_streaming.go +++ b/intercept_openai_chat_streaming.go @@ -65,7 +65,7 @@ func (i *OpenAIStreamingChatInterception) ProcessRequest(w http.ResponseWriter, defer cancel() r = r.WithContext(ctx) // Rewire context for SSE cancellation. - client := newOpenAIClient(i.baseURL, i.key) + client := i.newOpenAIClient(i.baseURL, i.key) logger := i.logger.With(slog.F("model", i.req.Model)) streamCtx, streamCancel := context.WithCancelCause(ctx) @@ -73,7 +73,6 @@ func (i *OpenAIStreamingChatInterception) ProcessRequest(w http.ResponseWriter, // events will either terminate when shutdown after interaction with upstream completes, or when streamCtx is done. events := newEventStream(streamCtx, logger.Named("sse-sender"), nil) - go events.run(w, r) defer func() { _ = events.Shutdown(streamCtx) // Catch-all in case it doesn't get shutdown after stream completes. }() @@ -106,6 +105,9 @@ func (i *OpenAIStreamingChatInterception) ProcessRequest(w http.ResponseWriter, var toolCall *openai.FinishedChatCompletionToolCall for stream.Next() { + // Only start the event stream if the upstream starts streaming (as opposed to erroring out prematurely). + go events.run(w, r) + chunk := stream.Current() canRelay := processor.process(chunk) @@ -172,35 +174,40 @@ func (i *OpenAIStreamingChatInterception) ProcessRequest(w http.ResponseWriter, }) } - // Check if the stream encountered any errors. - if streamErr := stream.Err(); streamErr != nil { - if isUnrecoverableError(streamErr) { - logger.Debug(ctx, "stream terminated", slog.Error(streamErr)) - // We can't reflect an error back if there's a connection error or the request context was canceled. - } else if oaiErr := getOpenAIErrorResponse(streamErr); oaiErr != nil { - logger.Warn(ctx, "openai stream error", slog.Error(streamErr)) - interceptionErr = oaiErr - } else { - logger.Warn(ctx, "unknown error", slog.Error(streamErr)) - // Unfortunately, the OpenAI SDK does not support parsing errors received in the stream - // into known types (i.e. [shared.OverloadedError]). - // See https://github.com/openai/openai-go/blob/v2.7.0/packages/ssestream/ssestream.go#L171 - // All it does is wrap the payload in an error - which is all we can return, currently. - interceptionErr = newOpenAIErr(fmt.Errorf("unknown stream error: %w", streamErr)) + if events.isRunning() { + // Check if the stream encountered any errors. + if streamErr := stream.Err(); streamErr != nil { + if isUnrecoverableError(streamErr) { + logger.Debug(ctx, "stream terminated", slog.Error(streamErr)) + // We can't reflect an error back if there's a connection error or the request context was canceled. + } else if oaiErr := getOpenAIErrorResponse(streamErr); oaiErr != nil { + logger.Warn(ctx, "openai stream error", slog.Error(streamErr)) + interceptionErr = oaiErr + } else { + logger.Warn(ctx, "unknown error", slog.Error(streamErr)) + // Unfortunately, the OpenAI SDK does not support parsing errors received in the stream + // into known types (i.e. [shared.OverloadedError]). + // See https://github.com/openai/openai-go/blob/v2.7.0/packages/ssestream/ssestream.go#L171 + // All it does is wrap the payload in an error - which is all we can return, currently. + interceptionErr = newOpenAIErr(fmt.Errorf("unknown stream error: %w", streamErr)) + } + } else if lastErr != nil { + // Otherwise check if any logical errors occurred during processing. + logger.Warn(ctx, "stream failed", slog.Error(lastErr)) + interceptionErr = newOpenAIErr(fmt.Errorf("processing error: %w", lastErr)) } - } else if lastErr != nil { - // Otherwise check if any logical errors occurred during processing. - logger.Warn(ctx, "stream failed", slog.Error(lastErr)) - interceptionErr = newOpenAIErr(fmt.Errorf("processing error: %w", lastErr)) - } - if interceptionErr != nil { - payload, err := i.marshalErr(interceptionErr) - if err != nil { - logger.Warn(ctx, "failed to marshal error", slog.Error(err), slog.F("error_payload", slog.F("%+v", interceptionErr))) - } else if err := events.Send(streamCtx, payload); err != nil { - logger.Warn(ctx, "failed to relay error", slog.Error(err), slog.F("payload", payload)) + if interceptionErr != nil { + payload, err := i.marshalErr(interceptionErr) + if err != nil { + logger.Warn(ctx, "failed to marshal error", slog.Error(err), slog.F("error_payload", slog.F("%+v", interceptionErr))) + } else if err := events.Send(streamCtx, payload); err != nil { + logger.Warn(ctx, "failed to relay error", slog.Error(err), slog.F("payload", payload)) + } } + } else { + // Stream has not started yet; write to response if present. + i.writeUpstreamError(w, getOpenAIErrorResponse(stream.Err())) } // No tool call, nothing more to do. diff --git a/openai.go b/openai.go index dc3abc8..3a02fb9 100644 --- a/openai.go +++ b/openai.go @@ -1,14 +1,12 @@ package aibridge import ( - "encoding/json" "errors" - "github.com/anthropics/anthropic-sdk-go/shared" - "github.com/anthropics/anthropic-sdk-go/shared/constant" "github.com/coder/aibridge/utils" "github.com/openai/openai-go/v2" "github.com/openai/openai-go/v2/packages/param" + "github.com/openai/openai-go/v2/shared" ) // ChatCompletionNewParamsWrapper exists because the "stream" param is not included in openai.ChatCompletionNewParams. @@ -106,57 +104,41 @@ func calculateActualInputTokenUsage(in openai.CompletionUsage) int64 { } func getOpenAIErrorResponse(err error) *OpenAIErrorResponse { - var apierr *openai.Error - if !errors.As(err, &apierr) { + var apiErr *openai.Error + if !errors.As(err, &apiErr) { return nil } - msg := apierr.Error() - typ := string(constant.ValueOf[constant.APIError]()) - - var detail *shared.APIErrorObject - if field, ok := apierr.JSON.ExtraFields["error"]; ok { - _ = json.Unmarshal([]byte(field.Raw()), &detail) - } - if detail != nil { - msg = detail.Message - typ = string(detail.Type) - } - return &OpenAIErrorResponse{ - ErrorResponse: &shared.ErrorResponse{ - Error: shared.ErrorObjectUnion{ - Message: msg, - Type: typ, - }, - Type: constant.ValueOf[constant.Error](), + ErrorObject: &shared.ErrorObject{ + Code: apiErr.Code, + Message: apiErr.Message, + Type: apiErr.Type, }, - StatusCode: apierr.StatusCode, + StatusCode: apiErr.StatusCode, } } var _ error = &OpenAIErrorResponse{} type OpenAIErrorResponse struct { - *shared.ErrorResponse - - StatusCode int `json:"-"` + ErrorObject *shared.ErrorObject `json:"error"` + StatusCode int `json:"-"` } func newOpenAIErr(msg error) *OpenAIErrorResponse { return &OpenAIErrorResponse{ - ErrorResponse: &shared.ErrorResponse{ - Error: shared.ErrorObjectUnion{ - Message: msg.Error(), - Type: "error", - }, + ErrorObject: &shared.ErrorObject{ + Code: "error", + Message: msg.Error(), + Type: "error", }, } } func (a *OpenAIErrorResponse) Error() string { - if a.ErrorResponse == nil { + if a.ErrorObject == nil { return "" } - return a.ErrorResponse.Error.Message + return a.ErrorObject.Message } diff --git a/provider_openai.go b/provider_openai.go index 3a3db45..0fc31a6 100644 --- a/provider_openai.go +++ b/provider_openai.go @@ -8,8 +8,6 @@ import ( "os" "github.com/google/uuid" - "github.com/openai/openai-go/v2" - "github.com/openai/openai-go/v2/option" ) var _ Provider = &OpenAIProvider{} @@ -100,11 +98,3 @@ func (p *OpenAIProvider) InjectAuthHeader(headers *http.Header) { headers.Set(p.AuthHeader(), "Bearer "+p.key) } - -func newOpenAIClient(baseURL, key string) openai.Client { - var opts []option.RequestOption - opts = append(opts, option.WithAPIKey(key)) - opts = append(opts, option.WithBaseURL(baseURL)) - - return openai.NewClient(opts...) -} From 453992c97dae97d5077ff467bf1fb4cee0bdb85b Mon Sep 17 00:00:00 2001 From: Danny Kopping Date: Thu, 13 Nov 2025 14:28:06 +0200 Subject: [PATCH 04/10] chore: self-review Signed-off-by: Danny Kopping --- intercept_anthropic_messages_base.go | 4 ++-- intercept_anthropic_messages_streaming.go | 6 +++++- intercept_openai_chat_base.go | 2 +- intercept_openai_chat_streaming.go | 6 +++++- streaming.go | 9 ++------- 5 files changed, 15 insertions(+), 12 deletions(-) diff --git a/intercept_anthropic_messages_base.go b/intercept_anthropic_messages_base.go index 0215d6c..2367933 100644 --- a/intercept_anthropic_messages_base.go +++ b/intercept_anthropic_messages_base.go @@ -180,8 +180,8 @@ func (i *AnthropicMessagesInterceptionBase) writeUpstreamError(w http.ResponseWr return } - w.WriteHeader(antErr.StatusCode) w.Header().Set("Content-Type", "application/json") + w.WriteHeader(antErr.StatusCode) out, err := json.Marshal(antErr) if err != nil { @@ -194,7 +194,7 @@ func (i *AnthropicMessagesInterceptionBase) writeUpstreamError(w http.ResponseWr "type": "error", "message":"error marshaling upstream error" }, - "request_id": "%s", + "request_id": "%s" }`, i.ID().String()))) } else { _, _ = w.Write(out) diff --git a/intercept_anthropic_messages_streaming.go b/intercept_anthropic_messages_streaming.go index f3cbf91..b191701 100644 --- a/intercept_anthropic_messages_streaming.go +++ b/intercept_anthropic_messages_streaming.go @@ -8,6 +8,7 @@ import ( "fmt" "net/http" "strings" + "sync" "time" "github.com/anthropics/anthropic-sdk-go" @@ -96,6 +97,7 @@ func (i *AnthropicMessagesStreamingInterception) ProcessRequest(w http.ResponseW } // events will either terminate when shutdown after interaction with upstream completes, or when streamCtx is done. + var runOnce sync.Once events := newEventStream(streamCtx, logger.Named("sse-sender"), i.pingPayload()) defer func() { _ = events.Shutdown(streamCtx) // Catch-all in case it doesn't get shutdown after stream completes. @@ -126,7 +128,9 @@ newStream: for stream.Next() { // Only start the event stream if the upstream starts streaming (as opposed to erroring out prematurely). - go events.run(w, r) + runOnce.Do(func() { + go events.run(w, r) + }) event := stream.Current() if err := message.Accumulate(event); err != nil { diff --git a/intercept_openai_chat_base.go b/intercept_openai_chat_base.go index 1fa3a8b..44ef582 100644 --- a/intercept_openai_chat_base.go +++ b/intercept_openai_chat_base.go @@ -109,8 +109,8 @@ func (i *OpenAIChatInterceptionBase) writeUpstreamError(w http.ResponseWriter, o return } - w.WriteHeader(oaiErr.StatusCode) w.Header().Set("Content-Type", "application/json") + w.WriteHeader(oaiErr.StatusCode) out, err := json.Marshal(oaiErr) if err != nil { diff --git a/intercept_openai_chat_streaming.go b/intercept_openai_chat_streaming.go index 50c6bbd..68a0f34 100644 --- a/intercept_openai_chat_streaming.go +++ b/intercept_openai_chat_streaming.go @@ -8,6 +8,7 @@ import ( "fmt" "net/http" "strings" + "sync" "time" "github.com/coder/aibridge/mcp" @@ -72,6 +73,7 @@ func (i *OpenAIStreamingChatInterception) ProcessRequest(w http.ResponseWriter, defer streamCancel(errors.New("deferred")) // events will either terminate when shutdown after interaction with upstream completes, or when streamCtx is done. + var runOnce sync.Once events := newEventStream(streamCtx, logger.Named("sse-sender"), nil) defer func() { _ = events.Shutdown(streamCtx) // Catch-all in case it doesn't get shutdown after stream completes. @@ -106,7 +108,9 @@ func (i *OpenAIStreamingChatInterception) ProcessRequest(w http.ResponseWriter, for stream.Next() { // Only start the event stream if the upstream starts streaming (as opposed to erroring out prematurely). - go events.run(w, r) + runOnce.Do(func() { + go events.run(w, r) + }) chunk := stream.Current() diff --git a/streaming.go b/streaming.go index 15950df..ac8dda6 100644 --- a/streaming.go +++ b/streaming.go @@ -54,11 +54,11 @@ func newEventStream(ctx context.Context, logger slog.Logger, pingPayload []byte) // run handles sending Server-Sent Event to the client. func (s *eventStream) run(w http.ResponseWriter, r *http.Request) { // Only one instance is allowed to run. - if s.running.Load() { + if swapped := s.running.CompareAndSwap(false, true); !swapped { + // Value has not changed; instance is already running. return } - s.running.Store(true) defer func() { // Signal completion on exit so senders don't block indefinitely after closure. close(s.doneCh) @@ -160,10 +160,6 @@ func (s *eventStream) sendRaw(ctx context.Context, payload []byte) error { // Shutdown gracefully shuts down the stream, sending any supplementary events downstream if required. // ONLY call this once all events have been submitted. func (s *eventStream) Shutdown(shutdownCtx context.Context) error { - defer func() { - s.running.Store(false) - }() - s.shutdownOnce.Do(func() { s.logger.Debug(shutdownCtx, "shutdown initiated", slog.F("outstanding_events", len(s.eventsCh))) @@ -172,7 +168,6 @@ func (s *eventStream) Shutdown(shutdownCtx context.Context) error { close(s.eventsCh) }) - // TODO: consider the safety of this approach. if !s.running.Load() { return nil } From f1eebe18e59693037de355ee9617ada8979c7264 Mon Sep 17 00:00:00 2001 From: Danny Kopping Date: Thu, 13 Nov 2025 16:49:23 +0200 Subject: [PATCH 05/10] chore: refactor away from atomic to compound mutex + bool Signed-off-by: Danny Kopping --- intercept_anthropic_messages_streaming.go | 2 +- intercept_openai_chat_streaming.go | 2 +- streaming.go | 63 ++++++++++++++++------- 3 files changed, 46 insertions(+), 21 deletions(-) diff --git a/intercept_anthropic_messages_streaming.go b/intercept_anthropic_messages_streaming.go index b191701..949f160 100644 --- a/intercept_anthropic_messages_streaming.go +++ b/intercept_anthropic_messages_streaming.go @@ -129,7 +129,7 @@ newStream: for stream.Next() { // Only start the event stream if the upstream starts streaming (as opposed to erroring out prematurely). runOnce.Do(func() { - go events.run(w, r) + go events.start(w, r) }) event := stream.Current() diff --git a/intercept_openai_chat_streaming.go b/intercept_openai_chat_streaming.go index 68a0f34..c1521b8 100644 --- a/intercept_openai_chat_streaming.go +++ b/intercept_openai_chat_streaming.go @@ -109,7 +109,7 @@ func (i *OpenAIStreamingChatInterception) ProcessRequest(w http.ResponseWriter, for stream.Next() { // Only start the event stream if the upstream starts streaming (as opposed to erroring out prematurely). runOnce.Do(func() { - go events.run(w, r) + go events.start(w, r) }) chunk := stream.Current() diff --git a/streaming.go b/streaming.go index ac8dda6..e8f523b 100644 --- a/streaming.go +++ b/streaming.go @@ -9,7 +9,6 @@ import ( "net/http" "strings" "sync" - "sync/atomic" "syscall" "time" @@ -28,13 +27,15 @@ type eventStream struct { pingPayload []byte - running atomic.Bool - closeOnce sync.Once shutdownOnce sync.Once + doneOnce sync.Once eventsCh chan event - // doneCh is closed when the run loop exits. + // startedMu protects the started flag. + startedMu sync.Mutex + started bool + // doneCh is closed when the start loop exits. doneCh chan struct{} } @@ -51,19 +52,23 @@ func newEventStream(ctx context.Context, logger slog.Logger, pingPayload []byte) } } -// run handles sending Server-Sent Event to the client. -func (s *eventStream) run(w http.ResponseWriter, r *http.Request) { - // Only one instance is allowed to run. - if swapped := s.running.CompareAndSwap(false, true); !swapped { - // Value has not changed; instance is already running. +// start handles sending Server-Sent Event to the client. +func (s *eventStream) start(w http.ResponseWriter, r *http.Request) { + // Atomically signal that streaming has started + s.startedMu.Lock() + if s.started { + // Another goroutine is already running. + s.startedMu.Unlock() return } + s.started = true + s.startedMu.Unlock() defer func() { // Signal completion on exit so senders don't block indefinitely after closure. - close(s.doneCh) - - s.running.Store(false) + s.doneOnce.Do(func() { + close(s.doneCh) + }) }() ctx := r.Context() @@ -160,31 +165,51 @@ func (s *eventStream) sendRaw(ctx context.Context, payload []byte) error { // Shutdown gracefully shuts down the stream, sending any supplementary events downstream if required. // ONLY call this once all events have been submitted. func (s *eventStream) Shutdown(shutdownCtx context.Context) error { + var shutdownErr error + s.shutdownOnce.Do(func() { s.logger.Debug(shutdownCtx, "shutdown initiated", slog.F("outstanding_events", len(s.eventsCh))) - // Now it is safe to close the events channel; the run loop will exit + // Now it is safe to close the events channel; the start() loop will exit // after draining remaining events and receivers will stop ranging. close(s.eventsCh) }) - if !s.running.Load() { + // Atomically check if start() was called and close doneCh if it wasn't. + s.startedMu.Lock() + if !s.started { + // start() was never called, close doneCh ourselves so we don't block forever. + s.doneOnce.Do(func() { + close(s.doneCh) + }) + s.startedMu.Unlock() return nil } + // start() was called (or is about to be called), it will close doneCh when it finishes. + s.startedMu.Unlock() + // Wait for start() to complete. We must ALWAYS wait for doneCh to prevent + // races with http.ResponseWriter cleanup, even if contexts are cancelled. select { case <-shutdownCtx.Done(): - // If shutdownCtx completes, shutdown likely exceeded its timeout. - return fmt.Errorf("shutdown ended prematurely with %d outstanding events: %w", len(s.eventsCh), shutdownCtx.Err()) + shutdownErr = fmt.Errorf("shutdown timeout with %d outstanding events: %w", len(s.eventsCh), shutdownCtx.Err()) case <-s.ctx.Done(): - return fmt.Errorf("shutdown ended prematurely with %d outstanding events: %w", len(s.eventsCh), s.ctx.Err()) + shutdownErr = fmt.Errorf("shutdown cancelled with %d outstanding events: %w", len(s.eventsCh), s.ctx.Err()) case <-s.doneCh: - return nil + // Goroutine has finished. + return shutdownErr } + + // If we got here due to context cancellation/timeout, we MUST still wait for doneCh + // to ensure the goroutine has stopped using the ResponseWriter before the HTTP handler returns. + <-s.doneCh + return shutdownErr } func (s *eventStream) isRunning() bool { - return s.running.Load() + s.startedMu.Lock() + defer s.startedMu.Unlock() + return s.started } // isConnError checks if an error is related to client disconnection or context cancellation. From d73b3bffa318de6f27a7cc9ad03cebcefd2506ee Mon Sep 17 00:00:00 2001 From: Danny Kopping Date: Thu, 13 Nov 2025 16:49:35 +0200 Subject: [PATCH 06/10] chore: drive-by flake fix Signed-off-by: Danny Kopping --- bridge_integration_test.go | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/bridge_integration_test.go b/bridge_integration_test.go index b9b2745..2b39325 100644 --- a/bridge_integration_test.go +++ b/bridge_integration_test.go @@ -673,7 +673,7 @@ func TestFallthrough(t *testing.T) { files := filesMap(arc) require.Contains(t, files, fixtureResponse) - var receivedHeaders *http.Header + var receivedHeaders atomic.Pointer[http.Header] respBody := files[fixtureResponse] upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if r.URL.Path != "/v1/models" { @@ -685,7 +685,7 @@ func TestFallthrough(t *testing.T) { w.WriteHeader(http.StatusOK) _, _ = w.Write(respBody) - receivedHeaders = &r.Header + receivedHeaders.Store(&r.Header) })) t.Cleanup(upstream.Close) @@ -710,8 +710,8 @@ func TestFallthrough(t *testing.T) { require.Equal(t, http.StatusOK, resp.StatusCode) // Ensure that the API key was sent. - require.NotNil(t, receivedHeaders) - require.Contains(t, receivedHeaders.Get(provider.AuthHeader()), apiKey) + require.NotNil(t, receivedHeaders.Load()) + require.Contains(t, receivedHeaders.Load().Get(provider.AuthHeader()), apiKey) gotBytes, err := io.ReadAll(resp.Body) require.NoError(t, err) From 2e492f23c0afa4adecc15143705a87388e2e556c Mon Sep 17 00:00:00 2001 From: Danny Kopping Date: Fri, 14 Nov 2025 09:19:51 +0200 Subject: [PATCH 07/10] chore: simplify approach Signed-off-by: Danny Kopping --- intercept_anthropic_messages_streaming.go | 10 +-- intercept_openai_chat_streaming.go | 10 +-- streaming.go | 99 +++++++++-------------- 3 files changed, 42 insertions(+), 77 deletions(-) diff --git a/intercept_anthropic_messages_streaming.go b/intercept_anthropic_messages_streaming.go index 949f160..dada47b 100644 --- a/intercept_anthropic_messages_streaming.go +++ b/intercept_anthropic_messages_streaming.go @@ -8,7 +8,6 @@ import ( "fmt" "net/http" "strings" - "sync" "time" "github.com/anthropics/anthropic-sdk-go" @@ -97,8 +96,8 @@ func (i *AnthropicMessagesStreamingInterception) ProcessRequest(w http.ResponseW } // events will either terminate when shutdown after interaction with upstream completes, or when streamCtx is done. - var runOnce sync.Once events := newEventStream(streamCtx, logger.Named("sse-sender"), i.pingPayload()) + go events.start(w, r) defer func() { _ = events.Shutdown(streamCtx) // Catch-all in case it doesn't get shutdown after stream completes. }() @@ -127,11 +126,6 @@ newStream: pendingToolCalls := make(map[string]string) for stream.Next() { - // Only start the event stream if the upstream starts streaming (as opposed to erroring out prematurely). - runOnce.Do(func() { - go events.start(w, r) - }) - event := stream.Current() if err := message.Accumulate(event); err != nil { logger.Warn(ctx, "failed to accumulate streaming events", slog.Error(err), slog.F("event", event), slog.F("msg", message.RawJSON())) @@ -420,7 +414,7 @@ newStream: prompt = nil } - if events.isRunning() { + if events.hasInitiated() { // Check if the stream encountered any errors. if streamErr := stream.Err(); streamErr != nil { if isUnrecoverableError(streamErr) { diff --git a/intercept_openai_chat_streaming.go b/intercept_openai_chat_streaming.go index c1521b8..b89d0fb 100644 --- a/intercept_openai_chat_streaming.go +++ b/intercept_openai_chat_streaming.go @@ -8,7 +8,6 @@ import ( "fmt" "net/http" "strings" - "sync" "time" "github.com/coder/aibridge/mcp" @@ -73,8 +72,8 @@ func (i *OpenAIStreamingChatInterception) ProcessRequest(w http.ResponseWriter, defer streamCancel(errors.New("deferred")) // events will either terminate when shutdown after interaction with upstream completes, or when streamCtx is done. - var runOnce sync.Once events := newEventStream(streamCtx, logger.Named("sse-sender"), nil) + go events.start(w, r) defer func() { _ = events.Shutdown(streamCtx) // Catch-all in case it doesn't get shutdown after stream completes. }() @@ -107,11 +106,6 @@ func (i *OpenAIStreamingChatInterception) ProcessRequest(w http.ResponseWriter, var toolCall *openai.FinishedChatCompletionToolCall for stream.Next() { - // Only start the event stream if the upstream starts streaming (as opposed to erroring out prematurely). - runOnce.Do(func() { - go events.start(w, r) - }) - chunk := stream.Current() canRelay := processor.process(chunk) @@ -178,7 +172,7 @@ func (i *OpenAIStreamingChatInterception) ProcessRequest(w http.ResponseWriter, }) } - if events.isRunning() { + if events.hasInitiated() { // Check if the stream encountered any errors. if streamErr := stream.Err(); streamErr != nil { if isUnrecoverableError(streamErr) { diff --git a/streaming.go b/streaming.go index e8f523b..6e55773 100644 --- a/streaming.go +++ b/streaming.go @@ -9,6 +9,7 @@ import ( "net/http" "strings" "sync" + "sync/atomic" "syscall" "time" @@ -27,14 +28,13 @@ type eventStream struct { pingPayload []byte + initiated atomic.Bool + + initiateOnce sync.Once closeOnce sync.Once shutdownOnce sync.Once - doneOnce sync.Once eventsCh chan event - // startedMu protects the started flag. - startedMu sync.Mutex - started bool // doneCh is closed when the start loop exits. doneCh chan struct{} } @@ -54,39 +54,15 @@ func newEventStream(ctx context.Context, logger slog.Logger, pingPayload []byte) // start handles sending Server-Sent Event to the client. func (s *eventStream) start(w http.ResponseWriter, r *http.Request) { - // Atomically signal that streaming has started - s.startedMu.Lock() - if s.started { - // Another goroutine is already running. - s.startedMu.Unlock() - return - } - s.started = true - s.startedMu.Unlock() - - defer func() { - // Signal completion on exit so senders don't block indefinitely after closure. - s.doneOnce.Do(func() { - close(s.doneCh) - }) - }() + // Signal completion on exit so senders don't block indefinitely after closure. + defer close(s.doneCh) ctx := r.Context() - w.Header().Set("Content-Type", "text/event-stream") - w.Header().Set("Cache-Control", "no-cache") - w.Header().Set("Connection", "keep-alive") - w.Header().Set("X-Accel-Buffering", "no") - - // Send initial flush to ensure connection is established. - if err := flush(w); err != nil { - http.Error(w, err.Error(), http.StatusInternalServerError) - return - } - // Send periodic pings to keep connections alive. // The upstream provider may also send their own pings, but we can't rely on this. - tick := time.NewTicker(pingInterval) + tick := time.NewTicker(time.Nanosecond) + tick.Stop() // Ticker will start after stream initiation. defer tick.Stop() for { @@ -101,10 +77,30 @@ func (s *eventStream) start(w http.ResponseWriter, r *http.Request) { case <-ctx.Done(): s.logger.Debug(ctx, "request context canceled", slog.Error(ctx.Err())) return - case ev, open = <-s.eventsCh: + case ev, open = <-s.eventsCh: // Once closed, the buffered channel will drain all buffered values before showing as closed. if !open { return } + + // Initiate the stream once the first event is received. + s.initiateOnce.Do(func() { + s.initiated.Store(true) + + // Send headers for Server-Sent Event stream. + w.Header().Set("Content-Type", "text/event-stream") + w.Header().Set("Cache-Control", "no-cache") + w.Header().Set("Connection", "keep-alive") + w.Header().Set("X-Accel-Buffering", "no") + + // Send initial flush to ensure connection is established. + if err := flush(w); err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + + // Start ping ticker. + tick.Reset(pingInterval) + }) case <-tick.C: ev = s.pingPayload if ev == nil { @@ -165,8 +161,6 @@ func (s *eventStream) sendRaw(ctx context.Context, payload []byte) error { // Shutdown gracefully shuts down the stream, sending any supplementary events downstream if required. // ONLY call this once all events have been submitted. func (s *eventStream) Shutdown(shutdownCtx context.Context) error { - var shutdownErr error - s.shutdownOnce.Do(func() { s.logger.Debug(shutdownCtx, "shutdown initiated", slog.F("outstanding_events", len(s.eventsCh))) @@ -175,41 +169,24 @@ func (s *eventStream) Shutdown(shutdownCtx context.Context) error { close(s.eventsCh) }) - // Atomically check if start() was called and close doneCh if it wasn't. - s.startedMu.Lock() - if !s.started { - // start() was never called, close doneCh ourselves so we don't block forever. - s.doneOnce.Do(func() { - close(s.doneCh) - }) - s.startedMu.Unlock() - return nil - } - // start() was called (or is about to be called), it will close doneCh when it finishes. - s.startedMu.Unlock() - - // Wait for start() to complete. We must ALWAYS wait for doneCh to prevent - // races with http.ResponseWriter cleanup, even if contexts are cancelled. + var err error select { case <-shutdownCtx.Done(): - shutdownErr = fmt.Errorf("shutdown timeout with %d outstanding events: %w", len(s.eventsCh), shutdownCtx.Err()) + // If shutdownCtx completes, shutdown likely exceeded its timeout. + err = fmt.Errorf("shutdown ended prematurely with %d outstanding events: %w", len(s.eventsCh), shutdownCtx.Err()) case <-s.ctx.Done(): - shutdownErr = fmt.Errorf("shutdown cancelled with %d outstanding events: %w", len(s.eventsCh), s.ctx.Err()) + err = fmt.Errorf("shutdown ended prematurely with %d outstanding events: %w", len(s.eventsCh), s.ctx.Err()) case <-s.doneCh: - // Goroutine has finished. - return shutdownErr + return nil } - // If we got here due to context cancellation/timeout, we MUST still wait for doneCh - // to ensure the goroutine has stopped using the ResponseWriter before the HTTP handler returns. + // Even if the context is canceled, we need to wait for start() to complete. <-s.doneCh - return shutdownErr + return err } -func (s *eventStream) isRunning() bool { - s.startedMu.Lock() - defer s.startedMu.Unlock() - return s.started +func (s *eventStream) hasInitiated() bool { + return s.initiated.Load() } // isConnError checks if an error is related to client disconnection or context cancellation. From 7e07b1d48460cd0d60c9c404ac7b3b06b2b4ce2a Mon Sep 17 00:00:00 2001 From: Danny Kopping Date: Fri, 14 Nov 2025 09:56:21 +0200 Subject: [PATCH 08/10] chore: fix flake due to order of operations Signed-off-by: Danny Kopping --- bridge_integration_test.go | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/bridge_integration_test.go b/bridge_integration_test.go index 2b39325..8768e0d 100644 --- a/bridge_integration_test.go +++ b/bridge_integration_test.go @@ -709,13 +709,13 @@ func TestFallthrough(t *testing.T) { require.Equal(t, http.StatusOK, resp.StatusCode) + gotBytes, err := io.ReadAll(resp.Body) + require.NoError(t, err) + // Ensure that the API key was sent. require.NotNil(t, receivedHeaders.Load()) require.Contains(t, receivedHeaders.Load().Get(provider.AuthHeader()), apiKey) - gotBytes, err := io.ReadAll(resp.Body) - require.NoError(t, err) - // Compare JSON bodies for semantic equality. var got any var exp any From d94d8466ac389fb9469381431aef198e688cd338 Mon Sep 17 00:00:00 2001 From: Danny Kopping Date: Fri, 14 Nov 2025 09:58:10 +0200 Subject: [PATCH 09/10] chore: fixing race Signed-off-by: Danny Kopping --- intercept_anthropic_messages_streaming.go | 2 +- intercept_openai_chat_streaming.go | 2 +- streaming.go | 18 ++++++++++++++---- 3 files changed, 16 insertions(+), 6 deletions(-) diff --git a/intercept_anthropic_messages_streaming.go b/intercept_anthropic_messages_streaming.go index dada47b..15bb6d8 100644 --- a/intercept_anthropic_messages_streaming.go +++ b/intercept_anthropic_messages_streaming.go @@ -414,7 +414,7 @@ newStream: prompt = nil } - if events.hasInitiated() { + if events.isStreaming() { // Check if the stream encountered any errors. if streamErr := stream.Err(); streamErr != nil { if isUnrecoverableError(streamErr) { diff --git a/intercept_openai_chat_streaming.go b/intercept_openai_chat_streaming.go index b89d0fb..cc1a64a 100644 --- a/intercept_openai_chat_streaming.go +++ b/intercept_openai_chat_streaming.go @@ -172,7 +172,7 @@ func (i *OpenAIStreamingChatInterception) ProcessRequest(w http.ResponseWriter, }) } - if events.hasInitiated() { + if events.isStreaming() { // Check if the stream encountered any errors. if streamErr := stream.Err(); streamErr != nil { if isUnrecoverableError(streamErr) { diff --git a/streaming.go b/streaming.go index 6e55773..a3216b5 100644 --- a/streaming.go +++ b/streaming.go @@ -28,9 +28,9 @@ type eventStream struct { pingPayload []byte - initiated atomic.Bool - + initiated atomic.Bool initiateOnce sync.Once + closeOnce sync.Once shutdownOnce sync.Once eventsCh chan event @@ -79,14 +79,22 @@ func (s *eventStream) start(w http.ResponseWriter, r *http.Request) { return case ev, open = <-s.eventsCh: // Once closed, the buffered channel will drain all buffered values before showing as closed. if !open { + s.logger.Debug(ctx, "events channel closed") return } // Initiate the stream once the first event is received. s.initiateOnce.Do(func() { s.initiated.Store(true) + s.logger.Debug(ctx, "stream initiated") // Send headers for Server-Sent Event stream. + // + // We only send these once an event is processed because an error can occur in the upstream + // request prior to the stream starting, in which case the SSE headers are inappropriate to + // send to the client. + // + // See use of isStreaming(). w.Header().Set("Content-Type", "text/event-stream") w.Header().Set("Cache-Control", "no-cache") w.Header().Set("Connection", "keep-alive") @@ -185,8 +193,10 @@ func (s *eventStream) Shutdown(shutdownCtx context.Context) error { return err } -func (s *eventStream) hasInitiated() bool { - return s.initiated.Load() +// isStreaming checks if the stream has been initiated, or +// when events are buffered which - when processed - will initiate the stream. +func (s *eventStream) isStreaming() bool { + return s.initiated.Load() || len(s.eventsCh) > 0 } // isConnError checks if an error is related to client disconnection or context cancellation. From 74ab7583572bd14f3b97538814fe045fdf9f631b Mon Sep 17 00:00:00 2001 From: Danny Kopping Date: Mon, 17 Nov 2025 14:11:03 +0200 Subject: [PATCH 10/10] chore: headers drive-by race fix Signed-off-by: Danny Kopping --- bridge_integration_test.go | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/bridge_integration_test.go b/bridge_integration_test.go index 8768e0d..e570dd9 100644 --- a/bridge_integration_test.go +++ b/bridge_integration_test.go @@ -673,7 +673,7 @@ func TestFallthrough(t *testing.T) { files := filesMap(arc) require.Contains(t, files, fixtureResponse) - var receivedHeaders atomic.Pointer[http.Header] + var receivedHeaders *http.Header respBody := files[fixtureResponse] upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if r.URL.Path != "/v1/models" { @@ -681,11 +681,11 @@ func TestFallthrough(t *testing.T) { t.FailNow() } + receivedHeaders = &r.Header + w.Header().Set("Content-Type", "application/json") w.WriteHeader(http.StatusOK) _, _ = w.Write(respBody) - - receivedHeaders.Store(&r.Header) })) t.Cleanup(upstream.Close) @@ -709,13 +709,13 @@ func TestFallthrough(t *testing.T) { require.Equal(t, http.StatusOK, resp.StatusCode) + // Ensure that the API key was sent. + require.NotNil(t, receivedHeaders) + require.Contains(t, receivedHeaders.Get(provider.AuthHeader()), apiKey) + gotBytes, err := io.ReadAll(resp.Body) require.NoError(t, err) - // Ensure that the API key was sent. - require.NotNil(t, receivedHeaders.Load()) - require.Contains(t, receivedHeaders.Load().Get(provider.AuthHeader()), apiKey) - // Compare JSON bodies for semantic equality. var got any var exp any