Skip to content
311 changes: 218 additions & 93 deletions bridge_integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ import (
"github.com/google/uuid"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/tidwall/gjson"

"github.com/openai/openai-go/v2"
oaissestream "github.com/openai/openai-go/v2/packages/ssestream"
Expand All @@ -47,8 +48,10 @@ var (
antSingleInjectedTool []byte
//go:embed fixtures/anthropic/fallthrough.txtar
antFallthrough []byte
//go:embed fixtures/anthropic/error.txtar
antErr []byte
//go:embed fixtures/anthropic/stream_error.txtar
antMidStreamErr []byte
//go:embed fixtures/anthropic/non_stream_error.txtar
antNonStreamErr []byte

//go:embed fixtures/openai/simple.txtar
oaiSimple []byte
Expand All @@ -58,8 +61,10 @@ var (
oaiSingleInjectedTool []byte
//go:embed fixtures/openai/fallthrough.txtar
oaiFallthrough []byte
//go:embed fixtures/openai/error.txtar
oaiErr []byte
//go:embed fixtures/openai/stream_error.txtar
oaiMidStreamErr []byte
//go:embed fixtures/openai/non_stream_error.txtar
oaiNonStreamErr []byte
)

const (
Expand Down Expand Up @@ -676,11 +681,11 @@ func TestFallthrough(t *testing.T) {
t.FailNow()
}

receivedHeaders = &r.Header

w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
_, _ = w.Write(respBody)

receivedHeaders = &r.Header
}))
t.Cleanup(upstream.Close)

Expand Down Expand Up @@ -1009,48 +1014,147 @@ func setupInjectedToolTest(t *testing.T, fixture []byte, streaming bool, configu
func TestErrorHandling(t *testing.T) {
t.Parallel()

cases := []struct {
name string
fixture []byte
createRequestFunc createRequestFunc
configureFunc configureFunc
responseHandlerFn func(streaming bool, resp *http.Response)
}{
{
name: aibridge.ProviderAnthropic,
fixture: antErr,
createRequestFunc: createAnthropicMessagesReq,
configureFunc: func(addr string, client aibridge.Recorder, srvProxyMgr *mcp.ServerProxyManager) (*aibridge.RequestBridge, error) {
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: false}).Leveled(slog.LevelDebug)
return aibridge.NewRequestBridge(t.Context(), []aibridge.Provider{aibridge.NewAnthropicProvider(anthropicCfg(addr, apiKey), nil)}, logger, client, srvProxyMgr)
// Tests that errors which occur *before* a streaming response begins, or in non-streaming requests, are handled as expected.
t.Run("non-stream error", func(t *testing.T) {
cases := []struct {
name string
fixture []byte
createRequestFunc createRequestFunc
configureFunc configureFunc
responseHandlerFn func(resp *http.Response)
}{
{
name: aibridge.ProviderAnthropic,
fixture: antNonStreamErr,
createRequestFunc: createAnthropicMessagesReq,
configureFunc: func(addr string, client aibridge.Recorder, srvProxyMgr *mcp.ServerProxyManager) (*aibridge.RequestBridge, error) {
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: false}).Leveled(slog.LevelDebug)
return aibridge.NewRequestBridge(t.Context(), []aibridge.Provider{aibridge.NewAnthropicProvider(anthropicCfg(addr, apiKey), nil)}, logger, client, srvProxyMgr)
},
responseHandlerFn: func(resp *http.Response) {
require.Equal(t, http.StatusBadRequest, resp.StatusCode)
body, err := io.ReadAll(resp.Body)
require.NoError(t, err)
require.Equal(t, "error", gjson.GetBytes(body, "type").Str)
require.Equal(t, "invalid_request_error", gjson.GetBytes(body, "error.type").Str)
require.Contains(t, gjson.GetBytes(body, "error.message").Str, "prompt is too long")
},
},
responseHandlerFn: func(streaming bool, resp *http.Response) {
if streaming {
{
name: aibridge.ProviderOpenAI,
fixture: oaiNonStreamErr,
createRequestFunc: createOpenAIChatCompletionsReq,
configureFunc: func(addr string, client aibridge.Recorder, srvProxyMgr *mcp.ServerProxyManager) (*aibridge.RequestBridge, error) {
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: false}).Leveled(slog.LevelDebug)
return aibridge.NewRequestBridge(t.Context(), []aibridge.Provider{aibridge.NewOpenAIProvider(aibridge.OpenAIConfig(anthropicCfg(addr, apiKey)))}, logger, client, srvProxyMgr)
},
responseHandlerFn: func(resp *http.Response) {
require.Equal(t, http.StatusBadRequest, resp.StatusCode)
body, err := io.ReadAll(resp.Body)
require.NoError(t, err)
require.Equal(t, "context_length_exceeded", gjson.GetBytes(body, "error.code").Str)
require.Equal(t, "invalid_request_error", gjson.GetBytes(body, "error.type").Str)
require.Contains(t, gjson.GetBytes(body, "error.message").Str, "Input tokens exceed the configured limit")
},
},
}

for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
t.Parallel()

for _, streaming := range []bool{true, false} {
t.Run(fmt.Sprintf("streaming=%v", streaming), func(t *testing.T) {
t.Parallel()

ctx, cancel := context.WithTimeout(t.Context(), time.Second*30)
t.Cleanup(cancel)

arc := txtar.Parse(tc.fixture)
t.Logf("%s: %s", t.Name(), arc.Comment)

files := filesMap(arc)
require.Len(t, files, 3)
require.Contains(t, files, fixtureRequest)
require.Contains(t, files, fixtureStreamingResponse)
require.Contains(t, files, fixtureNonStreamingResponse)

reqBody := files[fixtureRequest]
// Add the stream param to the request.
newBody, err := setJSON(reqBody, "stream", streaming)
require.NoError(t, err)
reqBody = newBody

// Setup mock server.
mockResp := files[fixtureStreamingResponse]
if !streaming {
mockResp = files[fixtureNonStreamingResponse]
}
mockSrv := newMockHTTPReflector(ctx, t, mockResp)
t.Cleanup(mockSrv.Close)

recorderClient := &mockRecorderClient{}

b, err := tc.configureFunc(mockSrv.URL, recorderClient, mcp.NewServerProxyManager(nil))
require.NoError(t, err)

// Invoke request to mocked API via aibridge.
bridgeSrv := httptest.NewUnstartedServer(b)
bridgeSrv.Config.BaseContext = func(_ net.Listener) context.Context {
return aibridge.AsActor(ctx, userID, nil)
}
bridgeSrv.Start()
t.Cleanup(bridgeSrv.Close)

req := tc.createRequestFunc(t, bridgeSrv.URL, reqBody)
resp, err := http.DefaultClient.Do(req)
t.Cleanup(func() { _ = resp.Body.Close() })
require.NoError(t, err)

tc.responseHandlerFn(resp)
recorderClient.verifyAllInterceptionsEnded(t)
})
}
})
}
})

