Skip to content

Commit e21b3fe

Browse files
committed
Add streaming response process.
1 parent 565a1ce commit e21b3fe

File tree

7 files changed

+368
-60
lines changed

7 files changed

+368
-60
lines changed

pkg/epp/handlers/response.go

Lines changed: 23 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -17,16 +17,17 @@ limitations under the License.
1717
package handlers
1818

1919
import (
20+
"bytes"
2021
"context"
2122
"encoding/json"
22-
"fmt"
2323
"strings"
2424

2525
configPb "github.com/envoyproxy/go-control-plane/envoy/config/core/v3"
2626
extProcPb "github.com/envoyproxy/go-control-plane/envoy/service/ext_proc/v3"
2727
"sigs.k8s.io/controller-runtime/pkg/log"
2828

2929
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/metrics"
30+
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/types"
3031
logutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/logging"
3132
)
3233

@@ -36,49 +37,50 @@ const (
3637
)
3738

3839
// HandleResponseBody always returns the requestContext even in the error case, as the request context is used in error handling.
39-
func (s *StreamingServer) HandleResponseBody(ctx context.Context, reqCtx *RequestContext, response map[string]any) (*RequestContext, error) {
40+
func (s *StreamingServer) HandleResponseBody(ctx context.Context, reqCtx *RequestContext, body []byte) (*RequestContext, error) {
4041
logger := log.FromContext(ctx)
41-
responseBytes, err := json.Marshal(response)
42+
llmResponse, err := types.NewLLMResponseFromBytes(body)
4243
if err != nil {
43-
return reqCtx, fmt.Errorf("error marshalling responseBody - %w", err)
44+
logger.Error(err, "failed to create LLMResponse from bytes")
45+
return reqCtx, err
4446
}
45-
if response["usage"] != nil {
46-
usg := response["usage"].(map[string]any)
47-
usage := Usage{
48-
PromptTokens: int(usg["prompt_tokens"].(float64)),
49-
CompletionTokens: int(usg["completion_tokens"].(float64)),
50-
TotalTokens: int(usg["total_tokens"].(float64)),
51-
}
47+
reqCtx.SchedulingResponse = llmResponse
48+
if usage := reqCtx.SchedulingResponse.Usage(); usage != nil {
5249
reqCtx.Usage = usage
53-
logger.V(logutil.VERBOSE).Info("Response generated", "usage", reqCtx.Usage)
50+
logger.V(logutil.VERBOSE).Info("Response generated", "usage", usage)
5451
}
55-
reqCtx.ResponseSize = len(responseBytes)
52+
reqCtx.ResponseSize = len(body)
5653
// ResponseComplete is to indicate the response is complete. In non-streaming
5754
// case, it will be set to be true once the response is processed; in
5855
// streaming case, it will be set to be true once the last chunk is processed.
5956
// TODO(https://github.com/kubernetes-sigs/gateway-api-inference-extension/issues/178)
6057
// will add the processing for streaming case.
6158
reqCtx.ResponseComplete = true
6259

63-
reqCtx.respBodyResp = generateResponseBodyResponses(responseBytes, true)
60+
reqCtx.respBodyResp = generateResponseBodyResponses(body, true)
6461

6562
return s.director.HandleResponseBodyComplete(ctx, reqCtx)
6663
}
6764

6865
// The function is to handle streaming response if the modelServer is streaming.
69-
func (s *StreamingServer) HandleResponseBodyModelStreaming(ctx context.Context, reqCtx *RequestContext, responseText string) {
66+
func (s *StreamingServer) HandleResponseBodyModelStreaming(ctx context.Context, reqCtx *RequestContext, streamBody []byte) {
7067
logger := log.FromContext(ctx)
7168
_, err := s.director.HandleResponseBodyStreaming(ctx, reqCtx, logger)
7269
if err != nil {
7370
logger.Error(err, "error in HandleResponseBodyStreaming")
7471
}
75-
if strings.Contains(responseText, streamingEndMsg) {
72+
if bytes.Contains(streamBody, []byte(streamingEndMsg)) {
7673
reqCtx.ResponseComplete = true
77-
resp := parseRespForUsage(ctx, responseText)
78-
reqCtx.Usage = resp.Usage
79-
metrics.RecordInputTokens(reqCtx.IncomingModelName, reqCtx.TargetModelName, resp.Usage.PromptTokens)
80-
metrics.RecordOutputTokens(reqCtx.IncomingModelName, reqCtx.TargetModelName, resp.Usage.CompletionTokens)
81-
_, err := s.director.HandleResponseBodyComplete(ctx, reqCtx)
74+
resp, err := types.NewLLMResponseFromStream(streamBody)
75+
if err != nil {
76+
logger.Error(err, "error in converting stream response to LLMResponse.")
77+
}
78+
if usage := resp.Usage(); usage != nil {
79+
reqCtx.Usage = usage
80+
metrics.RecordInputTokens(reqCtx.IncomingModelName, reqCtx.TargetModelName, usage.PromptTokens)
81+
metrics.RecordOutputTokens(reqCtx.IncomingModelName, reqCtx.TargetModelName, usage.CompletionTokens)
82+
}
83+
_, err = s.director.HandleResponseBodyComplete(ctx, reqCtx)
8284
if err != nil {
8385
logger.Error(err, "error in HandleResponseBodyComplete")
8486
}

pkg/epp/handlers/response_test.go

Lines changed: 36 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -18,13 +18,13 @@ package handlers
1818

1919
import (
2020
"context"
21-
"encoding/json"
2221
"testing"
2322

2423
"github.com/go-logr/logr"
2524
"github.com/google/go-cmp/cmp"
2625

2726
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend"
27+
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/types"
2828
logutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/logging"
2929
)
3030

@@ -53,12 +53,33 @@ const (
5353
}
5454
`
5555

56-
streamingBodyWithoutUsage = `data: {"id":"cmpl-41764c93-f9d2-4f31-be08-3ba04fa25394","object":"text_completion","created":1740002445,"model":"food-review-0","choices":[],"usage":null}
57-
`
56+
streamingBodyWithoutUsage = `
57+
data: {"id":"chatcmpl-1","object":"chat.completion.chunk","choices":[{"index":0,"delta":{"role":"assistant"}}]}
5858
59-
streamingBodyWithUsage = `data: {"id":"cmpl-41764c93-f9d2-4f31-be08-3ba04fa25394","object":"text_completion","created":1740002445,"model":"food-review-0","choices":[],"usage":{"prompt_tokens":7,"total_tokens":17,"completion_tokens":10}}
60-
data: [DONE]
61-
`
59+
data: {"id":"chatcmpl-1","object":"chat.completion.chunk","choices":[{"index":0,"delta":{"content":"Hello"}}]}
60+
61+
data: {"id":"chatcmpl-1","object":"chat.completion.chunk","choices":[{"index":0,"delta":{"content":" world"}}]}
62+
63+
data: {"id":"chatcmpl-1","object":"chat.completion.chunk","choices":[{"index":0,"delta":{},"finish_reason":"stop"}]}
64+
65+
data: {"id":"chatcmpl-1","object":"chat.completion.chunk","choices":[],"usage":null}
66+
67+
data: [DONE]
68+
`
69+
70+
streamingBodyWithUsage = `
71+
data: {"id":"chatcmpl-1","object":"chat.completion.chunk","choices":[{"index":0,"delta":{"role":"assistant"}}]}
72+
73+
data: {"id":"chatcmpl-1","object":"chat.completion.chunk","choices":[{"index":0,"delta":{"content":"Hello"}}]}
74+
75+
data: {"id":"chatcmpl-1","object":"chat.completion.chunk","choices":[{"index":0,"delta":{"content":" world"}}]}
76+
77+
data: {"id":"chatcmpl-1","object":"chat.completion.chunk","choices":[{"index":0,"delta":{},"finish_reason":"stop"}]}
78+
79+
data: {"id":"chatcmpl-1","object":"chat.completion.chunk","choices":[],"usage":{"prompt_tokens":5,"completion_tokens":7,"total_tokens":12}}
80+
81+
data: [DONE]
82+
`
6283
)
6384

6485
type mockDirector struct{}
@@ -89,13 +110,13 @@ func TestHandleResponseBody(t *testing.T) {
89110
name string
90111
body []byte
91112
reqCtx *RequestContext
92-
want Usage
113+
want *types.Usage
93114
wantErr bool
94115
}{
95116
{
96117
name: "success",
97118
body: []byte(body),
98-
want: Usage{
119+
want: &types.Usage{
99120
PromptTokens: 11,
100121
TotalTokens: 111,
101122
CompletionTokens: 100,
@@ -111,12 +132,7 @@ func TestHandleResponseBody(t *testing.T) {
111132
if reqCtx == nil {
112133
reqCtx = &RequestContext{}
113134
}
114-
var responseMap map[string]any
115-
marshalErr := json.Unmarshal(test.body, &responseMap)
116-
if marshalErr != nil {
117-
t.Error(marshalErr, "Error unmarshaling request body")
118-
}
119-
_, err := server.HandleResponseBody(ctx, reqCtx, responseMap)
135+
_, err := server.HandleResponseBody(ctx, reqCtx, test.body)
120136
if err != nil {
121137
if !test.wantErr {
122138
t.Fatalf("HandleResponseBody returned unexpected error: %v, want %v", err, test.wantErr)
@@ -137,7 +153,7 @@ func TestHandleStreamedResponseBody(t *testing.T) {
137153
name string
138154
body string
139155
reqCtx *RequestContext
140-
want Usage
156+
want *types.Usage
141157
wantErr bool
142158
}{
143159
{
@@ -156,10 +172,10 @@ func TestHandleStreamedResponseBody(t *testing.T) {
156172
modelServerStreaming: true,
157173
},
158174
wantErr: false,
159-
want: Usage{
160-
PromptTokens: 7,
161-
TotalTokens: 17,
162-
CompletionTokens: 10,
175+
want: &types.Usage{
176+
PromptTokens: 5,
177+
TotalTokens: 12,
178+
CompletionTokens: 7,
163179
},
164180
},
165181
}
@@ -172,7 +188,7 @@ func TestHandleStreamedResponseBody(t *testing.T) {
172188
if reqCtx == nil {
173189
reqCtx = &RequestContext{}
174190
}
175-
server.HandleResponseBodyModelStreaming(ctx, reqCtx, test.body)
191+
server.HandleResponseBodyModelStreaming(ctx, reqCtx, []byte(test.body))
176192

177193
if diff := cmp.Diff(test.want, reqCtx.Usage); diff != "" {
178194
t.Errorf("HandleResponseBody returned unexpected response, diff(-want, +got): %v", diff)

pkg/epp/handlers/server.go

Lines changed: 6 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -85,14 +85,15 @@ type RequestContext struct {
8585
RequestReceivedTimestamp time.Time
8686
ResponseCompleteTimestamp time.Time
8787
RequestSize int
88-
Usage Usage
88+
Usage *schedulingtypes.Usage
8989
ResponseSize int
9090
ResponseComplete bool
9191
ResponseStatusCode string
9292
RequestRunning bool
9393
Request *Request
9494

95-
SchedulingRequest *schedulingtypes.LLMRequest
95+
SchedulingRequest *schedulingtypes.LLMRequest
96+
SchedulingResponse *schedulingtypes.LLMResponse
9697

9798
RequestState StreamRequestState
9899
modelServerStreaming bool
@@ -115,7 +116,6 @@ type Request struct {
115116
}
116117
type Response struct {
117118
Headers map[string]string
118-
Body []byte
119119
}
120120
type StreamRequestState int
121121

@@ -268,11 +268,10 @@ func (s *StreamingServer) Process(srv extProcPb.ExternalProcessor_ProcessServer)
268268
reqCtx.respHeaderResp = s.generateResponseHeaderResponse(reqCtx)
269269

270270
case *extProcPb.ProcessingRequest_ResponseBody:
271+
body = append(body, v.ResponseBody.Body...)
271272
if reqCtx.modelServerStreaming {
272273
// Currently we punt on response parsing if the modelServer is streaming, and we just passthrough.
273-
274-
responseText := string(v.ResponseBody.Body)
275-
s.HandleResponseBodyModelStreaming(ctx, reqCtx, responseText)
274+
s.HandleResponseBodyModelStreaming(ctx, reqCtx, body)
276275
if v.ResponseBody.EndOfStream {
277276
loggerTrace.Info("stream completed")
278277

@@ -283,8 +282,6 @@ func (s *StreamingServer) Process(srv extProcPb.ExternalProcessor_ProcessServer)
283282

284283
reqCtx.respBodyResp = generateResponseBodyResponses(v.ResponseBody.Body, v.ResponseBody.EndOfStream)
285284
} else {
286-
body = append(body, v.ResponseBody.Body...)
287-
288285
// Message is buffered, we can read and decode.
289286
if v.ResponseBody.EndOfStream {
290287
loggerTrace.Info("stream completed")
@@ -303,8 +300,7 @@ func (s *StreamingServer) Process(srv extProcPb.ExternalProcessor_ProcessServer)
303300
break
304301
}
305302

306-
reqCtx.Response.Body = body
307-
reqCtx, responseErr = s.HandleResponseBody(ctx, reqCtx, responseBody)
303+
reqCtx, responseErr = s.HandleResponseBody(ctx, reqCtx, body)
308304
if responseErr != nil {
309305
if logger.V(logutil.DEBUG).Enabled() {
310306
logger.V(logutil.DEBUG).Error(responseErr, "Failed to process response body", "request", req)

pkg/epp/requestcontrol/director.go

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -309,12 +309,11 @@ func (d *Director) HandleResponseBodyComplete(ctx context.Context, reqCtx *handl
309309
requestID := reqCtx.Request.Headers[requtil.RequestIdHeaderKey]
310310
logger := log.FromContext(ctx).WithValues("stage", "bodyChunk", requtil.RequestIdHeaderKey, requestID)
311311
logger.V(logutil.DEBUG).Info("Entering HandleResponseBodyComplete")
312-
llmResponse, err := schedulingtypes.NewLLMResponseFromBytes(reqCtx.Response.Body)
313-
if err != nil {
314-
logger.Error(err, "HandleResponseBodyComplete: failed to convert the response to LLMResponse.")
312+
if reqCtx.SchedulingResponse == nil {
313+
err := fmt.Errorf("nil scheduling reponse from reqCtx")
315314
return reqCtx, err
316315
}
317-
d.runResponseCompletePlugins(ctx, reqCtx.SchedulingRequest, llmResponse, reqCtx.TargetPod)
316+
d.runResponseCompletePlugins(ctx, reqCtx.SchedulingRequest, reqCtx.SchedulingResponse, reqCtx.TargetPod)
318317

319318
logger.V(logutil.DEBUG).Info("Exiting HandleResponseBodyComplete")
320319
return reqCtx, nil

pkg/epp/requestcontrol/director_test.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -704,9 +704,9 @@ func TestDirector_HandleResponseComplete(t *testing.T) {
704704
},
705705
Response: &handlers.Response{
706706
Headers: map[string]string{"X-Test-Complete-Header": "CompleteValue"},
707-
Body: []byte(chatCompletionJSON),
708707
},
709-
TargetPod: &backend.Pod{NamespacedName: types.NamespacedName{Namespace: "namespace1", Name: "test-pod-name"}},
708+
SchedulingResponse: wantLLMResponse,
709+
TargetPod: &backend.Pod{NamespacedName: types.NamespacedName{Namespace: "namespace1", Name: "test-pod-name"}},
710710
}
711711

712712
_, err = director.HandleResponseBodyComplete(ctx, reqCtx)

0 commit comments

Comments
 (0)