@@ -433,7 +433,7 @@ def __init__(
433433 min_retry_wait_time = min_retry_wait_time ,
434434 client_class = OpenAI ,
435435 client_args = client_args ,
436- pricing_func = tracking .get_pricing_openai ,
436+ pricing_func = tracking .partial ( tracking . get_pricing_litellm , model_name = model_name ) ,
437437 log_probs = log_probs ,
438438 )
439439
@@ -492,6 +492,7 @@ def __init__(
492492 temperature = 0.5 ,
493493 max_tokens = 100 ,
494494 max_retry = 4 ,
495+ pricing_func = None ,
495496 ):
496497 self .model_name = model_name
497498 self .temperature = temperature
@@ -501,6 +502,22 @@ def __init__(
501502 api_key = api_key or os .getenv ("ANTHROPIC_API_KEY" )
502503 self .client = anthropic .Anthropic (api_key = api_key )
503504
505+ # Get pricing information
506+ if pricing_func :
507+ pricings = pricing_func ()
508+ try :
509+ self .input_cost = float (pricings [model_name ]["prompt" ])
510+ self .output_cost = float (pricings [model_name ]["completion" ])
511+ except KeyError :
512+ logging .warning (
513+ f"Model { model_name } not found in the pricing information, prices are set to 0. Maybe try upgrading langchain_community."
514+ )
515+ self .input_cost = 0.0
516+ self .output_cost = 0.0
517+ else :
518+ self .input_cost = 0.0
519+ self .output_cost = 0.0
520+
504521 def __call__ (self , messages : list [dict ], n_samples : int = 1 , temperature : float = None ) -> dict :
505522 # Convert OpenAI format to Anthropic format
506523 system_message = None
@@ -528,13 +545,29 @@ def __call__(self, messages: list[dict], n_samples: int = 1, temperature: float
528545
529546 response = self .client .messages .create (** kwargs )
530547
548+ usage = getattr (response , "usage" , {})
549+ new_input_tokens = getattr (usage , "input_tokens" , 0 )
550+ output_tokens = getattr (usage , "output_tokens" , 0 )
551+ cache_read_tokens = getattr (usage , "cache_input_tokens" , 0 )
552+ cache_write_tokens = getattr (usage , "cache_creation_input_tokens" , 0 )
553+ cache_read_cost = (
554+ self .input_cost * tracking .ANTHROPIC_CACHE_PRICING_FACTOR ["cache_read_tokens" ]
555+ )
556+ cache_write_cost = (
557+ self .input_cost * tracking .ANTHROPIC_CACHE_PRICING_FACTOR ["cache_write_tokens" ]
558+ )
559+ cost = (
560+ new_input_tokens * self .input_cost
561+ + output_tokens * self .output_cost
562+ + cache_read_tokens * cache_read_cost
563+ + cache_write_tokens * cache_write_cost
564+ )
565+
531566 # Track usage if available
532- if hasattr (tracking .TRACKER , "instance" ):
533- tracking .TRACKER .instance (
534- response .usage .input_tokens ,
535- response .usage .output_tokens ,
536- 0 , # cost calculation would need pricing info
537- )
567+ if hasattr (tracking .TRACKER , "instance" ) and isinstance (
568+ tracking .TRACKER .instance , tracking .LLMTracker
569+ ):
570+ tracking .TRACKER .instance (new_input_tokens , output_tokens , cost )
538571
539572 return AIMessage (response .content [0 ].text )
540573
@@ -552,6 +585,7 @@ def make_model(self):
552585 model_name = self .model_name ,
553586 temperature = self .temperature ,
554587 max_tokens = self .max_new_tokens ,
588+ pricing_func = partial (tracking .get_pricing_litellm , model_name = self .model_name ),
555589 )
556590
557591
0 commit comments