@@ -23,6 +23,7 @@ import (
2323 "fmt"
2424 "math/rand"
2525 "net"
26+ "sort"
2627 "strings"
2728 "time"
2829
@@ -52,6 +53,7 @@ type Datastore interface {
5253 PoolGet () (* v1.InferencePool , error )
5354 ObjectiveGet (objectiveName string ) * v1alpha2.InferenceObjective
5455 PodList (predicate func (backendmetrics.PodMetrics ) bool ) []backendmetrics.PodMetrics
56+ RewriteGetAll () []* v1alpha2.InferenceModelRewrite
5557}
5658
5759// Scheduler defines the interface required by the Director for scheduling.
@@ -112,34 +114,28 @@ func (d *Director) getInferenceObjective(ctx context.Context, reqCtx *handlers.R
112114 return infObjective
113115}
114116
115- // resolveTargetModel is a helper to update reqCtx with target model based on request.
116- func (d * Director ) resolveTargetModel (reqCtx * handlers.RequestContext ) (* handlers.RequestContext , error ) {
117+ // HandleRequest orchestrates the request lifecycle.
118+ // It always returns the requestContext even in the error case, as the request context is used in error handling.
119+ func (d * Director ) HandleRequest (ctx context.Context , reqCtx * handlers.RequestContext ) (* handlers.RequestContext , error ) {
120+ logger := log .FromContext (ctx )
121+
122+ // Parse Request, Resolve Target Models, and Determine Parameters
117123 requestBodyMap := reqCtx .Request .Body
118124 var ok bool
119125 reqCtx .IncomingModelName , ok = requestBodyMap ["model" ].(string )
126+
120127 if ! ok {
121128 return reqCtx , errutil.Error {Code : errutil .BadRequest , Msg : "model not found in request body" }
122129 }
123130 if reqCtx .TargetModelName == "" {
124131 // Default to incoming model name
125132 reqCtx .TargetModelName = reqCtx .IncomingModelName
126133 }
127- reqCtx .Request .Body ["model" ] = reqCtx .TargetModelName
128- return reqCtx , nil
129- }
130134
131- // HandleRequest orchestrates the request lifecycle.
132- // It always returns the requestContext even in the error case, as the request context is used in error handling.
133- func (d * Director ) HandleRequest (ctx context.Context , reqCtx * handlers.RequestContext ) (* handlers.RequestContext , error ) {
134- logger := log .FromContext (ctx )
135+ d .applyWeightedModelRewrite (reqCtx )
135136
136- // Resolve target model and update req context.
137- reqCtx , err := d .resolveTargetModel (reqCtx )
138- if err != nil {
139- return reqCtx , err
140- }
137+ reqCtx .Request .Body ["model" ] = reqCtx .TargetModelName
141138
142- // Parse request body.
143139 requestBody , err := requtil .ExtractRequestBody (reqCtx .Request .Body )
144140 if err != nil {
145141 return reqCtx , errutil.Error {Code : errutil .BadRequest , Msg : fmt .Errorf ("failed to extract request data: %w" , err ).Error ()}
@@ -200,6 +196,56 @@ func (d *Director) HandleRequest(ctx context.Context, reqCtx *handlers.RequestCo
200196 return reqCtx , nil
201197}
202198
199+ func (d * Director ) applyWeightedModelRewrite (reqCtx * handlers.RequestContext ) {
200+ rewrites := d .datastore .RewriteGetAll ()
201+ if len (rewrites ) == 0 {
202+ return
203+ }
204+
205+ sort .Slice (rewrites , func (i , j int ) bool {
206+ return rewrites [i ].CreationTimestamp .Before (& rewrites [j ].CreationTimestamp )
207+ })
208+
209+ for _ , rewrite := range rewrites {
210+ for _ , rule := range rewrite .Spec .Rules {
211+ for _ , match := range rule .Matches {
212+ if match .Model != nil && match .Model .Value == reqCtx .IncomingModelName {
213+ reqCtx .TargetModelName = d .selectWeightedModel (rule .Targets )
214+ return
215+ }
216+ }
217+ }
218+ }
219+ }
220+
221+ func (d * Director ) selectWeightedModel (models []v1alpha2.TargetModel ) string {
222+ if len (models ) == 0 {
223+ return ""
224+ }
225+
226+ var totalWeight int32
227+ for _ , model := range models {
228+ totalWeight += model .Weight
229+ }
230+
231+ if totalWeight == 0 {
232+ // If total weight is 0, distribute evenly
233+ return models [rand .Intn (len (models ))].ModelRewrite
234+ }
235+
236+ randomNum := rand .Intn (int (totalWeight ))
237+ var currentWeight int32
238+ for _ , model := range models {
239+ currentWeight += model .Weight
240+ if randomNum < int (currentWeight ) {
241+ return model .ModelRewrite
242+ }
243+ }
244+
245+ // Should not happen
246+ return models [len (models )- 1 ].ModelRewrite
247+ }
248+
203249// getCandidatePodsForScheduling gets the list of relevant endpoints for the scheduling cycle from the datastore.
204250// according to EPP protocol, if "x-gateway-destination-endpoint-subset" is set on the request metadata and specifies
205251// a subset of endpoints, only these endpoints will be considered as candidates for the scheduler.
0 commit comments