|
| 1 | +""" |
| 2 | +Helper utilities for constructing and validating OpenAI tool calling requests. |
| 3 | +
|
| 4 | +This module provides functions and constants for testing OpenAI-compatible |
| 5 | +tool calling functionality with Clarifai models. |
| 6 | +""" |
| 7 | + |
| 8 | +import json |
| 9 | +import os |
| 10 | +import time |
| 11 | + |
| 12 | +from openai import APIConnectionError, APITimeoutError, OpenAI, RateLimitError |
| 13 | + |
| 14 | +from clarifai_grpc.channel.clarifai_channel import ClarifaiChannel |
| 15 | +from clarifai_grpc.channel.http_client import CLIENT_VERSION |
| 16 | +from clarifai_grpc.grpc.api import service_pb2, service_pb2_grpc |
| 17 | +from clarifai_grpc.grpc.api.status import status_code_pb2 |
| 18 | +from tests.common import metadata |
| 19 | + |
| 20 | +# Maximum retry attempts for API calls |
| 21 | +MAX_RETRY_ATTEMPTS = 3 |
| 22 | + |
| 23 | +# Parameter combinations to test |
| 24 | +# For now, we only test non-streaming with tool calling |
| 25 | +TOOL_CALLING_CONFIGS = [ |
| 26 | + {"stream": False, "tool_choice": "required", "strict": True}, |
| 27 | + {"stream": False, "tool_choice": "required", "strict": False}, |
| 28 | + {"stream": False, "tool_choice": "auto", "strict": True}, |
| 29 | + {"stream": False, "tool_choice": "auto", "strict": False}, |
| 30 | +] |
| 31 | + |
| 32 | +# Tool definition for weather query |
| 33 | +WEATHER_TOOL = { |
| 34 | + "type": "function", |
| 35 | + "function": { |
| 36 | + "name": "get_current_weather", |
| 37 | + "description": "Get the current weather in a given location", |
| 38 | + "parameters": { |
| 39 | + "type": "object", |
| 40 | + "properties": { |
| 41 | + "location": { |
| 42 | + "type": "string", |
| 43 | + "description": "The city and state, e.g. San Francisco, CA", |
| 44 | + }, |
| 45 | + "unit": {"type": "string", "enum": ["celsius", "fahrenheit"]}, |
| 46 | + }, |
| 47 | + "additionalProperties": False, |
| 48 | + "required": ["location", "unit"], |
| 49 | + }, |
| 50 | + }, |
| 51 | +} |
| 52 | + |
| 53 | + |
| 54 | +def call_openai_tool_calling(model_url, config): |
| 55 | + """ |
| 56 | + Call OpenAI-compatible endpoint with tool calling using the specified configuration. |
| 57 | + Includes retry mechanism for transient errors. |
| 58 | +
|
| 59 | + Args: |
| 60 | + model_url: Full URL to the Clarifai model |
| 61 | + config: Dictionary with stream, tool_choice, and strict parameters |
| 62 | +
|
| 63 | + Returns: |
| 64 | + tuple: (response, error) where error is None on success |
| 65 | + """ |
| 66 | + channel = ClarifaiChannel.get_grpc_channel() |
| 67 | + client = OpenAI( |
| 68 | + api_key=os.environ.get('CLARIFAI_PAT_KEY'), |
| 69 | + base_url=f"https://{channel._target}/v2/ext/openai/v1", |
| 70 | + default_headers={"X-Clarifai-Request-Id-Prefix": f"python-openai-{CLIENT_VERSION}"}, |
| 71 | + timeout=10 # 10 seconds timeout to avoid hanging |
| 72 | + ) |
| 73 | + |
| 74 | + # Build tool definition with strict parameter |
| 75 | + tool = WEATHER_TOOL.copy() |
| 76 | + tool["function"]["strict"] = config["strict"] |
| 77 | + |
| 78 | + # Build request parameters |
| 79 | + request_params = { |
| 80 | + "model": model_url, |
| 81 | + "messages": [ |
| 82 | + {"role": "user", "content": "What is the weather like in Boston in fahrenheit?"} |
| 83 | + ], |
| 84 | + "temperature": 1, |
| 85 | + "top_p": 1, |
| 86 | + "max_tokens": 32768, |
| 87 | + "stream": config["stream"], |
| 88 | + "tool_choice": config["tool_choice"], |
| 89 | + "tools": [tool], |
| 90 | + } |
| 91 | + |
| 92 | + # Add stream_options only if streaming is enabled |
| 93 | + if config["stream"]: |
| 94 | + request_params["stream_options"] = {"include_usage": True} |
| 95 | + |
| 96 | + last_error = None |
| 97 | + |
| 98 | + # Retry loop for transient errors |
| 99 | + for attempt in range(MAX_RETRY_ATTEMPTS): |
| 100 | + try: |
| 101 | + response = client.chat.completions.create(**request_params) |
| 102 | + |
| 103 | + # Handle streaming vs non-streaming responses differently |
| 104 | + if config["stream"]: |
| 105 | + # For streaming, we need to consume the iterator |
| 106 | + chunks = [] |
| 107 | + for chunk in response: |
| 108 | + chunks.append(chunk) |
| 109 | + return chunks, None |
| 110 | + else: |
| 111 | + # For non-streaming, return the response directly |
| 112 | + return response, None |
| 113 | + |
| 114 | + except (APIConnectionError, APITimeoutError, RateLimitError) as e: |
| 115 | + last_error = e |
| 116 | + if attempt == MAX_RETRY_ATTEMPTS - 1: |
| 117 | + break |
| 118 | + print( |
| 119 | + f"Retrying tool calling for '{model_url}' after error: {e}. " |
| 120 | + f"Attempt #{attempt + 1}" |
| 121 | + ) |
| 122 | + time.sleep(attempt + 1) |
| 123 | + except Exception as e: |
| 124 | + last_error = e |
| 125 | + break |
| 126 | + |
| 127 | + return None, str(last_error) |
| 128 | + |
| 129 | + |
| 130 | +def is_valid_tool_arguments(arguments): |
| 131 | + """Check if arguments string is valid JSON with required fields.""" |
| 132 | + try: |
| 133 | + args = json.loads(arguments) |
| 134 | + return isinstance(args, dict) and "location" in args and "unit" in args |
| 135 | + except (json.JSONDecodeError, TypeError): |
| 136 | + return False |
| 137 | + |
| 138 | + |
| 139 | +def validate_tool_calling_response(response, config): |
| 140 | + """ |
| 141 | + Validate tool calling response with clear assertion messages. |
| 142 | +
|
| 143 | + Validation criteria: |
| 144 | + - Streaming: finish_reason='tool_calls', usage info, exactly one tool call with valid JSON |
| 145 | + - Non-streaming: tool_calls present with valid JSON arguments |
| 146 | + """ |
| 147 | + assert response is not None, "Response is None" |
| 148 | + |
| 149 | + if config["stream"]: |
| 150 | + assert isinstance( |
| 151 | + response, list |
| 152 | + ) and response, f"Invalid streaming response: {type(response)}" |
| 153 | + |
| 154 | + # Check finish_reason and usage |
| 155 | + has_finish_reason = any( |
| 156 | + chunk.choices and chunk.choices[0].finish_reason == 'tool_calls' for chunk in response |
| 157 | + ) |
| 158 | + has_usage = any(hasattr(chunk, 'usage') and chunk.usage for chunk in response) |
| 159 | + |
| 160 | + assert has_usage, "Missing usage info in streaming response" |
| 161 | + |
| 162 | + if config["tool_choice"] == "required": |
| 163 | + assert has_finish_reason, "Missing finish_reason='tool_calls'" |
| 164 | + |
| 165 | + # Find chunks that contain tool calls |
| 166 | + chunks_with_tool_calls = [ |
| 167 | + chunk for chunk in response if chunk.choices and chunk.choices[0].delta.tool_calls |
| 168 | + ] |
| 169 | + |
| 170 | + # Validate exactly ONE chunk contains tool calls |
| 171 | + assert ( |
| 172 | + len(chunks_with_tool_calls) == 1 |
| 173 | + ), f"Expected exactly 1 chunk with tool calls, got {len(chunks_with_tool_calls)}" |
| 174 | + |
| 175 | + # Get the single chunk's tool calls |
| 176 | + tool_calls = chunks_with_tool_calls[0].choices[0].delta.tool_calls |
| 177 | + |
| 178 | + # Validate exactly one tool call |
| 179 | + assert len(tool_calls) == 1, f"Expected exactly 1 tool call, got {len(tool_calls)}" |
| 180 | + |
| 181 | + tool_call = tool_calls[0] |
| 182 | + |
| 183 | + # Validate has function name |
| 184 | + assert ( |
| 185 | + tool_call.function and tool_call.function.name |
| 186 | + ), "Tool call missing function or name" |
| 187 | + |
| 188 | + # Validate has complete valid JSON arguments |
| 189 | + assert ( |
| 190 | + tool_call.function.arguments and is_valid_tool_arguments(tool_call.function.arguments) |
| 191 | + ), f"Invalid or missing arguments: {tool_call.function.arguments if tool_call.function else 'N/A'}" |
| 192 | + |
| 193 | + else: |
| 194 | + # Non-streaming |
| 195 | + assert hasattr(response, 'choices') and response.choices, "Response missing choices" |
| 196 | + |
| 197 | + message = response.choices[0].message |
| 198 | + |
| 199 | + if config["tool_choice"] == "required": |
| 200 | + assert hasattr(message, 'tool_calls') and message.tool_calls, "Message missing tool_calls" |
| 201 | + |
| 202 | + tool_call = message.tool_calls[0] |
| 203 | + assert tool_call.function and tool_call.function.name, "Tool call missing function or name" |
| 204 | + |
| 205 | + assert is_valid_tool_arguments( |
| 206 | + tool_call.function.arguments |
| 207 | + ), f"Invalid arguments: {tool_call.function.arguments}" |
| 208 | + |
| 209 | + |
| 210 | +def _list_featured_models_with_use_case_filters(per_page=50, use_cases=None): |
| 211 | + """Lists featured models from the Clarifai platform.""" |
| 212 | + channel = ClarifaiChannel.get_grpc_channel() |
| 213 | + stub = service_pb2_grpc.V2Stub(channel) |
| 214 | + request = service_pb2.ListModelsRequest(per_page=per_page, featured_only=True, use_cases=use_cases) |
| 215 | + response = stub.ListModels(request, metadata=metadata(pat=True)) |
| 216 | + if response.status.code != status_code_pb2.SUCCESS: |
| 217 | + raise Exception(f"ListModels failed: {response.status.description}") |
| 218 | + return response.models |
| 219 | + |
| 220 | + |
| 221 | +def get_tool_calling_models(): |
| 222 | + """ |
| 223 | + Get the list of models to test for tool calling. |
| 224 | + """ |
| 225 | + if not os.environ.get('CLARIFAI_PAT_KEY'): |
| 226 | + return ["Missing API KEY"] |
| 227 | + |
| 228 | + # Get models with function-calling use case |
| 229 | + models_with_use_case = _list_featured_models_with_use_case_filters( |
| 230 | + per_page=100, use_cases=['function-calling'] |
| 231 | + ) |
| 232 | + |
| 233 | + tool_calling_models = [] |
| 234 | + for model in models_with_use_case: |
| 235 | + # Also check for openai_transport support |
| 236 | + method_signatures = getattr(model.model_version, "method_signatures", []) |
| 237 | + if any(ms.name == "openai_transport" for ms in method_signatures): |
| 238 | + model_url = f"https://clarifai.com/{model.user_id}/{model.app_id}/models/{model.id}" |
| 239 | + tool_calling_models.append(model_url) |
| 240 | + |
| 241 | + return tool_calling_models |
| 242 | + |
| 243 | + |
| 244 | +def generate_tool_calling_test_params(): |
| 245 | + """Generate all combinations of models and configurations for testing.""" |
| 246 | + models = get_tool_calling_models() |
| 247 | + params = [] |
| 248 | + for model in models: |
| 249 | + for config in TOOL_CALLING_CONFIGS: |
| 250 | + # Create a readable test ID |
| 251 | + test_id = f"{model.split('/')[-1]}-stream_{config['stream']}-choice_{config['tool_choice']}-strict_{config['strict']}" |
| 252 | + params.append((model, config, test_id)) |
| 253 | + return params |
0 commit comments