Skip to content

Commit dd4c205

Browse files
authored
ToolCall Signature and additional parameters (#1154)
* Box message in error * Add signature and additional_params to ToolCall OpenRouter wants the user to pass back the reasoning as is if possible
1 parent f4e0df5 commit dd4c205

File tree

25 files changed

+401
-402
lines changed

25 files changed

+401
-402
lines changed

rig-integrations/rig-bedrock/src/streaming.rs

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,10 @@ use async_stream::stream;
44
use aws_sdk_bedrockruntime::types as aws_bedrock;
55
use rig::completion::GetTokenUsage;
66
use rig::streaming::StreamingCompletionResponse;
7-
use rig::{completion::CompletionError, streaming::RawStreamingChoice};
7+
use rig::{
8+
completion::CompletionError,
9+
streaming::{RawStreamingChoice, RawStreamingToolCall},
10+
};
811
use serde::{Deserialize, Serialize};
912

1013
#[derive(Clone, Deserialize, Serialize)]
@@ -160,12 +163,7 @@ impl CompletionModel {
160163
} else {
161164
serde_json::from_str(tool_call.input_json.as_str())?
162165
};
163-
yield Ok(RawStreamingChoice::ToolCall {
164-
name: tool_call.name,
165-
call_id: None,
166-
id: tool_call.id,
167-
arguments: tool_input
168-
});
166+
yield Ok(RawStreamingChoice::ToolCall(RawStreamingToolCall::new(tool_call.id, tool_call.name, tool_input)));
169167
} else {
170168
yield Err(CompletionError::ProviderError("Failed to call tool".into()))
171169
}

