77from  dataclasses  import  dataclass 
88from  inspect  import  isawaitable 
99from  logging  import  Logger 
10- from  typing  import  Any , Awaitable , Callable , Self , TypeVar , cast 
10+ from  typing  import  Any , Awaitable , Callable , Dict ,  Optional ,  Self , TypeVar , Union ,  cast ,  overload 
1111
1212from  microsoft .teams .common .logging  import  ConsoleLogger 
1313from  pydantic  import  BaseModel 
1414
1515from  .ai_model  import  AIModel 
16- from  .function  import  Function , FunctionHandler 
16+ from  .function  import  Function , FunctionHandler ,  FunctionHandlers ,  FunctionHandlerWithNoParams 
1717from  .memory  import  Memory 
1818from  .message  import  DeferredMessage , FunctionMessage , Message , ModelMessage , SystemMessage , UserMessage 
1919from  .plugin  import  AIPluginProtocol 
@@ -70,17 +70,67 @@ def __init__(
7070        self .logger  =  logger  or  ConsoleLogger ().create_logger ("@teams/ai/chat_prompt" )
7171        self .instructions  =  instructions 
7272
73-     def  with_function (self , function : Function [T ]) ->  Self :
73+     @overload  
74+     def  with_function (self , function : Function [T ]) ->  Self : ...
75+ 
76+     @overload  
77+     def  with_function (
78+         self ,
79+         * ,
80+         name : str ,
81+         description : str ,
82+         parameter_schema : Union [type [T ], Dict [str , Any ]],
83+         handler : FunctionHandlers ,
84+     ) ->  Self : ...
85+ 
86+     @overload  
87+     def  with_function (
88+         self ,
89+         * ,
90+         name : str ,
91+         description : str ,
92+         handler : FunctionHandlerWithNoParams ,
93+     ) ->  Self : ...
94+ 
95+     def  with_function (
96+         self ,
97+         function : Function [T ] |  None  =  None ,
98+         * ,
99+         name : str  |  None  =  None ,
100+         description : str  |  None  =  None ,
101+         parameter_schema : Union [type [T ], Dict [str , Any ], None ] =  None ,
102+         handler : FunctionHandlers  |  None  =  None ,
103+     ) ->  Self :
74104        """ 
75105        Add a function to the available functions for this prompt. 
76106
107+         Can be called in three ways: 
108+         1. with_function(function=Function(...)) 
109+         2. with_function(name=..., description=..., parameter_schema=..., handler=...) 
110+         3. with_function(name=..., description=..., handler=...) - for functions with no parameters 
111+ 
77112        Args: 
78-             function: Function to add to the available functions 
113+             function: Function object to add (first overload) 
114+             name: Function name (second and third overload) 
115+             description: Function description (second and third overload) 
116+             parameter_schema: Function parameter schema (second overload, optional) 
117+             handler: Function handler (second and third overload) 
79118
80119        Returns: 
81120            Self for method chaining 
82121        """ 
83-         self .functions [function .name ] =  function 
122+         if  function  is  not   None :
123+             self .functions [function .name ] =  function 
124+         else :
125+             if  name  is  None  or  description  is  None  or  handler  is  None :
126+                 raise  ValueError ("When not providing a Function object, name, description, and handler are required" )
127+             func  =  Function [T ](
128+                 name = name ,
129+                 description = description ,
130+                 parameter_schema = parameter_schema ,
131+                 handler = handler ,
132+             )
133+             self .functions [func .name ] =  func 
84134        return  self 
85135
86136    def  with_plugin (self , plugin : AIPluginProtocol ) ->  Self :
@@ -259,9 +309,7 @@ async def on_chunk_fn(chunk: str):
259309
260310        return  ChatSendResult (response = current_response )
261311
262-     def  _wrap_function_handler (
263-         self , original_handler : FunctionHandler [BaseModel ], function_name : str 
264-     ) ->  FunctionHandler [BaseModel ]:
312+     def  _wrap_function_handler (self , original_handler : FunctionHandlers , function_name : str ) ->  FunctionHandlers :
265313        """ 
266314        Wrap a function handler with plugin before/after hooks. 
267315
@@ -276,20 +324,28 @@ def _wrap_function_handler(
276324            Wrapped handler that includes plugin hook execution 
277325        """ 
278326
279-         async  def  wrapped_handler (params : BaseModel ) ->  str :
327+         async  def  wrapped_handler (params : Optional [ BaseModel ] ) ->  str :
280328            # Run before function call hooks 
281329            for  plugin  in  self .plugins :
282330                await  plugin .on_before_function_call (function_name , params )
283331
284-             # Call the original function (could be sync or async) 
285-             result  =  original_handler (params )
286-             if  isawaitable (result ):
287-                 result  =  await  result 
332+             if  params :
333+                 # Call the original function with params (could be sync or async) 
334+                 casted_handler  =  cast (FunctionHandler [BaseModel ], original_handler )
335+                 result  =  casted_handler (params )
336+                 if  isawaitable (result ):
337+                     result  =  await  result 
338+             else :
339+                 # Function with no parameters case 
340+                 casted_handler  =  cast (FunctionHandlerWithNoParams , original_handler )
341+                 result  =  casted_handler ()
342+                 if  isawaitable (result ):
343+                     result  =  await  result 
288344
289345            # Run after function call hooks 
290346            current_result  =  result 
291347            for  plugin  in  self .plugins :
292-                 plugin_result  =  await  plugin .on_after_function_call (function_name , params ,  current_result )
348+                 plugin_result  =  await  plugin .on_after_function_call (function_name , current_result ,  params )
293349                if  plugin_result  is  not   None :
294350                    current_result  =  plugin_result 
295351
0 commit comments