1+ const fs = require ( "fs" ) ;
2+ const path = require ( "path" ) ;
3+ const { safeJsonParse } = require ( "../../http" ) ;
14const { NativeEmbedder } = require ( "../../EmbeddingEngines/native" ) ;
25const {
36 LLMPerformanceMonitor,
@@ -6,13 +9,16 @@ const {
69 handleDefaultStreamResponseV2,
710} = require ( "../../helpers/chat/responses" ) ;
811
9- function fireworksAiModels ( ) {
10- const { MODELS } = require ( "./models.js" ) ;
11- return MODELS || { } ;
12- }
12+ const cacheFolder = path . resolve (
13+ process . env . STORAGE_DIR
14+ ? path . resolve ( process . env . STORAGE_DIR , "models" , "fireworks" )
15+ : path . resolve ( __dirname , `../../../storage/models/fireworks` )
16+ ) ;
1317
1418class FireworksAiLLM {
1519 constructor ( embedder = null , modelPreference = null ) {
20+ this . className = "FireworksAiLLM" ;
21+
1622 if ( ! process . env . FIREWORKS_AI_LLM_API_KEY )
1723 throw new Error ( "No FireworksAI API key was set." ) ;
1824 const { OpenAI : OpenAIApi } = require ( "openai" ) ;
@@ -29,6 +35,51 @@ class FireworksAiLLM {
2935
3036 this . embedder = ! embedder ? new NativeEmbedder ( ) : embedder ;
3137 this . defaultTemp = 0.7 ;
38+
39+ if ( ! fs . existsSync ( cacheFolder ) )
40+ fs . mkdirSync ( cacheFolder , { recursive : true } ) ;
41+ this . cacheModelPath = path . resolve ( cacheFolder , "models.json" ) ;
42+ this . cacheAtPath = path . resolve ( cacheFolder , ".cached_at" ) ;
43+ }
44+
45+ log ( text , ...args ) {
46+ console . log ( `\x1b[36m[${ this . className } ]\x1b[0m ${ text } ` , ...args ) ;
47+ }
48+
49+ // This checks if the .cached_at file has a timestamp that is more than 1Week (in millis)
50+ // from the current date. If it is, then we will refetch the API so that all the models are up
51+ // to date.
52+ #cacheIsStale( ) {
53+ const MAX_STALE = 6.048e8 ; // 1 Week in MS
54+ if ( ! fs . existsSync ( this . cacheAtPath ) ) return true ;
55+ const now = Number ( new Date ( ) ) ;
56+ const timestampMs = Number ( fs . readFileSync ( this . cacheAtPath ) ) ;
57+ return now - timestampMs > MAX_STALE ;
58+ }
59+
60+ // This function fetches the models from the ApiPie API and caches them locally.
61+ // We do this because the ApiPie API has a lot of models, and we need to get the proper token context window
62+ // for each model and this is a constructor property - so we can really only get it if this cache exists.
63+ // We used to have this as a chore, but given there is an API to get the info - this makes little sense.
64+ // This might slow down the first request, but we need the proper token context window
65+ // for each model and this is a constructor property - so we can really only get it if this cache exists.
66+ async #syncModels( ) {
67+ if ( fs . existsSync ( this . cacheModelPath ) && ! this . #cacheIsStale( ) )
68+ return false ;
69+
70+ this . log (
71+ "Model cache is not present or stale. Fetching from FireworksAI API."
72+ ) ;
73+ await fireworksAiModels ( ) ;
74+ return ;
75+ }
76+
77+ models ( ) {
78+ if ( ! fs . existsSync ( this . cacheModelPath ) ) return { } ;
79+ return safeJsonParse (
80+ fs . readFileSync ( this . cacheModelPath , { encoding : "utf-8" } ) ,
81+ { }
82+ ) ;
3283 }
3384
3485 #appendContext( contextTexts = [ ] ) {
@@ -43,28 +94,31 @@ class FireworksAiLLM {
4394 ) ;
4495 }
4596
46- allModelInformation ( ) {
47- return fireworksAiModels ( ) ;
48- }
49-
5097 streamingEnabled ( ) {
5198 return "streamGetChatCompletion" in this ;
5299 }
53100
54101 static promptWindowLimit ( modelName ) {
55- const availableModels = fireworksAiModels ( ) ;
102+ const cacheModelPath = path . resolve ( cacheFolder , "models.json" ) ;
103+ const availableModels = fs . existsSync ( cacheModelPath )
104+ ? safeJsonParse (
105+ fs . readFileSync ( cacheModelPath , { encoding : "utf-8" } ) ,
106+ { }
107+ )
108+ : { } ;
56109 return availableModels [ modelName ] ?. maxLength || 4096 ;
57110 }
58111
59112 // Ensure the user set a value for the token limit
60113 // and if undefined - assume 4096 window.
61114 promptWindowLimit ( ) {
62- const availableModels = this . allModelInformation ( ) ;
115+ const availableModels = this . models ( ) ;
63116 return availableModels [ this . model ] ?. maxLength || 4096 ;
64117 }
65118
66119 async isValidChatCompletionModel ( model = "" ) {
67- const availableModels = this . allModelInformation ( ) ;
120+ await this . #syncModels( ) ;
121+ const availableModels = this . models ( ) ;
68122 return availableModels . hasOwnProperty ( model ) ;
69123 }
70124
@@ -151,6 +205,63 @@ class FireworksAiLLM {
151205 }
152206}
153207
208+ async function fireworksAiModels ( providedApiKey = null ) {
209+ const apiKey = providedApiKey || process . env . FIREWORKS_AI_LLM_API_KEY || null ;
210+ const { OpenAI : OpenAIApi } = require ( "openai" ) ;
211+ const client = new OpenAIApi ( {
212+ baseURL : "https://api.fireworks.ai/inference/v1" ,
213+ apiKey : apiKey ,
214+ } ) ;
215+
216+ return await client . models
217+ . list ( )
218+ . then ( ( res ) => res . data )
219+ . then ( ( models = [ ] ) => {
220+ const validModels = { } ;
221+ models . forEach ( ( model ) => {
222+ // There are many models - the ones without a context length are not chat models
223+ if ( ! model . hasOwnProperty ( "context_length" ) ) return ;
224+
225+ validModels [ model . id ] = {
226+ id : model . id ,
227+ name : model . id . split ( "/" ) . pop ( ) ,
228+ organization : model . owned_by ,
229+ subtype : model . type ,
230+ maxLength : model . context_length ?? 4096 ,
231+ } ;
232+ } ) ;
233+
234+ if ( Object . keys ( validModels ) . length === 0 ) {
235+ console . log ( "fireworksAi: No models found" ) ;
236+ return { } ;
237+ }
238+
239+ // Cache all response information
240+ if ( ! fs . existsSync ( cacheFolder ) )
241+ fs . mkdirSync ( cacheFolder , { recursive : true } ) ;
242+ fs . writeFileSync (
243+ path . resolve ( cacheFolder , "models.json" ) ,
244+ JSON . stringify ( validModels ) ,
245+ {
246+ encoding : "utf-8" ,
247+ }
248+ ) ;
249+ fs . writeFileSync (
250+ path . resolve ( cacheFolder , ".cached_at" ) ,
251+ String ( Number ( new Date ( ) ) ) ,
252+ {
253+ encoding : "utf-8" ,
254+ }
255+ ) ;
256+
257+ return validModels ;
258+ } )
259+ . catch ( ( e ) => {
260+ console . error ( e ) ;
261+ return { } ;
262+ } ) ;
263+ }
264+
154265module . exports = {
155266 FireworksAiLLM,
156267 fireworksAiModels,
0 commit comments