Skip to content

Commit b17e331

Browse files
Dev AgentQinYuuuu
authored andcommitted
fix request for embedding
1 parent ee485e6 commit b17e331

File tree

4 files changed

+189
-18
lines changed

4 files changed

+189
-18
lines changed

aigateway/handler/openai.go

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -346,8 +346,15 @@ func (h *OpenAIHandlerImpl) Embedding(c *gin.Context) {
346346
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
347347
return
348348
}
349-
if req.Input == "" || req.Model == "" {
350-
c.JSON(http.StatusBadRequest, gin.H{"error": "Model and input cannot be empty"})
349+
if req.Model == "" {
350+
c.JSON(http.StatusBadRequest, gin.H{"error": "Model cannot be empty"})
351+
return
352+
}
353+
if req.Input.OfString.String() == "" &&
354+
len(req.Input.OfArrayOfStrings) == 0 &&
355+
len(req.Input.OfArrayOfTokenArrays) == 0 &&
356+
len(req.Input.OfArrayOfTokens) == 0 {
357+
c.JSON(http.StatusBadRequest, gin.H{"error": "Input cannot be empty"})
351358
return
352359
}
353360
modelID := req.Model
@@ -415,7 +422,9 @@ func (h *OpenAIHandlerImpl) Embedding(c *gin.Context) {
415422
ImageID: model.ImageID,
416423
})
417424
w := NewResponseWriterWrapperEmbedding(c.Writer, tokenCounter)
418-
tokenCounter.Input(req.Input)
425+
if req.Input.OfString.String() != "" {
426+
tokenCounter.Input(req.Input.OfString.Value)
427+
}
419428

