|
3 | 3 |
|
4 | 4 | use super::config::JsonParserConfig;
|
5 | 5 | use super::response::{CalledFunction, ToolCallResponse, ToolCallType};
|
6 |
| -use openai_harmony::StreamableParser; |
7 | 6 | use openai_harmony::chat::{Content::Text, Role};
|
8 |
| -use openai_harmony::{HarmonyEncoding, HarmonyEncodingName, load_harmony_encoding}; |
| 7 | +use openai_harmony::{ |
| 8 | + HarmonyEncoding, HarmonyEncodingName, StreamableParser, load_harmony_encoding, |
| 9 | +}; |
9 | 10 | use serde_json::Value;
|
10 | 11 |
|
11 | 12 | static GLOBAL_HARMONY_GPTOSS_ENCODING: tokio::sync::OnceCell<
|
@@ -162,6 +163,109 @@ pub async fn parse_tool_calls_harmony(
|
162 | 163 | Ok((res, Some(normal_text.to_string())))
|
163 | 164 | }
|
164 | 165 |
|
| 166 | +/// Parse tool calls from a complete Harmony Format text chunk using direct token parsing. |
| 167 | +/// |
| 168 | +/// This function is optimized for parsing complete text chunks where the entire content |
| 169 | +/// is available at once. It uses `parse_messages_from_completion_tokens` to directly |
| 170 | +/// parse all tokens into Harmony Format messages, then extracts tool calls from messages |
| 171 | +/// with the "commentary" channel and "functions.*" recipients. |
| 172 | +/// |
| 173 | +/// Unlike `parse_tool_calls_harmony`, this function doesn't perform start token detection |
| 174 | +/// or token-by-token streaming, making it more efficient for complete chunks. |
| 175 | +/// |
| 176 | +/// # Arguments |
| 177 | +/// * `text` - The full Harmony-format string to be parsed, excluding any trailing stop tokens. |
| 178 | +/// Example: |
| 179 | +/// `<|channel|>commentary to=functions.get_current_weather <|constrain|>json<|message|>{"location":"San Francisco"}` |
| 180 | +/// * `_config` - Parser configuration (currently unused but kept for API consistency) |
| 181 | +/// |
| 182 | +/// # Returns |
| 183 | +/// * `Ok((tool_calls, normal_text))` - Tuple containing extracted tool calls and any normal text |
| 184 | +/// * `Err(e)` - If parsing fails due to encoding or tokenization errors |
| 185 | +pub async fn parse_tool_calls_harmony_complete( |
| 186 | + text: &str, |
| 187 | + _config: &JsonParserConfig, |
| 188 | +) -> anyhow::Result<(Vec<ToolCallResponse>, Option<String>)> { |
| 189 | + let enc = match get_harmony_encoding().await.as_ref() { |
| 190 | + Ok(e) => e, |
| 191 | + Err(e) => { |
| 192 | + tracing::debug!("Failed to load harmony encoding: {e}. Tool calls will not be parsed."); |
| 193 | + return Ok((vec![], Some(text.to_string()))); |
| 194 | + } |
| 195 | + }; |
| 196 | + |
| 197 | + // // Encode the text into tokens using harmony encoding |
| 198 | + let tokens: Vec<u32> = enc.tokenizer().encode_with_special_tokens(text); |
| 199 | + let messages = match enc.parse_messages_from_completion_tokens(tokens, Some(Role::Assistant)) { |
| 200 | + Ok(messages) => messages, |
| 201 | + Err(e) => { |
| 202 | + tracing::debug!( |
| 203 | + "Failed to parse messages from completion tokens: {e}. Tool calls will not be parsed." |
| 204 | + ); |
| 205 | + return Ok((vec![], Some(text.to_string()))); |
| 206 | + } |
| 207 | + }; |
| 208 | + |
| 209 | + let mut normal_text = String::new(); |
| 210 | + |
| 211 | + let mut res = Vec::with_capacity(messages.len()); |
| 212 | + let mut call_idx = 0; // Index of the tool call |
| 213 | + |
| 214 | + for message in messages.iter() { |
| 215 | + if message.author.role != Role::Assistant { |
| 216 | + continue; |
| 217 | + } |
| 218 | + |
| 219 | + let channel = message.channel.as_deref(); |
| 220 | + let recipient = message.recipient.as_deref().unwrap_or_default(); |
| 221 | + |
| 222 | + // Handle commentary channel |
| 223 | + if channel == Some("commentary") && recipient.starts_with("functions.") { |
| 224 | + let Some(fname) = message |
| 225 | + .recipient |
| 226 | + .as_ref() |
| 227 | + .and_then(|r| r.split('.').nth(1)) |
| 228 | + .filter(|s| !s.is_empty()) |
| 229 | + .map(|s| s.to_string()) |
| 230 | + else { |
| 231 | + continue; |
| 232 | + }; |
| 233 | + |
| 234 | + let args = match message.content.first() { |
| 235 | + Some(Text(text)) => match serde_json::from_str::<Value>(text.text.trim()) { |
| 236 | + Ok(value) => value, |
| 237 | + Err(_) => { |
| 238 | + Value::Null // Set args to null if it's not valid JSON |
| 239 | + } |
| 240 | + }, |
| 241 | + _ => { |
| 242 | + Value::Null // Set args to null if it's not a text content |
| 243 | + } |
| 244 | + }; |
| 245 | + // Add tool call to result if args is valid JSON |
| 246 | + if !args.is_null() { |
| 247 | + call_idx += 1; |
| 248 | + res.push(ToolCallResponse { |
| 249 | + id: format!("call-{}", call_idx), |
| 250 | + tp: ToolCallType::Function, |
| 251 | + function: CalledFunction { |
| 252 | + name: fname.to_string(), |
| 253 | + // Safety: `Value::Object` is always valid JSON, so serialization cannot fail |
| 254 | + arguments: serde_json::to_string(&args).unwrap(), |
| 255 | + }, |
| 256 | + }); |
| 257 | + } |
| 258 | + // Handle reasoning(analysis) channel |
| 259 | + } else if channel == Some("analysis") { |
| 260 | + normal_text.push_str(match &message.content[0] { |
| 261 | + Text(t) => &t.text, |
| 262 | + _ => "", |
| 263 | + }); |
| 264 | + } |
| 265 | + } |
| 266 | + Ok((res, Some(normal_text.to_string()))) |
| 267 | +} |
| 268 | + |
165 | 269 | pub fn detect_tool_call_start_harmony(
|
166 | 270 | chunk: &str,
|
167 | 271 | config: &JsonParserConfig,
|
@@ -266,6 +370,20 @@ mod tests {
|
266 | 370 | assert_eq!(args["location"], "San Francisco");
|
267 | 371 | }
|
268 | 372 |
|
| 373 | + #[tokio::test] |
| 374 | + async fn test_parse_tool_calls_harmony_complete_basic() { |
| 375 | + let text = r#"<|channel|>commentary to=functions.get_current_weather <|constrain|>json<|message|>{"format":"celsius","location":"San Francisco"}"#; |
| 376 | + let (tool_calls, normal_content) = |
| 377 | + parse_tool_calls_harmony_complete(text, &Default::default()) |
| 378 | + .await |
| 379 | + .unwrap(); |
| 380 | + assert_eq!(normal_content, Some("".to_string())); |
| 381 | + let (name, args) = extract_name_and_args(tool_calls[0].clone()); |
| 382 | + assert_eq!(name, "get_current_weather"); |
| 383 | + assert_eq!(args["location"], "San Francisco"); |
| 384 | + assert_eq!(args["format"], "celsius"); |
| 385 | + } |
| 386 | + |
269 | 387 | #[tokio::test]
|
270 | 388 | async fn test_parse_tools_harmony_without_start_token() {
|
271 | 389 | let text = r#"
|
|
0 commit comments