@@ -64,6 +64,7 @@ export type AxGenerateResult<OUT extends AxGenOut> = OUT & {
6464
6565export interface AxResponseHandlerArgs < T > {
6666 ai : Readonly < AxAIService > ;
67+ model ?: string ;
6768 sig : Readonly < AxSignature > ;
6869 res : T ;
6970 usageInfo : { ai : string ; model : string } ;
@@ -110,10 +111,10 @@ export class AxGen<
110111 }
111112 }
112113
113- private updateSigForFunctions = ( ai : AxAIService ) => {
114+ private updateSigForFunctions = ( ai : AxAIService , model ?: string ) => {
114115 // AI supports function calling natively so
115116 // no need to add fields for function call
116- if ( ai . getFeatures ( ) . functions ) {
117+ if ( ai . getFeatures ( model ) . functions ) {
117118 return ;
118119 }
119120
@@ -239,6 +240,7 @@ export class AxGen<
239240 if ( res instanceof ReadableStream ) {
240241 return ( await this . processSteamingResponse ( {
241242 ai,
243+ model,
242244 sig,
243245 res,
244246 usageInfo,
@@ -250,6 +252,7 @@ export class AxGen<
250252
251253 return ( await this . processResponse ( {
252254 ai,
255+ model,
253256 sig,
254257 res,
255258 usageInfo,
@@ -262,6 +265,7 @@ export class AxGen<
262265 private async processSteamingResponse ( {
263266 ai,
264267 sig,
268+ model,
265269 res,
266270 usageInfo,
267271 mem,
@@ -313,7 +317,7 @@ export class AxGen<
313317 }
314318 }
315319
316- const funcs = parseFunctions ( ai , functionCalls , values ) ;
320+ const funcs = parseFunctions ( ai , functionCalls , values , model ) ;
317321 if ( funcs ) {
318322 await this . processFunctions ( ai , funcs , mem , sessionId , traceId ) ;
319323 }
@@ -372,7 +376,7 @@ export class AxGen<
372376 const maxRetries = options ?. maxRetries ?? this . options ?. maxRetries ?? 5 ;
373377 const maxSteps = options ?. maxSteps ?? this . options ?. maxSteps ?? 10 ;
374378 const mem = options ?. mem ?? this . options ?. mem ?? new AxMemory ( ) ;
375- const canStream = ai . getFeatures ( ) . streaming ;
379+ const canStream = ai . getFeatures ( options ?. model ) . streaming ;
376380
377381 let err : ValidationError | AxAssertionError | undefined ;
378382
@@ -453,7 +457,8 @@ export class AxGen<
453457 values : IN ,
454458 options ?: Readonly < AxProgramForwardOptions >
455459 ) : Promise < OUT > {
456- const sig = this . updateSigForFunctions ( ai ) ?? this . signature ;
460+ const sig =
461+ this . updateSigForFunctions ( ai , options ?. model ) ?? this . signature ;
457462
458463 const tracer = this . options ?. tracer ?? options ?. tracer ;
459464
@@ -515,12 +520,13 @@ export class AxGen<
515520function parseFunctions (
516521 ai : Readonly < AxAIService > ,
517522 functionCalls : Readonly < AxChatResponseResult [ 'functionCalls' ] > ,
518- values : Record < string , unknown >
523+ values : Record < string , unknown > ,
524+ model ?: string
519525) : AxChatResponseFunctionCall [ ] | undefined {
520526 if ( ! functionCalls || functionCalls . length === 0 ) {
521527 return ;
522528 }
523- if ( ai . getFeatures ( ) . functions ) {
529+ if ( ai . getFeatures ( model ) . functions ) {
524530 const funcs : AxChatResponseFunctionCall [ ] = functionCalls . map ( ( f ) => ( {
525531 id : f . id ,
526532 name : f . function . name ,
0 commit comments