diff --git a/src/google/adk/tools/function_tool.py b/src/google/adk/tools/function_tool.py index 1ab32d42b7..ab04b0223c 100644 --- a/src/google/adk/tools/function_tool.py +++ b/src/google/adk/tools/function_tool.py @@ -102,6 +102,7 @@ def _preprocess_args(self, args: dict[str, Any]) -> dict[str, Any]: Currently handles: - Converting JSON dictionaries to Pydantic model instances where expected + - Converting lists of JSON dictionaries to lists of Pydantic model instances Future extensions could include: - Type coercion for other complex types @@ -129,8 +130,36 @@ def _preprocess_args(self, args: dict[str, Any]) -> dict[str, Any]: if len(non_none_types) == 1: target_type = non_none_types[0] + # Check if the target type is a list + if get_origin(target_type) is list: + list_args = get_args(target_type) + if list_args: + element_type = list_args[0] + + # Check if the element type is a Pydantic model + if inspect.isclass(element_type) and issubclass( + element_type, pydantic.BaseModel + ): + # Skip conversion if the value is None + if args[param_name] is None: + continue + + # Convert list elements to Pydantic models + if isinstance(args[param_name], list): + converted_list = [] + for item in args[param_name]: + try: + converted_list.append(element_type.model_validate(item)) + except Exception as e: + # Skip items that fail validation + logger.warning( + f"Skipping item in '{param_name}': " + f'Failed to convert to {element_type.__name__}: {e}' + ) + converted_args[param_name] = converted_list + # Check if the target type is a Pydantic model - if inspect.isclass(target_type) and issubclass( + elif inspect.isclass(target_type) and issubclass( target_type, pydantic.BaseModel ): # Skip conversion if the value is None and the parameter is Optional diff --git a/tests/unittests/tools/test_function_tool_pydantic.py b/tests/unittests/tools/test_function_tool_pydantic.py index 1af5d68345..acacd5e8be 100644 --- a/tests/unittests/tools/test_function_tool_pydantic.py +++ b/tests/unittests/tools/test_function_tool_pydantic.py @@ -280,5 +280,121 @@ async def test_run_async_with_optional_pydantic_models(): assert result["theme"] == "dark" assert result["notifications"] is True assert result["preferences_type"] == "PreferencesModel" - assert result["preferences_type"] == "PreferencesModel" - assert result["preferences_type"] == "PreferencesModel" + + +def function_with_list_of_pydantic_models(users: list[UserModel]) -> dict: + """Function that takes a list of Pydantic models.""" + return { + "count": len(users), + "names": [user.name for user in users], + "ages": [user.age for user in users], + "types": [type(user).__name__ for user in users], + } + + +def function_with_optional_list_of_pydantic_models( + users: Optional[list[UserModel]] = None, +) -> dict: + """Function that takes an optional list of Pydantic models.""" + if users is None: + return {"count": 0, "names": []} + return { + "count": len(users), + "names": [user.name for user in users], + } + + +def test_preprocess_args_with_list_of_dicts_to_pydantic_models(): + """Test _preprocess_args converts list of dicts to list of Pydantic models.""" + tool = FunctionTool(function_with_list_of_pydantic_models) + + input_args = { + "users": [ + {"name": "Alice", "age": 30, "email": "alice@example.com"}, + {"name": "Bob", "age": 25}, + {"name": "Charlie", "age": 35, "email": "charlie@example.com"}, + ] + } + + processed_args = tool._preprocess_args(input_args) + + # Check that the list of dicts was converted to a list of Pydantic models + assert "users" in processed_args + users = processed_args["users"] + assert isinstance(users, list) + assert len(users) == 3 + + # Check each element is a Pydantic model with correct data + assert isinstance(users[0], UserModel) + assert users[0].name == "Alice" + assert users[0].age == 30 + assert users[0].email == "alice@example.com" + + assert isinstance(users[1], UserModel) + assert users[1].name == "Bob" + assert users[1].age == 25 + assert users[1].email is None + + assert isinstance(users[2], UserModel) + assert users[2].name == "Charlie" + assert users[2].age == 35 + assert users[2].email == "charlie@example.com" + + +def test_preprocess_args_with_optional_list_of_pydantic_models_none(): + """Test _preprocess_args handles None for optional list parameter.""" + tool = FunctionTool(function_with_optional_list_of_pydantic_models) + + input_args = {"users": None} + + processed_args = tool._preprocess_args(input_args) + + # Check that None is preserved + assert "users" in processed_args + assert processed_args["users"] is None + + +def test_preprocess_args_with_optional_list_of_pydantic_models_with_data(): + """Test _preprocess_args converts list for optional list parameter.""" + tool = FunctionTool(function_with_optional_list_of_pydantic_models) + + input_args = { + "users": [ + {"name": "Alice", "age": 30}, + {"name": "Bob", "age": 25}, + ] + } + + processed_args = tool._preprocess_args(input_args) + + # Check conversion + assert "users" in processed_args + users = processed_args["users"] + assert len(users) == 2 + assert all(isinstance(user, UserModel) for user in users) + assert users[0].name == "Alice" + assert users[1].name == "Bob" + + +def test_preprocess_args_with_list_skips_invalid_items(): + """Test _preprocess_args skips items that fail validation.""" + tool = FunctionTool(function_with_list_of_pydantic_models) + + input_args = { + "users": [ + {"name": "Alice", "age": 30}, + {"name": "Invalid"}, # Missing required 'age' field + {"name": "Bob", "age": 25}, + ] + } + + processed_args = tool._preprocess_args(input_args) + + # Check that invalid item was skipped + assert "users" in processed_args + users = processed_args["users"] + assert len(users) == 2 # Only 2 valid items + assert users[0].name == "Alice" + assert users[0].age == 30 + assert users[1].name == "Bob" + assert users[1].age == 25