Skip to content

Commit ae26313

Browse files
committed
implments model rewrite and traffic splitting.
1 parent a3b4528 commit ae26313

File tree

2 files changed

+395
-30
lines changed

2 files changed

+395
-30
lines changed

pkg/epp/requestcontrol/director.go

Lines changed: 61 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ import (
2323
"fmt"
2424
"math/rand"
2525
"net"
26+
"sort"
2627
"strings"
2728
"time"
2829

@@ -52,6 +53,7 @@ type Datastore interface {
5253
PoolGet() (*v1.InferencePool, error)
5354
ObjectiveGet(objectiveName string) *v1alpha2.InferenceObjective
5455
PodList(predicate func(backendmetrics.PodMetrics) bool) []backendmetrics.PodMetrics
56+
RewriteGetAll() []*v1alpha2.InferenceModelRewrite
5557
}
5658

5759
// Scheduler defines the interface required by the Director for scheduling.
@@ -112,34 +114,28 @@ func (d *Director) getInferenceObjective(ctx context.Context, reqCtx *handlers.R
112114
return infObjective
113115
}
114116

115-
// resolveTargetModel is a helper to update reqCtx with target model based on request.
116-
func (d *Director) resolveTargetModel(reqCtx *handlers.RequestContext) (*handlers.RequestContext, error) {
117+
// HandleRequest orchestrates the request lifecycle.
118+
// It always returns the requestContext even in the error case, as the request context is used in error handling.
119+
func (d *Director) HandleRequest(ctx context.Context, reqCtx *handlers.RequestContext) (*handlers.RequestContext, error) {
120+
logger := log.FromContext(ctx)
121+
122+
// Parse Request, Resolve Target Models, and Determine Parameters
117123
requestBodyMap := reqCtx.Request.Body
118124
var ok bool
119125
reqCtx.IncomingModelName, ok = requestBodyMap["model"].(string)
126+
120127
if !ok {
121128
return reqCtx, errutil.Error{Code: errutil.BadRequest, Msg: "model not found in request body"}
122129
}
123130
if reqCtx.TargetModelName == "" {
124131
// Default to incoming model name
125132
reqCtx.TargetModelName = reqCtx.IncomingModelName
126133
}
127-
reqCtx.Request.Body["model"] = reqCtx.TargetModelName
128-
return reqCtx, nil
129-
}
130134

131-
// HandleRequest orchestrates the request lifecycle.
132-
// It always returns the requestContext even in the error case, as the request context is used in error handling.
133-
func (d *Director) HandleRequest(ctx context.Context, reqCtx *handlers.RequestContext) (*handlers.RequestContext, error) {
134-
logger := log.FromContext(ctx)
135+
d.applyWeightedModelRewrite(reqCtx)
135136

136-
// Resolve target model and update req context.
137-
reqCtx, err := d.resolveTargetModel(reqCtx)
138-
if err != nil {
139-
return reqCtx, err
140-
}
137+
reqCtx.Request.Body["model"] = reqCtx.TargetModelName
141138

142-
// Parse request body.
143139
requestBody, err := requtil.ExtractRequestBody(reqCtx.Request.Body)
144140
if err != nil {
145141
return reqCtx, errutil.Error{Code: errutil.BadRequest, Msg: fmt.Errorf("failed to extract request data: %w", err).Error()}
@@ -200,6 +196,56 @@ func (d *Director) HandleRequest(ctx context.Context, reqCtx *handlers.RequestCo
200196
return reqCtx, nil
201197
}
202198

199+
func (d *Director) applyWeightedModelRewrite(reqCtx *handlers.RequestContext) {
200+
rewrites := d.datastore.RewriteGetAll()
201+
if len(rewrites) == 0 {
202+
return
203+
}
204+
205+
sort.Slice(rewrites, func(i, j int) bool {
206+
return rewrites[i].CreationTimestamp.Before(&rewrites[j].CreationTimestamp)
207+
})
208+
209+
for _, rewrite := range rewrites {
210+
for _, rule := range rewrite.Spec.Rules {
211+
for _, match := range rule.Matches {
212+
if match.Model != nil && match.Model.Value == reqCtx.IncomingModelName {
213+
reqCtx.TargetModelName = d.selectWeightedModel(rule.Targets)
214+
return
215+
}
216+
}
217+
}
218+
}
219+
}
220+
221+
func (d *Director) selectWeightedModel(models []v1alpha2.TargetModel) string {
222+
if len(models) == 0 {
223+
return ""
224+
}
225+
226+
var totalWeight int32
227+
for _, model := range models {
228+
totalWeight += model.Weight
229+
}
230+
231+
if totalWeight == 0 {
232+
// If total weight is 0, distribute evenly
233+
return models[rand.Intn(len(models))].ModelRewrite
234+
}
235+
236+
randomNum := rand.Intn(int(totalWeight))
237+
var currentWeight int32
238+
for _, model := range models {
239+
currentWeight += model.Weight
240+
if randomNum < int(currentWeight) {
241+
return model.ModelRewrite
242+
}
243+
}
244+
245+
// Should not happen
246+
return models[len(models)-1].ModelRewrite
247+
}
248+
203249
// getCandidatePodsForScheduling gets the list of relevant endpoints for the scheduling cycle from the datastore.
204250
// according to EPP protocol, if "x-gateway-destination-endpoint-subset" is set on the request metadata and specifies
205251
// a subset of endpoints, only these endpoints will be considered as candidates for the scheduler.

0 commit comments

Comments
 (0)