rig-integrations/rig-bedrock/src/types/assistant_content.rs

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -57,14 +57,10 @@ impl TryFrom<AwsConverseOutput> for completion::CompletionResponse<AwsConverseOu
5757
_ => None,
5858
}) {
5959
return Ok(completion::CompletionResponse {
60-
choice: OneOrMany::one(AssistantContent::ToolCall(ToolCall {
61-
id: tool_use.id,
62-
call_id: None,
63-
function: ToolFunction {
64-
name: tool_use.function.name,
65-
arguments: tool_use.function.arguments,
66-
},
67-
})),
60+
choice: OneOrMany::one(AssistantContent::ToolCall(ToolCall::new(
61+
tool_use.id,
62+
ToolFunction::new(tool_use.function.name, tool_use.function.arguments),
63+
))),
6864
usage,
6965
raw_response: value,
7066
});

rig-integrations/rig-eternalai/src/providers/eternalai.rs

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ use rig::http_client;
2323
use rig::message;
2424
use rig::message::AssistantContent;
2525
use rig::providers::openai::{self, Message};
26-
use rig::streaming::{RawStreamingChoice, StreamingCompletionResponse};
26+
use rig::streaming::{RawStreamingChoice, RawStreamingToolCall, StreamingCompletionResponse};
2727
use rig::{Embed, completion, embeddings};
2828
use schemars::JsonSchema;
2929
use serde::{Deserialize, Serialize};
@@ -734,12 +734,7 @@ impl completion::CompletionModel for CompletionModel {
734734
yield Ok(RawStreamingChoice::Message(text.text.clone()))
735735
}
736736
AssistantContent::ToolCall(tc) => {
737-
yield Ok(RawStreamingChoice::ToolCall {
738-
id: tc.id.clone(),
739-
call_id: None,
740-
name: tc.function.name.clone(),
741-
arguments: tc.function.arguments.clone(),
742-
})
737+
yield Ok(RawStreamingChoice::ToolCall(RawStreamingToolCall::new(tc.id.clone(), tc.function.name.clone(), tc.function.arguments.clone())));
743738
}
744739
AssistantContent::Image(_) => {
745740
panic!("Image content is currently unimplemented on Eternal AI. If you need this, please open a ticket!")

rig-integrations/rig-vertexai/src/types/completion_response.rs

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -34,14 +34,10 @@ impl TryFrom<VertexGenerateContentOutput> for CompletionResponse<VertexGenerateC
3434
.map(|s| serde_json::Value::Object(s.clone()))
3535
.unwrap_or_else(|| serde_json::json!({}));
3636

37-
assistant_contents.push(AssistantContent::ToolCall(ToolCall {
38-
id: function_call.name.clone(),
39-
call_id: None,
40-
function: ToolFunction {
41-
name: function_call.name.clone(),
42-
arguments: args_json,
43-
},
44-
}));
37+
assistant_contents.push(AssistantContent::ToolCall(ToolCall::new(
38+
function_call.name.clone(),
39+
ToolFunction::new(function_call.name.clone(), args_json),
40+
)));
4541
} else if let Some(text) = part.text() {
4642
assistant_contents.push(AssistantContent::Text(Text { text: text.clone() }));
4743
}

rig-integrations/rig-vertexai/src/types/message.rs

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -148,17 +148,16 @@ mod tests {
148148
#[test]
149149
fn test_assistant_tool_call_message_conversion() {
150150
use rig::message::{ToolCall, ToolFunction};
151-
let tool_call = ToolCall {
152-
id: "add".to_string(),
153-
call_id: None,
154-
function: ToolFunction {
155-
name: "add".to_string(),
156-
arguments: serde_json::json!({
151+
let tool_call = ToolCall::new(
152+
"add".to_string(),
153+
ToolFunction::new(
154+
"add".to_string(),
155+
serde_json::json!({
157156
"x": 5,
158157
"y": 3
159158
}),
160-
},
161-
};
159+
),
160+
);
162161

163162
let message = Message::Assistant {
164163
id: None,

rig/rig-core/src/agent/prompt_request/mod.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -562,7 +562,7 @@ where
562562
Err(PromptError::MaxDepthError {
563563
max_depth: self.max_depth,
564564
chat_history: Box::new(chat_history.clone()),
565-
prompt: last_prompt,
565+
prompt: Box::new(last_prompt),
566566
})
567567
}
568568
}

rig/rig-core/src/agent/prompt_request/streaming.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -437,7 +437,7 @@ where
437437
yield Err(Box::new(PromptError::MaxDepthError {
438438
max_depth: self.max_depth,
439439
chat_history: Box::new((*chat_history.read().await).clone()),
440-
prompt: last_prompt_error.clone().into(),
440+
prompt: Box::new(last_prompt_error.clone().into()),
441441
}).into());
442442
}
443443
};

rig/rig-core/src/completion/message.rs

Lines changed: 59 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -142,6 +142,45 @@ pub struct ToolCall {
142142
pub id: String,
143143
pub call_id: Option<String>,
144144
pub function: ToolFunction,
145+
/// Optional cryptographic signature for the tool call.
146+
///
147+
/// This field is used by some providers (e.g., Google) to provide a signature
148+
/// that can verify the authenticity and integrity of the tool call. When present,
149+
/// it allows verification that the tool call was actually generated by the model
150+
/// and has not been tampered with.
151+
///
152+
/// This is an optional, provider-specific feature and will be `None` for providers
153+
/// that don't support tool call signatures.
154+
pub signature: Option<String>,
155+
/// Additional provider-specific parameters to be sent to the completion model provider
156+
pub additional_params: Option<serde_json::Value>,
157+
}
158+
159+
impl ToolCall {
160+
pub fn new(id: String, function: ToolFunction) -> Self {
161+
Self {
162+
id,
163+
call_id: None,
164+
function,
165+
signature: None,
166+
additional_params: None,
167+
}
168+
}
169+
170+
pub fn with_call_id(mut self, call_id: String) -> Self {
171+
self.call_id = Some(call_id);
172+
self
173+
}
174+
175+
pub fn with_signature(mut self, signature: Option<String>) -> Self {
176+
self.signature = signature;
177+
self
178+
}
179+
180+
pub fn with_additional_params(mut self, additional_params: Option<serde_json::Value>) -> Self {
181+
self.additional_params = additional_params;
182+
self
183+
}
145184
}
146185

147186
/// Describes a tool function to call with a name and arguments, generally produced by a provider.
@@ -151,6 +190,12 @@ pub struct ToolFunction {
151190
pub arguments: serde_json::Value,
152191
}
153192

193+
impl ToolFunction {
194+
pub fn new(name: String, arguments: serde_json::Value) -> Self {
195+
Self { name, arguments }
196+
}
197+
}
198+
154199
// ================================================================
155200
// Base content models
156201
// ================================================================
@@ -612,14 +657,13 @@ impl AssistantContent {
612657
name: impl Into<String>,
613658
arguments: serde_json::Value,
614659
) -> Self {
615-
AssistantContent::ToolCall(ToolCall {
616-
id: id.into(),
617-
call_id: None,
618-
function: ToolFunction {
660+
AssistantContent::ToolCall(ToolCall::new(
661+
id.into(),
662+
ToolFunction {
619663
name: name.into(),
620664
arguments,
621665
},
622-
})
666+
))
623667
}
624668

625669
pub fn tool_call_with_call_id(
@@ -628,14 +672,16 @@ impl AssistantContent {
628672
name: impl Into<String>,
629673
arguments: serde_json::Value,
630674
) -> Self {
631-
AssistantContent::ToolCall(ToolCall {
632-
id: id.into(),
633-
call_id: Some(call_id),
634-
function: ToolFunction {
635-
name: name.into(),
636-
arguments,
637-
},
638-
})
675+
AssistantContent::ToolCall(
676+
ToolCall::new(
677+
id.into(),
678+
ToolFunction {
679+
name: name.into(),
680+
arguments,
681+
},
682+
)
683+
.with_call_id(call_id),
684+
)
639685
}
640686

641687
pub fn reasoning(reasoning: impl AsRef<str>) -> Self {

rig/rig-core/src/completion/request.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -140,7 +140,7 @@ pub enum PromptError {
140140
MaxDepthError {
141141
max_depth: usize,
142142
chat_history: Box<Vec<Message>>,
143-
prompt: Message,
143+
prompt: Box<Message>,
144144
},
145145

146146
/// A prompting loop was cancelled.

rig/rig-core/src/providers/anthropic/streaming.rs

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ use crate::completion::{CompletionError, CompletionRequest, GetTokenUsage};
1313
use crate::http_client::sse::{Event, GenericEventSource};
1414
use crate::http_client::{self, HttpClientExt};
1515
use crate::json_utils::merge_inplace;
16-
use crate::streaming::{self, RawStreamingChoice, StreamingResult};
16+
use crate::streaming::{self, RawStreamingChoice, RawStreamingToolCall, StreamingResult};
1717
use crate::telemetry::SpanCombinator;
1818

1919
#[derive(Debug, Deserialize)]
@@ -408,12 +408,9 @@ fn handle_event(
408408
&tool_call.input_json
409409
};
410410
match serde_json::from_str(json_str) {
411-
Ok(json_value) => Some(Ok(RawStreamingChoice::ToolCall {
412-
name: tool_call.name,
413-
id: tool_call.id,
414-
arguments: json_value,
415-
call_id: None,
416-
})),
411+
Ok(json_value) => Some(Ok(RawStreamingChoice::ToolCall(
412+
RawStreamingToolCall::new(tool_call.id, tool_call.name, json_value),
413+
))),
417414
Err(e) => Some(Err(CompletionError::from(e))),
418415
}
419416
} else {
@@ -709,12 +706,12 @@ mod tests {
709706
assert!(final_result.is_some());
710707

711708
match final_result.unwrap().unwrap() {
712-
RawStreamingChoice::ToolCall {
709+
RawStreamingChoice::ToolCall(RawStreamingToolCall {
713710
id,
714711
name,
715712
arguments,
716713
..
717-
} => {
714+
}) => {
718715
assert_eq!(id, "tool_123");
719716
assert_eq!(name, "test_tool");
720717
assert_eq!(

0 commit comments

Comments
 (0)