// Tests that errors which occur *during* a streaming response are handled as expected.
t.Run("mid-stream error", func(t *testing.T) {
cases := []struct {
name string
fixture []byte
createRequestFunc createRequestFunc
configureFunc configureFunc
responseHandlerFn func(resp *http.Response)
}{
{
name: aibridge.ProviderAnthropic,
fixture: antMidStreamErr,
createRequestFunc: createAnthropicMessagesReq,
configureFunc: func(addr string, client aibridge.Recorder, srvProxyMgr *mcp.ServerProxyManager) (*aibridge.RequestBridge, error) {
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: false}).Leveled(slog.LevelDebug)
return aibridge.NewRequestBridge(t.Context(), []aibridge.Provider{aibridge.NewAnthropicProvider(anthropicCfg(addr, apiKey), nil)}, logger, client, srvProxyMgr)
},
responseHandlerFn: func(resp *http.Response) {
// Server responds first with 200 OK then starts streaming.
require.Equal(t, http.StatusOK, resp.StatusCode)

sp := aibridge.NewSSEParser()
require.NoError(t, sp.Parse(resp.Body))
require.Len(t, sp.EventsByType("error"), 1)
require.Contains(t, sp.EventsByType("error")[0].Data, "Overloaded")
} else {
require.Equal(t, resp.StatusCode, http.StatusInternalServerError)
body, err := io.ReadAll(resp.Body)
require.NoError(t, err)
require.Contains(t, string(body), "Overloaded")
}
},
},
},
{
name: aibridge.ProviderOpenAI,
fixture: oaiErr,
createRequestFunc: createOpenAIChatCompletionsReq,
configureFunc: func(addr string, client aibridge.Recorder, srvProxyMgr *mcp.ServerProxyManager) (*aibridge.RequestBridge, error) {
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: false}).Leveled(slog.LevelDebug)
return aibridge.NewRequestBridge(t.Context(), []aibridge.Provider{aibridge.NewOpenAIProvider(aibridge.OpenAIConfig(anthropicCfg(addr, apiKey)))}, logger, client, srvProxyMgr)
},
responseHandlerFn: func(streaming bool, resp *http.Response) {
if streaming {
{
name: aibridge.ProviderOpenAI,
fixture: oaiMidStreamErr,
createRequestFunc: createOpenAIChatCompletionsReq,
configureFunc: func(addr string, client aibridge.Recorder, srvProxyMgr *mcp.ServerProxyManager) (*aibridge.RequestBridge, error) {
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: false}).Leveled(slog.LevelDebug)
return aibridge.NewRequestBridge(t.Context(), []aibridge.Provider{aibridge.NewOpenAIProvider(aibridge.OpenAIConfig(anthropicCfg(addr, apiKey)))}, logger, client, srvProxyMgr)
},
responseHandlerFn: func(resp *http.Response) {
// Server responds first with 200 OK then starts streaming.
require.Equal(t, http.StatusOK, resp.StatusCode)

Expand All @@ -1063,72 +1167,55 @@ func TestErrorHandling(t *testing.T) {
errEvent := sp.MessageEvents()[len(sp.MessageEvents())-2] // Last event is termination marker ("[DONE]").
require.NotEmpty(t, errEvent)
require.Contains(t, errEvent.Data, "The server had an error while processing your request. Sorry about that!")
} else {
require.Equal(t, resp.StatusCode, http.StatusInternalServerError)
body, err := io.ReadAll(resp.Body)
require.NoError(t, err)
require.Contains(t, string(body), "The server had an error while processing your request. Sorry about that")
}
},
},
},
}

