From ee485e60471de550ded3b8ef6965fc6ecc499522 Mon Sep 17 00:00:00 2001 From: Dev Agent Date: Thu, 18 Dec 2025 08:12:30 +0000 Subject: [PATCH 1/2] style: embedding token counter use interface instead of struct --- .../aigateway/token/mock_CounterFactory.go | 12 +- .../token/mock_EmbeddingTokenCounter.go | 163 ++++++++++++++++++ aigateway/handler/openai.go | 2 +- .../response_writer_wrapper_embedding.go | 26 +-- .../response_writer_wrapper_embedding_test.go | 87 +++++++++- aigateway/token/embedding_token_counter.go | 20 ++- .../token/embedding_token_counter_test.go | 129 ++++++++++++++ aigateway/token/token_counter.go | 10 +- 8 files changed, 403 insertions(+), 46 deletions(-) create mode 100644 _mocks/opencsg.com/csghub-server/aigateway/token/mock_EmbeddingTokenCounter.go create mode 100644 aigateway/token/embedding_token_counter_test.go diff --git a/_mocks/opencsg.com/csghub-server/aigateway/token/mock_CounterFactory.go b/_mocks/opencsg.com/csghub-server/aigateway/token/mock_CounterFactory.go index 194d218c..383f5a86 100644 --- a/_mocks/opencsg.com/csghub-server/aigateway/token/mock_CounterFactory.go +++ b/_mocks/opencsg.com/csghub-server/aigateway/token/mock_CounterFactory.go @@ -69,19 +69,19 @@ func (_c *MockCounterFactory_NewChat_Call) RunAndReturn(run func(token.CreatePar } // NewEmbedding provides a mock function with given fields: param -func (_m *MockCounterFactory) NewEmbedding(param token.CreateParam) *token.EmbeddingTokenCounter { +func (_m *MockCounterFactory) NewEmbedding(param token.CreateParam) token.EmbeddingTokenCounter { ret := _m.Called(param) if len(ret) == 0 { panic("no return value specified for NewEmbedding") } - var r0 *token.EmbeddingTokenCounter - if rf, ok := ret.Get(0).(func(token.CreateParam) *token.EmbeddingTokenCounter); ok { + var r0 token.EmbeddingTokenCounter + if rf, ok := ret.Get(0).(func(token.CreateParam) token.EmbeddingTokenCounter); ok { r0 = rf(param) } else { if ret.Get(0) != nil { - r0 = ret.Get(0).(*token.EmbeddingTokenCounter) + r0 = ret.Get(0).(token.EmbeddingTokenCounter) } } @@ -106,12 +106,12 @@ func (_c *MockCounterFactory_NewEmbedding_Call) Run(run func(param token.CreateP return _c } -func (_c *MockCounterFactory_NewEmbedding_Call) Return(_a0 *token.EmbeddingTokenCounter) *MockCounterFactory_NewEmbedding_Call { +func (_c *MockCounterFactory_NewEmbedding_Call) Return(_a0 token.EmbeddingTokenCounter) *MockCounterFactory_NewEmbedding_Call { _c.Call.Return(_a0) return _c } -func (_c *MockCounterFactory_NewEmbedding_Call) RunAndReturn(run func(token.CreateParam) *token.EmbeddingTokenCounter) *MockCounterFactory_NewEmbedding_Call { +func (_c *MockCounterFactory_NewEmbedding_Call) RunAndReturn(run func(token.CreateParam) token.EmbeddingTokenCounter) *MockCounterFactory_NewEmbedding_Call { _c.Call.Return(run) return _c } diff --git a/_mocks/opencsg.com/csghub-server/aigateway/token/mock_EmbeddingTokenCounter.go b/_mocks/opencsg.com/csghub-server/aigateway/token/mock_EmbeddingTokenCounter.go new file mode 100644 index 00000000..54918848 --- /dev/null +++ b/_mocks/opencsg.com/csghub-server/aigateway/token/mock_EmbeddingTokenCounter.go @@ -0,0 +1,163 @@ +// Code generated by mockery v2.53.0. DO NOT EDIT. + +package token + +import ( + context "context" + + openai "github.com/openai/openai-go/v3" + mock "github.com/stretchr/testify/mock" + + token "opencsg.com/csghub-server/aigateway/token" +) + +// MockEmbeddingTokenCounter is an autogenerated mock type for the EmbeddingTokenCounter type +type MockEmbeddingTokenCounter struct { + mock.Mock +} + +type MockEmbeddingTokenCounter_Expecter struct { + mock *mock.Mock +} + +func (_m *MockEmbeddingTokenCounter) EXPECT() *MockEmbeddingTokenCounter_Expecter { + return &MockEmbeddingTokenCounter_Expecter{mock: &_m.Mock} +} + +// Embedding provides a mock function with given fields: resp +func (_m *MockEmbeddingTokenCounter) Embedding(resp openai.CreateEmbeddingResponseUsage) { + _m.Called(resp) +} + +// MockEmbeddingTokenCounter_Embedding_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Embedding' +type MockEmbeddingTokenCounter_Embedding_Call struct { + *mock.Call +} + +// Embedding is a helper method to define mock.On call +// - resp openai.CreateEmbeddingResponseUsage +func (_e *MockEmbeddingTokenCounter_Expecter) Embedding(resp interface{}) *MockEmbeddingTokenCounter_Embedding_Call { + return &MockEmbeddingTokenCounter_Embedding_Call{Call: _e.mock.On("Embedding", resp)} +} + +func (_c *MockEmbeddingTokenCounter_Embedding_Call) Run(run func(resp openai.CreateEmbeddingResponseUsage)) *MockEmbeddingTokenCounter_Embedding_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(openai.CreateEmbeddingResponseUsage)) + }) + return _c +} + +func (_c *MockEmbeddingTokenCounter_Embedding_Call) Return() *MockEmbeddingTokenCounter_Embedding_Call { + _c.Call.Return() + return _c +} + +func (_c *MockEmbeddingTokenCounter_Embedding_Call) RunAndReturn(run func(openai.CreateEmbeddingResponseUsage)) *MockEmbeddingTokenCounter_Embedding_Call { + _c.Run(run) + return _c +} + +// Input provides a mock function with given fields: input +func (_m *MockEmbeddingTokenCounter) Input(input string) { + _m.Called(input) +} + +// MockEmbeddingTokenCounter_Input_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Input' +type MockEmbeddingTokenCounter_Input_Call struct { + *mock.Call +} + +// Input is a helper method to define mock.On call +// - input string +func (_e *MockEmbeddingTokenCounter_Expecter) Input(input interface{}) *MockEmbeddingTokenCounter_Input_Call { + return &MockEmbeddingTokenCounter_Input_Call{Call: _e.mock.On("Input", input)} +} + +func (_c *MockEmbeddingTokenCounter_Input_Call) Run(run func(input string)) *MockEmbeddingTokenCounter_Input_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(string)) + }) + return _c +} + +func (_c *MockEmbeddingTokenCounter_Input_Call) Return() *MockEmbeddingTokenCounter_Input_Call { + _c.Call.Return() + return _c +} + +func (_c *MockEmbeddingTokenCounter_Input_Call) RunAndReturn(run func(string)) *MockEmbeddingTokenCounter_Input_Call { + _c.Run(run) + return _c +} + +// Usage provides a mock function with given fields: c +func (_m *MockEmbeddingTokenCounter) Usage(c context.Context) (*token.Usage, error) { + ret := _m.Called(c) + + if len(ret) == 0 { + panic("no return value specified for Usage") + } + + var r0 *token.Usage + var r1 error + if rf, ok := ret.Get(0).(func(context.Context) (*token.Usage, error)); ok { + return rf(c) + } + if rf, ok := ret.Get(0).(func(context.Context) *token.Usage); ok { + r0 = rf(c) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*token.Usage) + } + } + + if rf, ok := ret.Get(1).(func(context.Context) error); ok { + r1 = rf(c) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockEmbeddingTokenCounter_Usage_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Usage' +type MockEmbeddingTokenCounter_Usage_Call struct { + *mock.Call +} + +// Usage is a helper method to define mock.On call +// - c context.Context +func (_e *MockEmbeddingTokenCounter_Expecter) Usage(c interface{}) *MockEmbeddingTokenCounter_Usage_Call { + return &MockEmbeddingTokenCounter_Usage_Call{Call: _e.mock.On("Usage", c)} +} + +func (_c *MockEmbeddingTokenCounter_Usage_Call) Run(run func(c context.Context)) *MockEmbeddingTokenCounter_Usage_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context)) + }) + return _c +} + +func (_c *MockEmbeddingTokenCounter_Usage_Call) Return(_a0 *token.Usage, _a1 error) *MockEmbeddingTokenCounter_Usage_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockEmbeddingTokenCounter_Usage_Call) RunAndReturn(run func(context.Context) (*token.Usage, error)) *MockEmbeddingTokenCounter_Usage_Call { + _c.Call.Return(run) + return _c +} + +// NewMockEmbeddingTokenCounter creates a new instance of MockEmbeddingTokenCounter. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. +// The first argument is typically a *testing.T value. +func NewMockEmbeddingTokenCounter(t interface { + mock.TestingT + Cleanup(func()) +}) *MockEmbeddingTokenCounter { + mock := &MockEmbeddingTokenCounter{} + mock.Mock.Test(t) + + t.Cleanup(func() { mock.AssertExpectations(t) }) + + return mock +} diff --git a/aigateway/handler/openai.go b/aigateway/handler/openai.go index 336cf99e..ef9d2b5e 100644 --- a/aigateway/handler/openai.go +++ b/aigateway/handler/openai.go @@ -407,7 +407,6 @@ func (h *OpenAIHandlerImpl) Embedding(c *gin.Context) { slog.InfoContext(c, "proxy embedding request to model endpoint", slog.Any("target", target), slog.Any("host", host), slog.Any("user", username), slog.Any("model_id", modelID)) rp, _ := proxy.NewReverseProxy(target) - w := NewResponseWriterWrapperEmbedding(c.Writer) tokenCounter := h.tokenCounterFactory.NewEmbedding(token.CreateParam{ Endpoint: target, @@ -415,6 +414,7 @@ func (h *OpenAIHandlerImpl) Embedding(c *gin.Context) { Model: modelName, ImageID: model.ImageID, }) + w := NewResponseWriterWrapperEmbedding(c.Writer, tokenCounter) tokenCounter.Input(req.Input) rp.ServeHTTP(w, c.Request, "", host) diff --git a/aigateway/handler/response_writer_wrapper_embedding.go b/aigateway/handler/response_writer_wrapper_embedding.go index a5e34459..0ca016d9 100644 --- a/aigateway/handler/response_writer_wrapper_embedding.go +++ b/aigateway/handler/response_writer_wrapper_embedding.go @@ -10,29 +10,20 @@ import ( "github.com/openai/openai-go/v3" "opencsg.com/csghub-server/aigateway/token" - "opencsg.com/csghub-server/builder/rpc" ) type ResponseWriterWrapperEmbedding struct { internalWritter http.ResponseWriter - modSvcClient rpc.ModerationSvcClient - tokenCounter *token.EmbeddingTokenCounter + tokenCounter token.EmbeddingTokenCounter } -func NewResponseWriterWrapperEmbedding(internalWritter http.ResponseWriter) *ResponseWriterWrapperEmbedding { +func NewResponseWriterWrapperEmbedding(internalWritter http.ResponseWriter, tokenCounter token.EmbeddingTokenCounter) *ResponseWriterWrapperEmbedding { return &ResponseWriterWrapperEmbedding{ internalWritter: internalWritter, + tokenCounter: tokenCounter, } } -func (rw *ResponseWriterWrapperEmbedding) WithModeration(modSvcClient rpc.ModerationSvcClient) { - rw.modSvcClient = modSvcClient -} - -func (rw *ResponseWriterWrapperEmbedding) WithTokenCounter(counter *token.EmbeddingTokenCounter) { - rw.tokenCounter = counter -} - func (rw *ResponseWriterWrapperEmbedding) Header() http.Header { return rw.internalWritter.Header() } @@ -66,14 +57,3 @@ func (rw *ResponseWriterWrapperEmbedding) Write(data []byte) (int, error) { return rw.internalWritter.Write(data) } - -//TODO: moderate embedding request and generate sensitive response if needed -// func (rw *ResponseWriterWrapperEmbedding) generateSensitiveResp(originResp openai.CreateEmbeddingResponse) openai.CreateEmbeddingResponse { -// newResp := openai.CreateEmbeddingResponse{ -// Data: nil, -// Model: originResp.Model, -// Object: originResp.Object, -// Usage: originResp.Usage, -// } -// return newResp -// } diff --git a/aigateway/handler/response_writer_wrapper_embedding_test.go b/aigateway/handler/response_writer_wrapper_embedding_test.go index bc0d00b6..49e91075 100644 --- a/aigateway/handler/response_writer_wrapper_embedding_test.go +++ b/aigateway/handler/response_writer_wrapper_embedding_test.go @@ -1,6 +1,8 @@ package handler import ( + "bytes" + "compress/gzip" "context" "encoding/json" "net/http/httptest" @@ -8,6 +10,7 @@ import ( "github.com/openai/openai-go/v3" "github.com/stretchr/testify/assert" + mocktoken "opencsg.com/csghub-server/_mocks/opencsg.com/csghub-server/aigateway/token" "opencsg.com/csghub-server/aigateway/token" ) @@ -71,12 +74,15 @@ func TestResponseWriterWrapperEmbedding_Write(t *testing.T) { t.Run(tt.name, func(t *testing.T) { // Create a test http.ResponseWriter w := httptest.NewRecorder() - wrapper := NewResponseWriterWrapperEmbedding(w) // Setup token counter if needed + var wrapper *ResponseWriterWrapperEmbedding + var counter *mocktoken.MockEmbeddingTokenCounter = nil if tt.withCounter { - counter := token.NewEmbeddingTokenCounter(nil) - wrapper.WithTokenCounter(counter) + counter = mocktoken.NewMockEmbeddingTokenCounter(t) + wrapper = NewResponseWriterWrapperEmbedding(w, counter) + } else { + wrapper = NewResponseWriterWrapperEmbedding(w, nil) } // Prepare test data @@ -84,7 +90,19 @@ func TestResponseWriterWrapperEmbedding_Write(t *testing.T) { var err error if tt.name == "invalid json data" { data = []byte(`invalid json`) - } else { + } else if tt.withCounter { + data, err = json.Marshal(tt.response) + assert.NoError(t, err) + var expectResp openai.CreateEmbeddingResponse + err = json.Unmarshal(data, &expectResp) + assert.NoError(t, err) + counter.EXPECT().Embedding(expectResp.Usage).Return().Once() + counter.EXPECT().Usage(context.Background()).Return(&token.Usage{ + TotalTokens: expectResp.Usage.TotalTokens, + PromptTokens: expectResp.Usage.PromptTokens, + CompletionTokens: 0, + }, nil).Once() + } else if !tt.withCounter { data, err = json.Marshal(tt.response) assert.NoError(t, err) } @@ -112,3 +130,64 @@ func TestResponseWriterWrapperEmbedding_Write(t *testing.T) { }) } } + +func TestResponseWriterWrapperEmbedding_Write_Gzip(t *testing.T) { + // Create a test http.ResponseWriter + w := httptest.NewRecorder() + + // Setup token counter + counter := mocktoken.NewMockEmbeddingTokenCounter(t) + wrapper := NewResponseWriterWrapperEmbedding(w, counter) + + // Create embedding response + response := openai.CreateEmbeddingResponse{ + Object: "embedding", + Data: []openai.Embedding{ + { + Object: "embedding", + Embedding: []float64{0.1, 0.2, 0.3}, + Index: 0, + }, + }, + Model: "text-embedding-ada-002", + Usage: openai.CreateEmbeddingResponseUsage{ + PromptTokens: 10, + TotalTokens: 10, + }, + } + + // Marshal response to JSON + jsonData, err := json.Marshal(response) + assert.NoError(t, err) + + // Create gzip compressed data + var gzippedData bytes.Buffer + gzipWriter := gzip.NewWriter(&gzippedData) + _, err = gzipWriter.Write(jsonData) + assert.NoError(t, err) + gzipWriter.Close() + var expectResp openai.CreateEmbeddingResponse + err = json.Unmarshal(jsonData, &expectResp) + assert.NoError(t, err) + // Set expectations for token counter + counter.EXPECT().Embedding(expectResp.Usage).Return().Once() + counter.EXPECT().Usage(context.Background()).Return(&token.Usage{ + TotalTokens: expectResp.Usage.TotalTokens, + PromptTokens: expectResp.Usage.PromptTokens, + CompletionTokens: 0, + }, nil).Once() + + // Execute Write method with gzipped data + n, err := wrapper.Write(gzippedData.Bytes()) + + // Verify results + assert.NoError(t, err) + assert.Equal(t, len(gzippedData.Bytes()), n) + assert.Equal(t, gzippedData.Bytes(), w.Body.Bytes()) + + // Verify token counter was called with correct usage data + usage, err := wrapper.tokenCounter.Usage(context.Background()) + assert.NoError(t, err) + assert.Equal(t, response.Usage.TotalTokens, usage.TotalTokens) + assert.Equal(t, response.Usage.PromptTokens, usage.PromptTokens) +} diff --git a/aigateway/token/embedding_token_counter.go b/aigateway/token/embedding_token_counter.go index 3349c66f..16271b4f 100644 --- a/aigateway/token/embedding_token_counter.go +++ b/aigateway/token/embedding_token_counter.go @@ -7,31 +7,37 @@ import ( "github.com/openai/openai-go/v3" ) -var _ Counter = (*EmbeddingTokenCounter)(nil) +var _ Counter = (*EmbeddingTokenCounterImpl)(nil) -type EmbeddingTokenCounter struct { +type EmbeddingTokenCounter interface { + Embedding(resp openai.CreateEmbeddingResponseUsage) + Input(input string) + Usage(c context.Context) (*Usage, error) +} + +type EmbeddingTokenCounterImpl struct { input string usage *openai.CreateEmbeddingResponseUsage tokenizer Tokenizer } -func NewEmbeddingTokenCounter(tokenizer Tokenizer) *EmbeddingTokenCounter { - return &EmbeddingTokenCounter{ +func NewEmbeddingTokenCounter(tokenizer Tokenizer) EmbeddingTokenCounter { + return &EmbeddingTokenCounterImpl{ tokenizer: tokenizer, } } // Embedding implements EmbeddingTokenCounter. -func (l *EmbeddingTokenCounter) Embedding(resp openai.CreateEmbeddingResponseUsage) { +func (l *EmbeddingTokenCounterImpl) Embedding(resp openai.CreateEmbeddingResponseUsage) { l.usage = &resp } -func (l *EmbeddingTokenCounter) Input(input string) { +func (l *EmbeddingTokenCounterImpl) Input(input string) { l.input = input } // Usage implements LLMTokenCounter. -func (l *EmbeddingTokenCounter) Usage(c context.Context) (*Usage, error) { +func (l *EmbeddingTokenCounterImpl) Usage(c context.Context) (*Usage, error) { if l.usage != nil { return &Usage{ PromptTokens: l.usage.PromptTokens, diff --git a/aigateway/token/embedding_token_counter_test.go b/aigateway/token/embedding_token_counter_test.go new file mode 100644 index 00000000..bd502fe3 --- /dev/null +++ b/aigateway/token/embedding_token_counter_test.go @@ -0,0 +1,129 @@ +package token_test + +import ( + "context" + "testing" + + "github.com/openai/openai-go/v3" + "github.com/stretchr/testify/assert" + mocktoken "opencsg.com/csghub-server/_mocks/opencsg.com/csghub-server/aigateway/token" + "opencsg.com/csghub-server/aigateway/token" +) + +func TestEmbeddingTokenCounter_Usage_WithEmbeddingResponse(t *testing.T) { + // Test that when embedding response is set, Usage method returns usage from response + tokenizer := mocktoken.NewMockTokenizer(t) + counter := token.NewEmbeddingTokenCounter(tokenizer) + + // Set up embedding response + embeddingUsage := openai.CreateEmbeddingResponseUsage{ + PromptTokens: 10, + TotalTokens: 10, + } + counter.Embedding(embeddingUsage) + + // Call Usage method + usage, err := counter.Usage(context.Background()) + + // Verify results + assert.NoError(t, err) + assert.Equal(t, int64(10), usage.PromptTokens) + assert.Equal(t, int64(10), usage.TotalTokens) + assert.Equal(t, int64(0), usage.CompletionTokens) +} + +func TestEmbeddingTokenCounter_Usage_WithTokenizer(t *testing.T) { + // Test that when embedding response is not set but tokenizer is available, use tokenizer to count tokens + tokenizer := mocktoken.NewMockTokenizer(t) + counter := token.NewEmbeddingTokenCounter(tokenizer) + + // Set up mock tokenizer behavior + inputText := "Hello, world!" + tokenizer.On("EmbeddingEncode", inputText).Return(int64(5), nil) + + // Set input + counter.Input(inputText) + + // Call Usage method + usage, err := counter.Usage(context.Background()) + + // Verify results + assert.NoError(t, err) + assert.Equal(t, int64(5), usage.PromptTokens) + assert.Equal(t, int64(5), usage.TotalTokens) + assert.Equal(t, int64(0), usage.CompletionTokens) + + // Verify tokenizer was called + tokenizer.AssertCalled(t, "EmbeddingEncode", inputText) +} + +func TestEmbeddingTokenCounter_Usage_WithoutTokenizer(t *testing.T) { + // Test that when tokenizer is nil and no embedding response, return error + inputText := "Hello, world!" + counter := token.NewEmbeddingTokenCounter(nil) + + counter.Input(inputText) + + usage, err := counter.Usage(context.Background()) + + assert.Error(t, err) + assert.Nil(t, usage) + assert.Equal(t, "no usage found in embedding response, and tokenizer not set", err.Error()) +} + +func TestEmbeddingTokenCounter_Usage_WithDumyTokenizer(t *testing.T) { + // Test with the existing DumyTokenizer + tokenizer := &token.DumyTokenizer{} + counter := token.NewEmbeddingTokenCounter(tokenizer) + + // Set input + inputText := "Hello, world!" + counter.Input(inputText) + + // Call Usage method + usage, err := counter.Usage(context.Background()) + + // Verify results + assert.NoError(t, err) + // DumyTokenizer counts characters as tokens + assert.Equal(t, int64(len(inputText)), usage.PromptTokens) + assert.Equal(t, int64(len(inputText)), usage.TotalTokens) + assert.Equal(t, int64(0), usage.CompletionTokens) +} + +func TestEmbeddingTokenCounter_Input(t *testing.T) { + // Test that Input method correctly sets the input text + tokenizer := mocktoken.NewMockTokenizer(t) + counter := token.NewEmbeddingTokenCounter(tokenizer) + + // Set input + inputText := "Test input text" + counter.Input(inputText) + + // Verify input is stored correctly by calling Usage and checking tokenizer call + tokenizer.EXPECT().EmbeddingEncode(inputText).Return(int64(5), nil) + usage, err := counter.Usage(context.Background()) + assert.NoError(t, err) + assert.Equal(t, int64(5), usage.PromptTokens) + assert.Equal(t, int64(5), usage.TotalTokens) + assert.Equal(t, int64(0), usage.CompletionTokens) +} + +func TestEmbeddingTokenCounter_Embedding(t *testing.T) { + // Test that Embedding method correctly sets the embedding usage + tokenizer := mocktoken.NewMockTokenizer(t) + counter := token.NewEmbeddingTokenCounter(tokenizer) + + // Set up embedding response + embeddingUsage := openai.CreateEmbeddingResponseUsage{ + PromptTokens: 20, + TotalTokens: 20, + } + counter.Embedding(embeddingUsage) + + // Verify embedding usage is stored correctly by calling Usage + usage, err := counter.Usage(context.Background()) + assert.NoError(t, err) + assert.Equal(t, int64(20), usage.PromptTokens) + assert.Equal(t, int64(20), usage.TotalTokens) +} diff --git a/aigateway/token/token_counter.go b/aigateway/token/token_counter.go index d97b7159..ec4f4017 100644 --- a/aigateway/token/token_counter.go +++ b/aigateway/token/token_counter.go @@ -11,21 +11,21 @@ type CreateParam struct { type CounterFactory interface { NewChat(param CreateParam) ChatTokenCounter - NewEmbedding(param CreateParam) *EmbeddingTokenCounter + NewEmbedding(param CreateParam) EmbeddingTokenCounter } func NewCounterFactory() CounterFactory { - return &tokenizerFactoryImpl{} + return &counterFactoryImpl{} } -type tokenizerFactoryImpl struct{} +type counterFactoryImpl struct{} -func (f *tokenizerFactoryImpl) NewChat(param CreateParam) ChatTokenCounter { +func (f *counterFactoryImpl) NewChat(param CreateParam) ChatTokenCounter { tokenizer := NewTokenizerImpl(param.Endpoint, param.Host, param.Model, param.ImageID) return NewLLMTokenCounter(tokenizer) } -func (f *tokenizerFactoryImpl) NewEmbedding(param CreateParam) *EmbeddingTokenCounter { +func (f *counterFactoryImpl) NewEmbedding(param CreateParam) EmbeddingTokenCounter { tokenizer := NewTokenizerImpl(param.Endpoint, param.Host, param.Model, param.ImageID) return NewEmbeddingTokenCounter(tokenizer) } From b17e331f05251d189a51cc7c18412b74a8aa3d65 Mon Sep 17 00:00:00 2001 From: Dev Agent Date: Thu, 18 Dec 2025 08:49:43 +0000 Subject: [PATCH 2/2] fix request for embedding --- aigateway/handler/openai.go | 15 ++++-- aigateway/handler/openai_test.go | 48 ++++++++++++++----- aigateway/handler/requests.go | 77 ++++++++++++++++++++++++++++-- aigateway/handler/requests_test.go | 67 ++++++++++++++++++++++++++ 4 files changed, 189 insertions(+), 18 deletions(-) diff --git a/aigateway/handler/openai.go b/aigateway/handler/openai.go index ef9d2b5e..f94b0c76 100644 --- a/aigateway/handler/openai.go +++ b/aigateway/handler/openai.go @@ -346,8 +346,15 @@ func (h *OpenAIHandlerImpl) Embedding(c *gin.Context) { c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) return } - if req.Input == "" || req.Model == "" { - c.JSON(http.StatusBadRequest, gin.H{"error": "Model and input cannot be empty"}) + if req.Model == "" { + c.JSON(http.StatusBadRequest, gin.H{"error": "Model cannot be empty"}) + return + } + if req.Input.OfString.String() == "" && + len(req.Input.OfArrayOfStrings) == 0 && + len(req.Input.OfArrayOfTokenArrays) == 0 && + len(req.Input.OfArrayOfTokens) == 0 { + c.JSON(http.StatusBadRequest, gin.H{"error": "Input cannot be empty"}) return } modelID := req.Model @@ -415,7 +422,9 @@ func (h *OpenAIHandlerImpl) Embedding(c *gin.Context) { ImageID: model.ImageID, }) w := NewResponseWriterWrapperEmbedding(c.Writer, tokenCounter) - tokenCounter.Input(req.Input) + if req.Input.OfString.String() != "" { + tokenCounter.Input(req.Input.OfString.Value) + } rp.ServeHTTP(w, c.Request, "", host) go func() { diff --git a/aigateway/handler/openai_test.go b/aigateway/handler/openai_test.go index a26611e1..ffdef798 100644 --- a/aigateway/handler/openai_test.go +++ b/aigateway/handler/openai_test.go @@ -403,8 +403,12 @@ func TestOpenAIHandler_Embedding(t *testing.T) { tester, c, w := setupTest(t) // Empty Input embeddingReq := EmbeddingRequest{ - Model: "model1:svc1", - Input: "", + EmbeddingNewParams: openai.EmbeddingNewParams{ + Model: "model1:svc1", + Input: openai.EmbeddingNewParamsInputUnion{ + OfArrayOfStrings: []string{}, + }, + }, } body, _ := json.Marshal(embeddingReq) c.Request.Method = http.MethodPost @@ -424,8 +428,12 @@ func TestOpenAIHandler_Embedding(t *testing.T) { httpbase.SetCurrentUser(c, "testuser") httpbase.SetCurrentUserUUID(c, "testuuid") embeddingReq = EmbeddingRequest{ - Model: "", - Input: "test input", + EmbeddingNewParams: openai.EmbeddingNewParams{ + Model: "", + Input: openai.EmbeddingNewParamsInputUnion{ + OfArrayOfStrings: []string{"test input"}, + }, + }, } body, _ = json.Marshal(embeddingReq) c.Request.Body = io.NopCloser(bytes.NewReader(body)) @@ -438,8 +446,12 @@ func TestOpenAIHandler_Embedding(t *testing.T) { t.Run("model not found", func(t *testing.T) { tester, c, w := setupTest(t) embeddingReq := EmbeddingRequest{ - Model: "nonexistent:svc", - Input: "test input", + EmbeddingNewParams: openai.EmbeddingNewParams{ + Model: "nonexistent:svc", + Input: openai.EmbeddingNewParamsInputUnion{ + OfArrayOfStrings: []string{"test input"}, + }, + }, } body, _ := json.Marshal(embeddingReq) c.Request.Method = http.MethodPost @@ -455,8 +467,12 @@ func TestOpenAIHandler_Embedding(t *testing.T) { t.Run("get model error", func(t *testing.T) { tester, c, w := setupTest(t) embeddingReq := EmbeddingRequest{ - Model: "model1:svc1", - Input: "test input", + EmbeddingNewParams: openai.EmbeddingNewParams{ + Model: "model1:svc1", + Input: openai.EmbeddingNewParamsInputUnion{ + OfArrayOfStrings: []string{"test input"}, + }, + }, } body, _ := json.Marshal(embeddingReq) c.Request.Method = http.MethodPost @@ -472,8 +488,12 @@ func TestOpenAIHandler_Embedding(t *testing.T) { t.Run("model not running", func(t *testing.T) { tester, c, w := setupTest(t) embeddingReq := EmbeddingRequest{ - Model: "model1:svc1", - Input: "test input", + EmbeddingNewParams: openai.EmbeddingNewParams{ + Model: "model1:svc1", + Input: openai.EmbeddingNewParamsInputUnion{ + OfArrayOfStrings: []string{"test input"}, + }, + }, } body, _ := json.Marshal(embeddingReq) c.Request.Method = http.MethodPost @@ -503,8 +523,12 @@ func TestOpenAIHandler_Embedding(t *testing.T) { t.Run("model without svc name", func(t *testing.T) { tester, c, _ := setupTest(t) embeddingReq := EmbeddingRequest{ - Model: "model1", - Input: "test input", + EmbeddingNewParams: openai.EmbeddingNewParams{ + Model: "model1", + Input: openai.EmbeddingNewParamsInputUnion{ + OfArrayOfStrings: []string{"test input"}, + }, + }, } body, _ := json.Marshal(embeddingReq) c.Request.Method = http.MethodPost diff --git a/aigateway/handler/requests.go b/aigateway/handler/requests.go index 75fc8aef..dedc072c 100644 --- a/aigateway/handler/requests.go +++ b/aigateway/handler/requests.go @@ -119,7 +119,78 @@ type StreamOptions struct { // EmbeddingRequest represents an embedding request structure type EmbeddingRequest struct { - Input string `json:"input"` // Input text content - Model string `json:"model"` // Model name used (e.g., "text-embedding-ada-002") - EncodingFormat string `json:"encoding_format,omitempty"` // Encoding format (e.g., "float") + openai.EmbeddingNewParams + // RawJSON stores all unknown fields during unmarshaling + RawJSON json.RawMessage `json:"-"` +} + +func (r *EmbeddingRequest) UnmarshalJSON(data []byte) error { + // Create a temporary struct to hold the known fields + type TempEmbeddingRequest EmbeddingRequest + + // First, unmarshal into the temporary struct + var temp TempEmbeddingRequest + if err := json.Unmarshal(data, &temp); err != nil { + return err + } + + // Then, unmarshal into a map to get all fields + var allFields map[string]json.RawMessage + if err := json.Unmarshal(data, &allFields); err != nil { + return err + } + + // Remove known fields from the map + delete(allFields, "model") + delete(allFields, "input") + delete(allFields, "encoding_format") + + // If there are any unknown fields left, marshal them into RawJSON + var rawJSON []byte + var err error + if len(allFields) > 0 { + rawJSON, err = json.Marshal(allFields) + if err != nil { + return err + } + } + + // Assign the temporary struct to the original and set RawJSON + *r = EmbeddingRequest(temp) + r.RawJSON = rawJSON + return nil +} + +func (r EmbeddingRequest) MarshalJSON() ([]byte, error) { + // First, marshal the known fields + type TempEmbeddingRequest EmbeddingRequest + data, err := json.Marshal(TempEmbeddingRequest(r)) + if err != nil { + return nil, err + } + + // If there are no raw JSON fields, just return the known fields + if len(r.RawJSON) == 0 { + return data, nil + } + + // Parse the known fields back into a map + var knownFields map[string]json.RawMessage + if err := json.Unmarshal(data, &knownFields); err != nil { + return nil, err + } + + // Parse the raw JSON fields into a map + var rawFields map[string]json.RawMessage + if err := json.Unmarshal(r.RawJSON, &rawFields); err != nil { + return nil, err + } + + // Merge the raw fields into the known fields + for k, v := range rawFields { + knownFields[k] = v + } + + // Marshal the merged map back into JSON + return json.Marshal(knownFields) } diff --git a/aigateway/handler/requests_test.go b/aigateway/handler/requests_test.go index 8b6d971c..7ec2f159 100644 --- a/aigateway/handler/requests_test.go +++ b/aigateway/handler/requests_test.go @@ -162,3 +162,70 @@ func TestChatCompletionRequest_EmptyRawJSON(t *testing.T) { // RawJSON should be empty assert.Empty(t, req4Unmarshaled.RawJSON) } + +func TestEmbeddingRequest_MarshalUnmarshal(t *testing.T) { + // Test case 1: Only known fields + req1 := &EmbeddingRequest{ + EmbeddingNewParams: openai.EmbeddingNewParams{ + Input: openai.EmbeddingNewParamsInputUnion{ + OfArrayOfStrings: []string{"Hello, world!"}, + }, + Model: "text-embedding-ada-002", + }, + } + + // Marshal to JSON + data1, err := json.Marshal(req1) + assert.NoError(t, err) + + // Unmarshal back + var req1Unmarshaled EmbeddingRequest + err = json.Unmarshal(data1, &req1Unmarshaled) + assert.NoError(t, err) + + // Verify fields + assert.Equal(t, req1.Model, req1Unmarshaled.Model) + assert.Equal(t, len(req1.Input.OfArrayOfStrings), len(req1Unmarshaled.Input.OfArrayOfStrings)) + assert.Empty(t, req1Unmarshaled.RawJSON) +} + +func TestEmbeddingRequest_UnknownFields(t *testing.T) { + // Test case 2: With unknown fields + jsonWithUnknown := `{ + "model": "text-embedding-ada-002", + "input": ["Hello, world!"], + "unknown_field": "unknown_value", + "another_unknown": 12345 + }` + + // Unmarshal + var req2 EmbeddingRequest + err := json.Unmarshal([]byte(jsonWithUnknown), &req2) + assert.NoError(t, err) + + // Verify known fields + assert.Equal(t, "text-embedding-ada-002", req2.Model) + assert.Equal(t, 1, len(req2.Input.OfArrayOfStrings)) + + // Verify unknown fields are stored in RawJSON + assert.NotEmpty(t, req2.RawJSON) + + // Marshal back and verify unknown fields are preserved + data2, err := json.Marshal(req2) + assert.NoError(t, err) + + // Unmarshal into map to check all fields + var resultMap map[string]interface{} + err = json.Unmarshal(data2, &resultMap) + assert.NoError(t, err) + + // Check known fields + assert.Equal(t, "text-embedding-ada-002", resultMap["model"]) + inputArray, ok := resultMap["input"].([]interface{}) + assert.True(t, ok) + assert.Equal(t, "Hello, world!", inputArray[0]) + + // Check unknown fields + assert.Equal(t, "unknown_value", resultMap["unknown_field"]) + assert.Equal(t, 12345.0, resultMap["another_unknown"]) +}