|
1 | 1 | import json |
| 2 | +import re |
2 | 3 | from typing import Any, AsyncIterator, Dict, Iterator, List, Literal, Optional, Union, get_args |
3 | 4 |
|
4 | 5 | from haystack import component, default_from_dict, default_to_dict, logging |
5 | 6 | from haystack.components.generators.utils import _convert_streaming_chunks_to_chat_message |
6 | | -from haystack.dataclasses import ChatMessage, ComponentInfo, ImageContent, TextContent, ToolCall |
| 7 | +from haystack.dataclasses import ChatMessage, ComponentInfo, ImageContent, ReasoningContent, TextContent, ToolCall |
7 | 8 | from haystack.dataclasses.streaming_chunk import ( |
8 | 9 | AsyncStreamingCallbackT, |
9 | 10 | FinishReason, |
@@ -202,11 +203,20 @@ def _parse_response(chat_response: ChatResponse, model: str) -> ChatMessage: |
202 | 203 | ) |
203 | 204 | ) |
204 | 205 |
|
| 206 | + # Extract reasoning from content if present, even with tool calls |
| 207 | + reasoning_content = None |
| 208 | + if chat_response.message.content and hasattr(chat_response.message.content[0], "text"): |
| 209 | + raw_content = chat_response.message.content[0].text |
| 210 | + reasoning_content, _ = _extract_reasoning_from_response(raw_content) |
| 211 | + |
205 | 212 | # Create message with tool plan as text and tool calls in the format Haystack expects |
206 | 213 | tool_plan = chat_response.message.tool_plan or "" |
207 | | - message = ChatMessage.from_assistant(text=tool_plan, tool_calls=tool_calls) |
| 214 | + message = ChatMessage.from_assistant(text=tool_plan, tool_calls=tool_calls, reasoning=reasoning_content) |
208 | 215 | elif chat_response.message.content and hasattr(chat_response.message.content[0], "text"): |
209 | | - message = ChatMessage.from_assistant(chat_response.message.content[0].text) |
| 216 | + raw_content = chat_response.message.content[0].text |
| 217 | + # Extract reasoning content if present |
| 218 | + reasoning_content, cleaned_content = _extract_reasoning_from_response(raw_content) |
| 219 | + message = ChatMessage.from_assistant(cleaned_content, reasoning=reasoning_content) |
210 | 220 | else: |
211 | 221 | # Handle the case where neither tool_calls nor content exists |
212 | 222 | logger.warning(f"Received empty response from Cohere API: {chat_response.message}") |
@@ -350,6 +360,125 @@ def _convert_cohere_chunk_to_streaming_chunk( |
350 | 360 | ) |
351 | 361 |
|
352 | 362 |
|
| 363 | +def _extract_reasoning_from_response(response_text: str) -> tuple[Optional[ReasoningContent], str]: |
| 364 | + """ |
| 365 | + Extract reasoning content from Cohere's response if present. |
| 366 | +
|
| 367 | + Cohere's reasoning-capable models (like Command A Reasoning) may include reasoning content |
| 368 | + in various formats. This function attempts to identify and extract such content. |
| 369 | +
|
| 370 | + :param response_text: The raw response text from Cohere |
| 371 | + :returns: A tuple of (ReasoningContent or None, cleaned_response_text) |
| 372 | + """ |
| 373 | + if not response_text or not isinstance(response_text, str): |
| 374 | + return None, response_text |
| 375 | + |
| 376 | + # Check for reasoning markers that Cohere might use |
| 377 | + |
| 378 | + # Pattern 1: Look for thinking/reasoning tags |
| 379 | + thinking_patterns = [ |
| 380 | + r"<thinking>(.*?)</thinking>", |
| 381 | + r"<reasoning>(.*?)</reasoning>", |
| 382 | + r"## Reasoning\s*\n(.*?)(?=\n## |$)", |
| 383 | + r"## Thinking\s*\n(.*?)(?=\n## |$)", |
| 384 | + ] |
| 385 | + |
| 386 | + for pattern in thinking_patterns: |
| 387 | + match = re.search(pattern, response_text, re.DOTALL | re.IGNORECASE) |
| 388 | + if match: |
| 389 | + reasoning_text = match.group(1).strip() |
| 390 | + cleaned_content = re.sub(pattern, "", response_text, flags=re.DOTALL | re.IGNORECASE).strip() |
| 391 | + # Apply minimum length threshold for tag-based reasoning |
| 392 | + min_reasoning_length = 30 |
| 393 | + if len(reasoning_text) > min_reasoning_length: |
| 394 | + return ReasoningContent(reasoning_text=reasoning_text), cleaned_content |
| 395 | + else: |
| 396 | + # Content too short, but still clean the tags |
| 397 | + return None, cleaned_content |
| 398 | + |
| 399 | + # Pattern 2: Look for step-by-step reasoning at start |
| 400 | + lines = response_text.split("\n") |
| 401 | + reasoning_lines = [] |
| 402 | + content_lines = [] |
| 403 | + found_separator = False |
| 404 | + |
| 405 | + for i, line in enumerate(lines): |
| 406 | + stripped_line = line.strip() |
| 407 | + # Look for reasoning indicators at the beginning of lines (more precise) |
| 408 | + if ( |
| 409 | + stripped_line.startswith(("Step ", "First,", "Let me think", "I need to solve", "To solve")) |
| 410 | + or stripped_line.startswith(("## Reasoning", "## Thinking", "## My reasoning")) |
| 411 | + or ( |
| 412 | + len(stripped_line) > 0 |
| 413 | + and stripped_line.endswith(":") |
| 414 | + and ("reasoning" in stripped_line.lower() or "thinking" in stripped_line.lower()) |
| 415 | + ) |
| 416 | + ): |
| 417 | + # Look for a clear separator to determine where reasoning ends |
| 418 | + reasoning_end = len(lines) # Default to end of text |
| 419 | + for j in range(i + 1, len(lines)): |
| 420 | + next_line = lines[j].strip() |
| 421 | + if next_line.startswith( |
| 422 | + ("Based on", "Therefore", "In conclusion", "So,", "Thus,", "## Solution", "## Answer") |
| 423 | + ): |
| 424 | + reasoning_end = j |
| 425 | + break |
| 426 | + |
| 427 | + reasoning_lines = lines[:reasoning_end] |
| 428 | + content_lines = lines[reasoning_end:] |
| 429 | + found_separator = True |
| 430 | + break |
| 431 | + # Stop looking after first few lines |
| 432 | + max_lines_to_check = 10 |
| 433 | + if i > max_lines_to_check: |
| 434 | + break |
| 435 | + |
| 436 | + if found_separator and reasoning_lines: |
| 437 | + reasoning_text = "\n".join(reasoning_lines).strip() |
| 438 | + cleaned_content = "\n".join(content_lines).strip() |
| 439 | + min_reasoning_length = 30 |
| 440 | + if len(reasoning_text) > min_reasoning_length: # Minimum threshold |
| 441 | + return ReasoningContent(reasoning_text=reasoning_text), cleaned_content |
| 442 | + |
| 443 | + # No reasoning detected |
| 444 | + return None, response_text |
| 445 | + |
| 446 | + |
| 447 | +def _convert_streaming_chunks_to_chat_message_with_reasoning(chunks: List[StreamingChunk]) -> ChatMessage: |
| 448 | + """ |
| 449 | + Convert streaming chunks to ChatMessage with reasoning extraction support. |
| 450 | +
|
| 451 | + This is a custom version of the core utility function that adds reasoning content |
| 452 | + extraction for Cohere responses. |
| 453 | + """ |
| 454 | + # Use the core utility to get the base ChatMessage |
| 455 | + base_message = _convert_streaming_chunks_to_chat_message(chunks=chunks) |
| 456 | + |
| 457 | + # Extract text content to check for reasoning |
| 458 | + if not base_message.text: |
| 459 | + return base_message |
| 460 | + |
| 461 | + # Use the text property for reasoning extraction |
| 462 | + combined_text = base_message.text |
| 463 | + |
| 464 | + # Extract reasoning if present |
| 465 | + reasoning_content, cleaned_text = _extract_reasoning_from_response(combined_text) |
| 466 | + |
| 467 | + if reasoning_content is None: |
| 468 | + # No reasoning found, return original message |
| 469 | + return base_message |
| 470 | + |
| 471 | + # Create new message with reasoning support |
| 472 | + new_message = ChatMessage.from_assistant( |
| 473 | + text=cleaned_text, |
| 474 | + reasoning=reasoning_content, |
| 475 | + tool_calls=base_message.tool_calls, |
| 476 | + meta=base_message.meta, |
| 477 | + ) |
| 478 | + |
| 479 | + return new_message |
| 480 | + |
| 481 | + |
353 | 482 | def _parse_streaming_response( |
354 | 483 | response: Iterator[StreamedChatResponseV2], |
355 | 484 | model: str, |
@@ -381,7 +510,7 @@ def _parse_streaming_response( |
381 | 510 | chunks.append(streaming_chunk) |
382 | 511 | streaming_callback(streaming_chunk) |
383 | 512 |
|
384 | | - return _convert_streaming_chunks_to_chat_message(chunks=chunks) |
| 513 | + return _convert_streaming_chunks_to_chat_message_with_reasoning(chunks=chunks) |
385 | 514 |
|
386 | 515 |
|
387 | 516 | async def _parse_async_streaming_response( |
@@ -409,7 +538,7 @@ async def _parse_async_streaming_response( |
409 | 538 | chunks.append(streaming_chunk) |
410 | 539 | await streaming_callback(streaming_chunk) |
411 | 540 |
|
412 | | - return _convert_streaming_chunks_to_chat_message(chunks=chunks) |
| 541 | + return _convert_streaming_chunks_to_chat_message_with_reasoning(chunks=chunks) |
413 | 542 |
|
414 | 543 |
|
415 | 544 | @component |
|
0 commit comments