Skip to content

Commit 8f06b16

Browse files
committed
feat: added support for o1 models
1 parent 1348ce4 commit 8f06b16

File tree

6 files changed

+53
-21
lines changed

6 files changed

+53
-21
lines changed

src/ax/ai/balance.ts

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -81,8 +81,8 @@ export class AxBalancer implements AxAIService {
8181
return this.currentService.getModelConfig();
8282
}
8383

84-
getFeatures() {
85-
return this.currentService.getFeatures();
84+
getFeatures(model?: string) {
85+
return this.currentService.getFeatures(model);
8686
}
8787

8888
async chat(

src/ax/ai/base.ts

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ export interface AxBaseAIArgs {
3737
modelInfo: Readonly<AxModelInfo[]>;
3838
models: Readonly<{ model: string; embedModel?: string }>;
3939
options?: Readonly<AxAIServiceOptions>;
40-
supportFor: AxBaseAIFeatures;
40+
supportFor: AxBaseAIFeatures | ((model: string) => AxBaseAIFeatures);
4141
modelMap?: AxAIModelMap;
4242
}
4343

@@ -95,7 +95,9 @@ export class AxBaseAI<
9595
protected apiURL: string;
9696
protected name: string;
9797
protected headers: Record<string, string>;
98-
protected supportFor: AxBaseAIFeatures;
98+
protected supportFor:
99+
| AxBaseAIFeatures
100+
| ((model: string) => AxBaseAIFeatures);
99101

100102
constructor({
101103
name,
@@ -197,8 +199,10 @@ export class AxBaseAI<
197199
return this.name;
198200
}
199201

200-
getFeatures(): AxBaseAIFeatures {
201-
return this.supportFor;
202+
getFeatures(model?: string): AxBaseAIFeatures {
203+
return typeof this.supportFor === 'function'
204+
? this.supportFor(model ?? this.models.model)
205+
: this.supportFor;
202206
}
203207

204208
getModelConfig(): AxModelConfig {
@@ -211,7 +215,7 @@ export class AxBaseAI<
211215
): Promise<AxChatResponse | ReadableStream<AxChatResponse>> {
212216
const model = req.model
213217
? this.modelMap?.[req.model] ?? req.model
214-
: this.models.model;
218+
: this.modelMap?.[this.models.model] ?? this.models.model;
215219

216220
if (this.tracer) {
217221
const mc = this.getModelConfig();
@@ -366,7 +370,7 @@ export class AxBaseAI<
366370
): Promise<AxEmbedResponse> {
367371
const embedModel = req.embedModel
368372
? this.modelMap?.[req.embedModel] ?? req.embedModel
369-
: this.models.embedModel;
373+
: this.modelMap?.[this.models.embedModel ?? ''] ?? this.models.embedModel;
370374

371375
if (!embedModel) {
372376
throw new Error('No embed model defined');

src/ax/ai/openai/api.ts

Lines changed: 25 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@ import {
77
import type {
88
AxAIPromptConfig,
99
AxAIServiceOptions,
10-
AxChatRequest,
1110
AxChatResponse,
1211
AxChatResponseResult,
1312
AxEmbedResponse,
@@ -101,7 +100,11 @@ export class AxAIOpenAI extends AxBaseAI<
101100
embedModel: _config.embedModel as string
102101
},
103102
options,
104-
supportFor: { functions: true, streaming: true },
103+
supportFor: (model: string) => {
104+
return isO1Model(model)
105+
? { functions: false, streaming: false }
106+
: { functions: true, streaming: true };
107+
},
105108
modelMap
106109
});
107110
this.config = _config;
@@ -147,6 +150,10 @@ export class AxAIOpenAI extends AxBaseAI<
147150
}
148151
}));
149152

153+
if (tools && isO1Model(model)) {
154+
throw new Error('Functions are not supported for O1 models');
155+
}
156+
150157
const toolsChoice =
151158
!req.functionCall && req.functions && req.functions.length > 0
152159
? 'auto'
@@ -159,6 +166,10 @@ export class AxAIOpenAI extends AxBaseAI<
159166

160167
const stream = req.modelConfig?.stream ?? this.config.stream;
161168

169+
if (stream && isO1Model(model)) {
170+
throw new Error('Streaming is not supported for O1 models');
171+
}
172+
162173
const reqValue: AxAIOpenAIChatRequest = {
163174
model,
164175
messages,
@@ -355,9 +366,15 @@ const mapFinishReason = (
355366
};
356367

357368
function createMessages(
358-
req: Readonly<AxChatRequest>
369+
req: Readonly<AxInternalChatRequest>
359370
): AxAIOpenAIChatRequest['messages'] {
360371
return req.chatPrompt.map((msg) => {
372+
if (msg.role === 'system' && isO1Model(req.model)) {
373+
msg = {
374+
role: 'user',
375+
content: msg.content
376+
};
377+
}
361378
switch (msg.role) {
362379
case 'system':
363380
return { role: 'system' as const, content: msg.content };
@@ -412,3 +429,8 @@ function createMessages(
412429
}
413430
});
414431
}
432+
433+
const isO1Model = (model: string): boolean =>
434+
[AxAIOpenAIModel.O1Mini, AxAIOpenAIModel.O1Preview].includes(
435+
model as AxAIOpenAIModel
436+
);

src/ax/ai/types.ts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -185,7 +185,7 @@ export interface AxAIService {
185185
getModelInfo(): Readonly<AxModelInfoWithProvider>;
186186
getEmbedModelInfo(): Readonly<AxModelInfoWithProvider> | undefined;
187187
getModelConfig(): Readonly<AxModelConfig>;
188-
getFeatures(): { functions: boolean; streaming: boolean };
188+
getFeatures(model?: string): { functions: boolean; streaming: boolean };
189189
getModelMap(): AxAIModelMap | undefined;
190190

191191
chat(

src/ax/ai/wrap.ts

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -142,8 +142,8 @@ export class AxAI implements AxAIService {
142142
return this.ai.getModelConfig();
143143
}
144144

145-
getFeatures(): { functions: boolean; streaming: boolean } {
146-
return this.ai.getFeatures();
145+
getFeatures(model?: string): { functions: boolean; streaming: boolean } {
146+
return this.ai.getFeatures(model);
147147
}
148148

149149
getModelMap(): AxAIModelMap | undefined {

src/ax/dsp/generate.ts

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,7 @@ export type AxGenerateResult<OUT extends AxGenOut> = OUT & {
6464

6565
export interface AxResponseHandlerArgs<T> {
6666
ai: Readonly<AxAIService>;
67+
model?: string;
6768
sig: Readonly<AxSignature>;
6869
res: T;
6970
usageInfo: { ai: string; model: string };
@@ -110,10 +111,10 @@ export class AxGen<
110111
}
111112
}
112113

113-
private updateSigForFunctions = (ai: AxAIService) => {
114+
private updateSigForFunctions = (ai: AxAIService, model?: string) => {
114115
// AI supports function calling natively so
115116
// no need to add fields for function call
116-
if (ai.getFeatures().functions) {
117+
if (ai.getFeatures(model).functions) {
117118
return;
118119
}
119120

@@ -239,6 +240,7 @@ export class AxGen<
239240
if (res instanceof ReadableStream) {
240241
return (await this.processSteamingResponse({
241242
ai,
243+
model,
242244
sig,
243245
res,
244246
usageInfo,
@@ -250,6 +252,7 @@ export class AxGen<
250252

251253
return (await this.processResponse({
252254
ai,
255+
model,
253256
sig,
254257
res,
255258
usageInfo,
@@ -262,6 +265,7 @@ export class AxGen<
262265
private async processSteamingResponse({
263266
ai,
264267
sig,
268+
model,
265269
res,
266270
usageInfo,
267271
mem,
@@ -313,7 +317,7 @@ export class AxGen<
313317
}
314318
}
315319

316-
const funcs = parseFunctions(ai, functionCalls, values);
320+
const funcs = parseFunctions(ai, functionCalls, values, model);
317321
if (funcs) {
318322
await this.processFunctions(ai, funcs, mem, sessionId, traceId);
319323
}
@@ -372,7 +376,7 @@ export class AxGen<
372376
const maxRetries = options?.maxRetries ?? this.options?.maxRetries ?? 5;
373377
const maxSteps = options?.maxSteps ?? this.options?.maxSteps ?? 10;
374378
const mem = options?.mem ?? this.options?.mem ?? new AxMemory();
375-
const canStream = ai.getFeatures().streaming;
379+
const canStream = ai.getFeatures(options?.model).streaming;
376380

377381
let err: ValidationError | AxAssertionError | undefined;
378382

@@ -453,7 +457,8 @@ export class AxGen<
453457
values: IN,
454458
options?: Readonly<AxProgramForwardOptions>
455459
): Promise<OUT> {
456-
const sig = this.updateSigForFunctions(ai) ?? this.signature;
460+
const sig =
461+
this.updateSigForFunctions(ai, options?.model) ?? this.signature;
457462

458463
const tracer = this.options?.tracer ?? options?.tracer;
459464

@@ -515,12 +520,13 @@ export class AxGen<
515520
function parseFunctions(
516521
ai: Readonly<AxAIService>,
517522
functionCalls: Readonly<AxChatResponseResult['functionCalls']>,
518-
values: Record<string, unknown>
523+
values: Record<string, unknown>,
524+
model?: string
519525
): AxChatResponseFunctionCall[] | undefined {
520526
if (!functionCalls || functionCalls.length === 0) {
521527
return;
522528
}
523-
if (ai.getFeatures().functions) {
529+
if (ai.getFeatures(model).functions) {
524530
const funcs: AxChatResponseFunctionCall[] = functionCalls.map((f) => ({
525531
id: f.id,
526532
name: f.function.name,

0 commit comments

Comments
 (0)