420429
rp.ServeHTTP(w, c.Request, "", host)
421430
go func() {

aigateway/handler/openai_test.go

Lines changed: 36 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -403,8 +403,12 @@ func TestOpenAIHandler_Embedding(t *testing.T) {
403403
tester, c, w := setupTest(t)
404404
// Empty Input
405405
embeddingReq := EmbeddingRequest{
406-
Model: "model1:svc1",
407-
Input: "",
406+
EmbeddingNewParams: openai.EmbeddingNewParams{
407+
Model: "model1:svc1",
408+
Input: openai.EmbeddingNewParamsInputUnion{
409+
OfArrayOfStrings: []string{},
410+
},
411+
},
408412
}
409413
body, _ := json.Marshal(embeddingReq)
410414
c.Request.Method = http.MethodPost
@@ -424,8 +428,12 @@ func TestOpenAIHandler_Embedding(t *testing.T) {
424428
httpbase.SetCurrentUser(c, "testuser")
425429
httpbase.SetCurrentUserUUID(c, "testuuid")
426430
embeddingReq = EmbeddingRequest{
427-
Model: "",
428-
Input: "test input",
431+
EmbeddingNewParams: openai.EmbeddingNewParams{
432+
Model: "",
433+
Input: openai.EmbeddingNewParamsInputUnion{
434+
OfArrayOfStrings: []string{"test input"},
435+
},
436+
},
429437
}
430438
body, _ = json.Marshal(embeddingReq)
431439
c.Request.Body = io.NopCloser(bytes.NewReader(body))
@@ -438,8 +446,12 @@ func TestOpenAIHandler_Embedding(t *testing.T) {
438446
t.Run("model not found", func(t *testing.T) {
439447
tester, c, w := setupTest(t)
440448
embeddingReq := EmbeddingRequest{
441-
Model: "nonexistent:svc",
442-
Input: "test input",
449+
EmbeddingNewParams: openai.EmbeddingNewParams{
450+
Model: "nonexistent:svc",
451+
Input: openai.EmbeddingNewParamsInputUnion{
452+
OfArrayOfStrings: []string{"test input"},
453+
},
454+
},
443455
}
444456
body, _ := json.Marshal(embeddingReq)
445457
c.Request.Method = http.MethodPost
@@ -455,8 +467,12 @@ func TestOpenAIHandler_Embedding(t *testing.T) {
455467
t.Run("get model error", func(t *testing.T) {
456468
tester, c, w := setupTest(t)
457469
embeddingReq := EmbeddingRequest{
458-
Model: "model1:svc1",
459-
Input: "test input",
470+
EmbeddingNewParams: openai.EmbeddingNewParams{
471+
Model: "model1:svc1",
472+
Input: openai.EmbeddingNewParamsInputUnion{
473+
OfArrayOfStrings: []string{"test input"},
474+
},
475+
},
460476
}
461477
body, _ := json.Marshal(embeddingReq)
462478
c.Request.Method = http.MethodPost
@@ -472,8 +488,12 @@ func TestOpenAIHandler_Embedding(t *testing.T) {
472488
t.Run("model not running", func(t *testing.T) {
473489
tester, c, w := setupTest(t)
474490
embeddingReq := EmbeddingRequest{
475-
Model: "model1:svc1",
476-
Input: "test input",
491+
EmbeddingNewParams: openai.EmbeddingNewParams{
492+
Model: "model1:svc1",
493+
Input: openai.EmbeddingNewParamsInputUnion{
494+
OfArrayOfStrings: []string{"test input"},
495+
},
496+
},
477497
}
478498
body, _ := json.Marshal(embeddingReq)
479499
c.Request.Method = http.MethodPost
@@ -503,8 +523,12 @@ func TestOpenAIHandler_Embedding(t *testing.T) {
503523
t.Run("model without svc name", func(t *testing.T) {
504524
tester, c, _ := setupTest(t)
505525
embeddingReq := EmbeddingRequest{
506-
Model: "model1",
507-
Input: "test input",
526+
EmbeddingNewParams: openai.EmbeddingNewParams{
527+
Model: "model1",
528+
Input: openai.EmbeddingNewParamsInputUnion{
529+
OfArrayOfStrings: []string{"test input"},
530+
},
531+
},
508532
}
509533
body, _ := json.Marshal(embeddingReq)
510534
c.Request.Method = http.MethodPost

aigateway/handler/requests.go

Lines changed: 74 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -119,7 +119,78 @@ type StreamOptions struct {
119119

120120
// EmbeddingRequest represents an embedding request structure
121121
type EmbeddingRequest struct {
122-
Input string `json:"input"` // Input text content
123-
Model string `json:"model"` // Model name used (e.g., "text-embedding-ada-002")
124-
EncodingFormat string `json:"encoding_format,omitempty"` // Encoding format (e.g., "float")
122+
openai.EmbeddingNewParams
123+
// RawJSON stores all unknown fields during unmarshaling
124+
RawJSON json.RawMessage `json:"-"`
125+
}
126+
127+
func (r *EmbeddingRequest) UnmarshalJSON(data []byte) error {
128+
// Create a temporary struct to hold the known fields
129+
type TempEmbeddingRequest EmbeddingRequest
130+
131+
// First, unmarshal into the temporary struct
132+
var temp TempEmbeddingRequest
133+
if err := json.Unmarshal(data, &temp); err != nil {
134+
return err
135+
}
136+
137+
// Then, unmarshal into a map to get all fields
138+
var allFields map[string]json.RawMessage
139+
if err := json.Unmarshal(data, &allFields); err != nil {
140+
return err
141+
}
142+
143+
// Remove known fields from the map
144+
delete(allFields, "model")
145+
delete(allFields, "input")
146+
delete(allFields, "encoding_format")
147+
148+
// If there are any unknown fields left, marshal them into RawJSON
149+
var rawJSON []byte
150+
var err error
151+
if len(allFields) > 0 {
152+
rawJSON, err = json.Marshal(allFields)
153+
if err != nil {
154+
return err
155+
}
156+
}
157+
158+
// Assign the temporary struct to the original and set RawJSON
159+
*r = EmbeddingRequest(temp)
160+
r.RawJSON = rawJSON
161+
return nil
162+
}
163+
164+
func (r EmbeddingRequest) MarshalJSON() ([]byte, error) {
165+
// First, marshal the known fields
166+
type TempEmbeddingRequest EmbeddingRequest
167+
data, err := json.Marshal(TempEmbeddingRequest(r))
168+
if err != nil {
169+
return nil, err
170+
}
171+
172+
// If there are no raw JSON fields, just return the known fields
173+
if len(r.RawJSON) == 0 {
174+
return data, nil
175+
}
176+
177+
// Parse the known fields back into a map
178+
var knownFields map[string]json.RawMessage
179+
if err := json.Unmarshal(data, &knownFields); err != nil {
180+
return nil, err
181+
}
182+
183+
// Parse the raw JSON fields into a map
184+
var rawFields map[string]json.RawMessage
185+
if err := json.Unmarshal(r.RawJSON, &rawFields); err != nil {
186+
return nil, err
187+
}
188+
189+
// Merge the raw fields into the known fields
190+
for k, v := range rawFields {
191+
knownFields[k] = v
192+
}
193+
194+
// Marshal the merged map back into JSON
195+
return json.Marshal(knownFields)
125196
}

aigateway/handler/requests_test.go

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -162,3 +162,70 @@ func TestChatCompletionRequest_EmptyRawJSON(t *testing.T) {
162162
// RawJSON should be empty
163163
assert.Empty(t, req4Unmarshaled.RawJSON)
164164
}
165+
166+
func TestEmbeddingRequest_MarshalUnmarshal(t *testing.T) {
167+
// Test case 1: Only known fields
168+
req1 := &EmbeddingRequest{
169+
EmbeddingNewParams: openai.EmbeddingNewParams{
170+
Input: openai.EmbeddingNewParamsInputUnion{
171+
OfArrayOfStrings: []string{"Hello, world!"},
172+
},
173+
Model: "text-embedding-ada-002",
174+
},
175+
}
176+
177+
// Marshal to JSON
178+
data1, err := json.Marshal(req1)
179+
assert.NoError(t, err)
180+
181+
// Unmarshal back
182+
var req1Unmarshaled EmbeddingRequest
183+
err = json.Unmarshal(data1, &req1Unmarshaled)
184+
assert.NoError(t, err)
185+
186+
// Verify fields
187+
assert.Equal(t, req1.Model, req1Unmarshaled.Model)
188+
assert.Equal(t, len(req1.Input.OfArrayOfStrings), len(req1Unmarshaled.Input.OfArrayOfStrings))
189+
assert.Empty(t, req1Unmarshaled.RawJSON)
190+
}
191+
192+
func TestEmbeddingRequest_UnknownFields(t *testing.T) {
193+
// Test case 2: With unknown fields
194+
jsonWithUnknown := `{
195+
"model": "text-embedding-ada-002",
196+
"input": ["Hello, world!"],
197+
"unknown_field": "unknown_value",
198+
"another_unknown": 12345
199+
}`
200+
201+
// Unmarshal
202+
var req2 EmbeddingRequest
203+
err := json.Unmarshal([]byte(jsonWithUnknown), &req2)
204+
assert.NoError(t, err)
205+
206+
// Verify known fields
207+
assert.Equal(t, "text-embedding-ada-002", req2.Model)
208+
assert.Equal(t, 1, len(req2.Input.OfArrayOfStrings))
209+
210+
// Verify unknown fields are stored in RawJSON
211+
assert.NotEmpty(t, req2.RawJSON)
212+
213+
// Marshal back and verify unknown fields are preserved
214+
data2, err := json.Marshal(req2)
215+
assert.NoError(t, err)
216+
217+
// Unmarshal into map to check all fields
218+
var resultMap map[string]interface{}
219+
err = json.Unmarshal(data2, &resultMap)
220+
assert.NoError(t, err)
221+
222+
// Check known fields
223+
assert.Equal(t, "text-embedding-ada-002", resultMap["model"])
224+
inputArray, ok := resultMap["input"].([]interface{})
225+
assert.True(t, ok)
226+
assert.Equal(t, "Hello, world!", inputArray[0])
227+
228+
// Check unknown fields
229+
assert.Equal(t, "unknown_value", resultMap["unknown_field"])
230+
assert.Equal(t, 12345.0, resultMap["another_unknown"])
231+
}

0 commit comments

Comments
 (0)