From f1f18530bc9d94262d0dd813894d7bb1ac68c428 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=89=AC=E5=A4=9C?= Date: Mon, 19 Aug 2024 11:38:15 +0800 Subject: [PATCH 1/4] adopt to tool_calls format --- sgpt/handlers/handler.py | 17 ++++++++++++++--- 1 file changed, 14 insertions(+), 3 deletions(-) diff --git a/sgpt/handlers/handler.py b/sgpt/handlers/handler.py index c6302132..cad9f22a 100644 --- a/sgpt/handlers/handler.py +++ b/sgpt/handlers/handler.py @@ -58,6 +58,7 @@ def make_messages(self, prompt: str) -> List[Dict[str, str]]: def handle_function_call( self, messages: List[dict[str, Any]], + id: str, name: str, arguments: str, ) -> Generator[str, None, None]: @@ -65,7 +66,13 @@ def handle_function_call( { "role": "assistant", "content": "", - "function_call": {"name": name, "arguments": arguments}, + "tool_calls": [ + { + "id": id, + "type": "function", + "function": {"name": name, "arguments": arguments}, + } + ], } ) @@ -79,7 +86,9 @@ def handle_function_call( result = get_function(name)(**dict_args) if cfg.get("SHOW_FUNCTIONS_OUTPUT") == "true": yield f"```text\n{result}\n```\n" - messages.append({"role": "function", "content": result, "name": name}) + messages.append( + {"role": "tool", "content": result, "tool_call_id": id, "name": name} + ) @cache def get_completion( @@ -121,12 +130,14 @@ def get_completion( ) if tool_calls: for tool_call in tool_calls: + if tool_call.id: + id = tool_call.id if tool_call.function.name: name = tool_call.function.name if tool_call.function.arguments: arguments += tool_call.function.arguments if chunk.choices[0].finish_reason == "tool_calls": - yield from self.handle_function_call(messages, name, arguments) + yield from self.handle_function_call(messages, id, name, arguments) yield from self.get_completion( model=model, temperature=temperature, From e2ce0bac72575633669b72352ea242c233d91bd7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=89=AC=E5=A4=9C?= Date: Wed, 21 Aug 2024 14:08:08 +0800 Subject: [PATCH 2/4] Update import statement in __main__.py --- sgpt/__main__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sgpt/__main__.py b/sgpt/__main__.py index 1c7bc113..1e1c9e8e 100644 --- a/sgpt/__main__.py +++ b/sgpt/__main__.py @@ -1,3 +1,3 @@ -from .app import entry_point +from sgpt.app import entry_point entry_point() From 1c3f4061581e499b9a8c1d643453fb8063af2133 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=89=AC=E5=A4=9C?= Date: Thu, 29 Aug 2024 16:17:24 +0800 Subject: [PATCH 3/4] add parallel function call feature --- sgpt/handlers/handler.py | 20 ++++++++++++++------ 1 file changed, 14 insertions(+), 6 deletions(-) diff --git a/sgpt/handlers/handler.py b/sgpt/handlers/handler.py index cad9f22a..aa189bb6 100644 --- a/sgpt/handlers/handler.py +++ b/sgpt/handlers/handler.py @@ -121,6 +121,8 @@ def get_completion( ) try: + arguments_map = {} + name_map = {} for chunk in response: delta = chunk.choices[0].delta @@ -128,16 +130,22 @@ def get_completion( tool_calls = ( delta.get("tool_calls") if use_litellm else delta.tool_calls ) + if tool_calls: for tool_call in tool_calls: - if tool_call.id: + if tool_call.id is not None: id = tool_call.id - if tool_call.function.name: - name = tool_call.function.name - if tool_call.function.arguments: - arguments += tool_call.function.arguments + name_map[id] = tool_call.function.name + arguments_map[id] = "" + else: + arguments_map[id] += tool_call.function.arguments if chunk.choices[0].finish_reason == "tool_calls": - yield from self.handle_function_call(messages, id, name, arguments) + for id, name, arguments in zip( + name_map.keys(), name_map.values(), arguments_map.values() + ): + yield from self.handle_function_call( + messages, id, name, arguments + ) yield from self.get_completion( model=model, temperature=temperature, From a3f8f842740e11f53937b5386db60b212833b953 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=89=AC=E5=A4=9C?= Date: Thu, 29 Aug 2024 16:55:42 +0800 Subject: [PATCH 4/4] handel parallel function calls in handle_function_call --- sgpt/handlers/handler.py | 59 ++++++++++++++++++++++------------------ 1 file changed, 33 insertions(+), 26 deletions(-) diff --git a/sgpt/handlers/handler.py b/sgpt/handlers/handler.py index aa189bb6..04e36073 100644 --- a/sgpt/handlers/handler.py +++ b/sgpt/handlers/handler.py @@ -58,37 +58,46 @@ def make_messages(self, prompt: str) -> List[Dict[str, str]]: def handle_function_call( self, messages: List[dict[str, Any]], - id: str, - name: str, - arguments: str, + name_map, + arguments_map, ) -> Generator[str, None, None]: + all_tool_calls = [ + { + "id": id, + "type": "function", + "function": {"name": name, "arguments": arguments}, + } + for id, name, arguments in zip( + name_map.keys(), name_map.values(), arguments_map.values() + ) + ] messages.append( { "role": "assistant", "content": "", - "tool_calls": [ - { - "id": id, - "type": "function", - "function": {"name": name, "arguments": arguments}, - } - ], + "tool_calls": all_tool_calls, } ) if messages and messages[-1]["role"] == "assistant": yield "\n" - dict_args = json.loads(arguments) - joined_args = ", ".join(f'{k}="{v}"' for k, v in dict_args.items()) - yield f"> @FunctionCall `{name}({joined_args})` \n\n" + all_function_res_msgs = [] + for id, name, arguments in zip( + name_map.keys(), name_map.values(), arguments_map.values() + ): + dict_args = json.loads(arguments) + joined_args = ", ".join(f'{k}="{v}"' for k, v in dict_args.items()) + yield f"> @FunctionCall `{name}({joined_args})` \n\n" - result = get_function(name)(**dict_args) - if cfg.get("SHOW_FUNCTIONS_OUTPUT") == "true": - yield f"```text\n{result}\n```\n" - messages.append( - {"role": "tool", "content": result, "tool_call_id": id, "name": name} - ) + result = get_function(name)(**dict_args) + if cfg.get("SHOW_FUNCTIONS_OUTPUT") == "true": + yield f"```text\n{result}\n```\n" + + all_function_res_msgs.append( + {"role": "tool", "content": result, "tool_call_id": id} + ) + messages += all_function_res_msgs @cache def get_completion( @@ -99,7 +108,6 @@ def get_completion( messages: List[Dict[str, Any]], functions: Optional[List[Dict[str, str]]], ) -> Generator[str, None, None]: - name = arguments = "" is_shell_role = self.role.name == DefaultRoles.SHELL.value is_code_role = self.role.name == DefaultRoles.CODE.value is_dsc_shell_role = self.role.name == DefaultRoles.DESCRIBE_SHELL.value @@ -140,12 +148,11 @@ def get_completion( else: arguments_map[id] += tool_call.function.arguments if chunk.choices[0].finish_reason == "tool_calls": - for id, name, arguments in zip( - name_map.keys(), name_map.values(), arguments_map.values() - ): - yield from self.handle_function_call( - messages, id, name, arguments - ) + yield from self.handle_function_call( + messages, + name_map, + arguments_map, + ) yield from self.get_completion( model=model, temperature=temperature,