Skip to content

Commit 2b48310

Browse files
OpenAI tool calling tests (#235)
* add tool call tests * optimize tests * remove qwen from hardcoded list for tests * use use case to filter * remove hardcode list and streaming * fix
1 parent f107fdd commit 2b48310

File tree

2 files changed

+278
-0
lines changed

2 files changed

+278
-0
lines changed
Lines changed: 253 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,253 @@
1+
"""
2+
Helper utilities for constructing and validating OpenAI tool calling requests.
3+
4+
This module provides functions and constants for testing OpenAI-compatible
5+
tool calling functionality with Clarifai models.
6+
"""
7+
8+
import json
9+
import os
10+
import time
11+
12+
from openai import APIConnectionError, APITimeoutError, OpenAI, RateLimitError
13+
14+
from clarifai_grpc.channel.clarifai_channel import ClarifaiChannel
15+
from clarifai_grpc.channel.http_client import CLIENT_VERSION
16+
from clarifai_grpc.grpc.api import service_pb2, service_pb2_grpc
17+
from clarifai_grpc.grpc.api.status import status_code_pb2
18+
from tests.common import metadata
19+
20+
# Maximum retry attempts for API calls
21+
MAX_RETRY_ATTEMPTS = 3
22+
23+
# Parameter combinations to test
24+
# For now, we only test non-streaming with tool calling
25+
TOOL_CALLING_CONFIGS = [
26+
{"stream": False, "tool_choice": "required", "strict": True},
27+
{"stream": False, "tool_choice": "required", "strict": False},
28+
{"stream": False, "tool_choice": "auto", "strict": True},
29+
{"stream": False, "tool_choice": "auto", "strict": False},
30+
]
31+
32+
# Tool definition for weather query
33+
WEATHER_TOOL = {
34+
"type": "function",
35+
"function": {
36+
"name": "get_current_weather",
37+
"description": "Get the current weather in a given location",
38+
"parameters": {
39+
"type": "object",
40+
"properties": {
41+
"location": {
42+
"type": "string",
43+
"description": "The city and state, e.g. San Francisco, CA",
44+
},
45+
"unit": {"type": "string", "enum": ["celsius", "fahrenheit"]},
46+
},
47+
"additionalProperties": False,
48+
"required": ["location", "unit"],
49+
},
50+
},
51+
}
52+
53+
54+
def call_openai_tool_calling(model_url, config):
55+
"""
56+
Call OpenAI-compatible endpoint with tool calling using the specified configuration.
57+
Includes retry mechanism for transient errors.
58+
59+
Args:
60+
model_url: Full URL to the Clarifai model
61+
config: Dictionary with stream, tool_choice, and strict parameters
62+
63+
Returns:
64+
tuple: (response, error) where error is None on success
65+
"""
66+
channel = ClarifaiChannel.get_grpc_channel()
67+
client = OpenAI(
68+
api_key=os.environ.get('CLARIFAI_PAT_KEY'),
69+
base_url=f"https://{channel._target}/v2/ext/openai/v1",
70+
default_headers={"X-Clarifai-Request-Id-Prefix": f"python-openai-{CLIENT_VERSION}"},
71+
timeout=10 # 10 seconds timeout to avoid hanging
72+
)
73+
74+
# Build tool definition with strict parameter
75+
tool = WEATHER_TOOL.copy()
76+
tool["function"]["strict"] = config["strict"]
77+
78+
# Build request parameters
79+
request_params = {
80+
"model": model_url,
81+
"messages": [
82+
{"role": "user", "content": "What is the weather like in Boston in fahrenheit?"}
83+
],
84+
"temperature": 1,
85+
"top_p": 1,
86+
"max_tokens": 32768,
87+
"stream": config["stream"],
88+
"tool_choice": config["tool_choice"],
89+
"tools": [tool],
90+
}
91+
92+
# Add stream_options only if streaming is enabled
93+
if config["stream"]:
94+
request_params["stream_options"] = {"include_usage": True}
95+
96+
last_error = None
97+
98+
# Retry loop for transient errors
99+
for attempt in range(MAX_RETRY_ATTEMPTS):
100+
try:
101+
response = client.chat.completions.create(**request_params)
102+
103+
# Handle streaming vs non-streaming responses differently
104+
if config["stream"]:
105+
# For streaming, we need to consume the iterator
106+
chunks = []
107+
for chunk in response:
108+
chunks.append(chunk)
109+
return chunks, None
110+
else:
111+
# For non-streaming, return the response directly
112+
return response, None
113+
114+
except (APIConnectionError, APITimeoutError, RateLimitError) as e:
115+
last_error = e
116+
if attempt == MAX_RETRY_ATTEMPTS - 1:
117+
break
118+
print(
119+
f"Retrying tool calling for '{model_url}' after error: {e}. "
120+
f"Attempt #{attempt + 1}"
121+
)
122+
time.sleep(attempt + 1)
123+
except Exception as e:
124+
last_error = e
125+
break
126+
127+
return None, str(last_error)
128+
129+
130+
def is_valid_tool_arguments(arguments):
131+
"""Check if arguments string is valid JSON with required fields."""
132+
try:
133+
args = json.loads(arguments)
134+
return isinstance(args, dict) and "location" in args and "unit" in args
135+
except (json.JSONDecodeError, TypeError):
136+
return False
137+
138+
139+
def validate_tool_calling_response(response, config):
140+
"""
141+
Validate tool calling response with clear assertion messages.
142+
143+
Validation criteria:
144+
- Streaming: finish_reason='tool_calls', usage info, exactly one tool call with valid JSON
145+
- Non-streaming: tool_calls present with valid JSON arguments
146+
"""
147+
assert response is not None, "Response is None"
148+
149+
if config["stream"]:
150+
assert isinstance(
151+
response, list
152+
) and response, f"Invalid streaming response: {type(response)}"
153+
154+
# Check finish_reason and usage
155+
has_finish_reason = any(
156+
chunk.choices and chunk.choices[0].finish_reason == 'tool_calls' for chunk in response
157+
)
158+
has_usage = any(hasattr(chunk, 'usage') and chunk.usage for chunk in response)
159+
160+
assert has_usage, "Missing usage info in streaming response"
161+
162+
if config["tool_choice"] == "required":
163+
assert has_finish_reason, "Missing finish_reason='tool_calls'"
164+
165+
# Find chunks that contain tool calls
166+
chunks_with_tool_calls = [
167+
chunk for chunk in response if chunk.choices and chunk.choices[0].delta.tool_calls
168+
]
169+
170+
# Validate exactly ONE chunk contains tool calls
171+
assert (
172+
len(chunks_with_tool_calls) == 1
173+
), f"Expected exactly 1 chunk with tool calls, got {len(chunks_with_tool_calls)}"
174+
175+
# Get the single chunk's tool calls
176+
tool_calls = chunks_with_tool_calls[0].choices[0].delta.tool_calls
177+
178+
# Validate exactly one tool call
179+
assert len(tool_calls) == 1, f"Expected exactly 1 tool call, got {len(tool_calls)}"
180+
181+
tool_call = tool_calls[0]
182+
183+
# Validate has function name
184+
assert (
185+
tool_call.function and tool_call.function.name
186+
), "Tool call missing function or name"
187+
188+
# Validate has complete valid JSON arguments
189+
assert (
190+
tool_call.function.arguments and is_valid_tool_arguments(tool_call.function.arguments)
191+
), f"Invalid or missing arguments: {tool_call.function.arguments if tool_call.function else 'N/A'}"
192+
193+
else:
194+
# Non-streaming
195+
assert hasattr(response, 'choices') and response.choices, "Response missing choices"
196+
197+
message = response.choices[0].message
198+
199+
if config["tool_choice"] == "required":
200+
assert hasattr(message, 'tool_calls') and message.tool_calls, "Message missing tool_calls"
201+
202+
tool_call = message.tool_calls[0]
203+
assert tool_call.function and tool_call.function.name, "Tool call missing function or name"
204+
205+
assert is_valid_tool_arguments(
206+
tool_call.function.arguments
207+
), f"Invalid arguments: {tool_call.function.arguments}"
208+
209+
210+
def _list_featured_models_with_use_case_filters(per_page=50, use_cases=None):
211+
"""Lists featured models from the Clarifai platform."""
212+
channel = ClarifaiChannel.get_grpc_channel()
213+
stub = service_pb2_grpc.V2Stub(channel)
214+
request = service_pb2.ListModelsRequest(per_page=per_page, featured_only=True, use_cases=use_cases)
215+
response = stub.ListModels(request, metadata=metadata(pat=True))
216+
if response.status.code != status_code_pb2.SUCCESS:
217+
raise Exception(f"ListModels failed: {response.status.description}")
218+
return response.models
219+
220+
221+
def get_tool_calling_models():
222+
"""
223+
Get the list of models to test for tool calling.
224+
"""
225+
if not os.environ.get('CLARIFAI_PAT_KEY'):
226+
return ["Missing API KEY"]
227+
228+
# Get models with function-calling use case
229+
models_with_use_case = _list_featured_models_with_use_case_filters(
230+
per_page=100, use_cases=['function-calling']
231+
)
232+
233+
tool_calling_models = []
234+
for model in models_with_use_case:
235+
# Also check for openai_transport support
236+
method_signatures = getattr(model.model_version, "method_signatures", [])
237+
if any(ms.name == "openai_transport" for ms in method_signatures):
238+
model_url = f"https://clarifai.com/{model.user_id}/{model.app_id}/models/{model.id}"
239+
tool_calling_models.append(model_url)
240+
241+
return tool_calling_models
242+
243+
244+
def generate_tool_calling_test_params():
245+
"""Generate all combinations of models and configurations for testing."""
246+
models = get_tool_calling_models()
247+
params = []
248+
for model in models:
249+
for config in TOOL_CALLING_CONFIGS:
250+
# Create a readable test ID
251+
test_id = f"{model.split('/')[-1]}-stream_{config['stream']}-choice_{config['tool_choice']}-strict_{config['strict']}"
252+
params.append((model, config, test_id))
253+
return params

tests/public_models/test_public_models_predicts.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,11 @@
3131
post_model_outputs_and_maybe_allow_retries,
3232
raise_on_failure,
3333
)
34+
from tests.public_models.openai_tool_calling_helper import (
35+
call_openai_tool_calling,
36+
generate_tool_calling_test_params,
37+
validate_tool_calling_response,
38+
)
3439
from tests.public_models.public_test_helper import (
3540
AUDIO_MODEL_TITLE_IDS_TUPLE,
3641
DETECTION_MODEL_TITLE_AND_IDS,
@@ -556,3 +561,23 @@ async def test_openai_compatible_endpoint_on_featured_models_async():
556561
failed_models.append({model_identifiers[i]: error})
557562

558563
assert not failed_models, f"The following OpenAI compatible models failed: {failed_models}"
564+
565+
566+
@pytest.mark.parametrize(
567+
"model_url,config",
568+
[pytest.param(m, c, id=tid) for m, c, tid in generate_tool_calling_test_params()],
569+
)
570+
def test_openai_tool_calling_with_parameter_combinations(model_url, config):
571+
"""
572+
Test OpenAI-compatible tool calling with various parameter combinations.
573+
"""
574+
if not os.environ.get('CLARIFAI_PAT_KEY'):
575+
pytest.skip("Skipping test: CLARIFAI_PAT_KEY environment variable not set.")
576+
577+
response, error = call_openai_tool_calling(model_url, config)
578+
579+
# Assert no error occurred
580+
assert not error, f"Tool calling failed for {model_url} with config {config}: {error}"
581+
582+
# Validate response (raises AssertionError with clear message if validation fails)
583+
validate_tool_calling_response(response, config)

0 commit comments

Comments
 (0)