for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
t.Parallel()

for _, streaming := range []bool{true, false} {
t.Run(fmt.Sprintf("streaming=%v", streaming), func(t *testing.T) {
t.Parallel()
}

ctx, cancel := context.WithTimeout(t.Context(), time.Second*30)
t.Cleanup(cancel)
for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
t.Parallel()

arc := txtar.Parse(tc.fixture)
t.Logf("%s: %s", t.Name(), arc.Comment)
ctx, cancel := context.WithTimeout(t.Context(), time.Second*30)
t.Cleanup(cancel)

files := filesMap(arc)
require.Len(t, files, 3)
require.Contains(t, files, fixtureRequest)
require.Contains(t, files, fixtureStreamingResponse)
require.Contains(t, files, fixtureNonStreamingResponse)
arc := txtar.Parse(tc.fixture)
t.Logf("%s: %s", t.Name(), arc.Comment)

reqBody := files[fixtureRequest]
files := filesMap(arc)
require.Len(t, files, 2)
require.Contains(t, files, fixtureRequest)
require.Contains(t, files, fixtureStreamingResponse)

// Add the stream param to the request.
newBody, err := setJSON(reqBody, "stream", streaming)
require.NoError(t, err)
reqBody = newBody
reqBody := files[fixtureRequest]

// Setup mock server.
mockSrv := newMockServer(ctx, t, files, nil)
mockSrv.statusCode = http.StatusInternalServerError
t.Cleanup(mockSrv.Close)
// Setup mock server.
mockSrv := newMockServer(ctx, t, files, nil)
mockSrv.statusCode = http.StatusInternalServerError
t.Cleanup(mockSrv.Close)

recorderClient := &mockRecorderClient{}
recorderClient := &mockRecorderClient{}

b, err := tc.configureFunc(mockSrv.URL, recorderClient, mcp.NewServerProxyManager(nil))
require.NoError(t, err)
b, err := tc.configureFunc(mockSrv.URL, recorderClient, mcp.NewServerProxyManager(nil))
require.NoError(t, err)

// Invoke request to mocked API via aibridge.
bridgeSrv := httptest.NewUnstartedServer(b)
bridgeSrv.Config.BaseContext = func(_ net.Listener) context.Context {
return aibridge.AsActor(ctx, userID, nil)
}
bridgeSrv.Start()
t.Cleanup(bridgeSrv.Close)
// Invoke request to mocked API via aibridge.
bridgeSrv := httptest.NewUnstartedServer(b)
bridgeSrv.Config.BaseContext = func(_ net.Listener) context.Context {
return aibridge.AsActor(ctx, userID, nil)
}
bridgeSrv.Start()
t.Cleanup(bridgeSrv.Close)

req := tc.createRequestFunc(t, bridgeSrv.URL, reqBody)
resp, err := http.DefaultClient.Do(req)
t.Cleanup(func() { _ = resp.Body.Close() })
require.NoError(t, err)
req := tc.createRequestFunc(t, bridgeSrv.URL, reqBody)
resp, err := http.DefaultClient.Do(req)
t.Cleanup(func() { _ = resp.Body.Close() })
require.NoError(t, err)

tc.responseHandlerFn(streaming, resp)
recorderClient.verifyAllInterceptionsEnded(t)
})
}
})
}
tc.responseHandlerFn(resp)
recorderClient.verifyAllInterceptionsEnded(t)
})
}
})
}

// TestStableRequestEncoding validates that a given intercepted request and a
Expand Down Expand Up @@ -1297,6 +1384,44 @@ func createOpenAIChatCompletionsReq(t *testing.T, baseURL string, input []byte)
return req
}

type mockHTTPReflector struct {
*httptest.Server
}

func newMockHTTPReflector(ctx context.Context, t *testing.T, resp []byte) *mockHTTPReflector {
ref := &mockHTTPReflector{}

srv := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
mock, err := http.ReadResponse(bufio.NewReader(bytes.NewBuffer(resp)), r)
require.NoError(t, err)
defer mock.Body.Close()

// Copy headers from the mocked response.
for key, values := range mock.Header {
for _, value := range values {
w.Header().Add(key, value)
}
}

// Write the status code.
w.WriteHeader(mock.StatusCode)

// Copy the body.
_, err = io.Copy(w, mock.Body)
require.NoError(t, err)
}))
srv.Config.BaseContext = func(_ net.Listener) context.Context {
return ctx
}

srv.Start()
t.Cleanup(srv.Close)

ref.Server = srv
return ref
}

// TODO: replace this with mockHTTPReflector.
type mockServer struct {
*httptest.Server

Expand Down
Loading