Skip to content

Commit 52ae02b

Browse files
committed
Add reasoning support for Cohere chat generator
1 parent 24676f2 commit 52ae02b

File tree

2 files changed

+387
-6
lines changed

2 files changed

+387
-6
lines changed

integrations/cohere/src/haystack_integrations/components/generators/cohere/chat/chat_generator.py

Lines changed: 134 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
11
import json
2+
import re
23
from typing import Any, AsyncIterator, Dict, Iterator, List, Literal, Optional, Union, get_args
34

45
from haystack import component, default_from_dict, default_to_dict, logging
56
from haystack.components.generators.utils import _convert_streaming_chunks_to_chat_message
6-
from haystack.dataclasses import ChatMessage, ComponentInfo, ImageContent, TextContent, ToolCall
7+
from haystack.dataclasses import ChatMessage, ComponentInfo, ImageContent, ReasoningContent, TextContent, ToolCall
78
from haystack.dataclasses.streaming_chunk import (
89
AsyncStreamingCallbackT,
910
FinishReason,
@@ -202,11 +203,20 @@ def _parse_response(chat_response: ChatResponse, model: str) -> ChatMessage:
202203
)
203204
)
204205

206+
# Extract reasoning from content if present, even with tool calls
207+
reasoning_content = None
208+
if chat_response.message.content and hasattr(chat_response.message.content[0], "text"):
209+
raw_content = chat_response.message.content[0].text
210+
reasoning_content, _ = _extract_reasoning_from_response(raw_content)
211+
205212
# Create message with tool plan as text and tool calls in the format Haystack expects
206213
tool_plan = chat_response.message.tool_plan or ""
207-
message = ChatMessage.from_assistant(text=tool_plan, tool_calls=tool_calls)
214+
message = ChatMessage.from_assistant(text=tool_plan, tool_calls=tool_calls, reasoning=reasoning_content)
208215
elif chat_response.message.content and hasattr(chat_response.message.content[0], "text"):
209-
message = ChatMessage.from_assistant(chat_response.message.content[0].text)
216+
raw_content = chat_response.message.content[0].text
217+
# Extract reasoning content if present
218+
reasoning_content, cleaned_content = _extract_reasoning_from_response(raw_content)
219+
message = ChatMessage.from_assistant(cleaned_content, reasoning=reasoning_content)
210220
else:
211221
# Handle the case where neither tool_calls nor content exists
212222
logger.warning(f"Received empty response from Cohere API: {chat_response.message}")
@@ -350,6 +360,125 @@ def _convert_cohere_chunk_to_streaming_chunk(
350360
)
351361

352362

363+
def _extract_reasoning_from_response(response_text: str) -> tuple[Optional[ReasoningContent], str]:
364+
"""
365+
Extract reasoning content from Cohere's response if present.
366+
367+
Cohere's reasoning-capable models (like Command A Reasoning) may include reasoning content
368+
in various formats. This function attempts to identify and extract such content.
369+
370+
:param response_text: The raw response text from Cohere
371+
:returns: A tuple of (ReasoningContent or None, cleaned_response_text)
372+
"""
373+
if not response_text or not isinstance(response_text, str):
374+
return None, response_text
375+
376+
# Check for reasoning markers that Cohere might use
377+
378+
# Pattern 1: Look for thinking/reasoning tags
379+
thinking_patterns = [
380+
r"<thinking>(.*?)</thinking>",
381+
r"<reasoning>(.*?)</reasoning>",
382+
r"## Reasoning\s*\n(.*?)(?=\n## |$)",
383+
r"## Thinking\s*\n(.*?)(?=\n## |$)",
384+
]
385+
386+
for pattern in thinking_patterns:
387+
match = re.search(pattern, response_text, re.DOTALL | re.IGNORECASE)
388+
if match:
389+
reasoning_text = match.group(1).strip()
390+
cleaned_content = re.sub(pattern, "", response_text, flags=re.DOTALL | re.IGNORECASE).strip()
391+
# Apply minimum length threshold for tag-based reasoning
392+
min_reasoning_length = 30
393+
if len(reasoning_text) > min_reasoning_length:
394+
return ReasoningContent(reasoning_text=reasoning_text), cleaned_content
395+
else:
396+
# Content too short, but still clean the tags
397+
return None, cleaned_content
398+
399+
# Pattern 2: Look for step-by-step reasoning at start
400+
lines = response_text.split("\n")
401+
reasoning_lines = []
402+
content_lines = []
403+
found_separator = False
404+
405+
for i, line in enumerate(lines):
406+
stripped_line = line.strip()
407+
# Look for reasoning indicators at the beginning of lines (more precise)
408+
if (
409+
stripped_line.startswith(("Step ", "First,", "Let me think", "I need to solve", "To solve"))
410+
or stripped_line.startswith(("## Reasoning", "## Thinking", "## My reasoning"))
411+
or (
412+
len(stripped_line) > 0
413+
and stripped_line.endswith(":")
414+
and ("reasoning" in stripped_line.lower() or "thinking" in stripped_line.lower())
415+
)
416+
):
417+
# Look for a clear separator to determine where reasoning ends
418+
reasoning_end = len(lines) # Default to end of text
419+
for j in range(i + 1, len(lines)):
420+
next_line = lines[j].strip()
421+
if next_line.startswith(
422+
("Based on", "Therefore", "In conclusion", "So,", "Thus,", "## Solution", "## Answer")
423+
):
424+
reasoning_end = j
425+
break
426+
427+
reasoning_lines = lines[:reasoning_end]
428+
content_lines = lines[reasoning_end:]
429+
found_separator = True
430+
break
431+
# Stop looking after first few lines
432+
max_lines_to_check = 10
433+
if i > max_lines_to_check:
434+
break
435+
436+
if found_separator and reasoning_lines:
437+
reasoning_text = "\n".join(reasoning_lines).strip()
438+
cleaned_content = "\n".join(content_lines).strip()
439+
min_reasoning_length = 30
440+
if len(reasoning_text) > min_reasoning_length: # Minimum threshold
441+
return ReasoningContent(reasoning_text=reasoning_text), cleaned_content
442+
443+
# No reasoning detected
444+
return None, response_text
445+
446+
447+
def _convert_streaming_chunks_to_chat_message_with_reasoning(chunks: List[StreamingChunk]) -> ChatMessage:
448+
"""
449+
Convert streaming chunks to ChatMessage with reasoning extraction support.
450+
451+
This is a custom version of the core utility function that adds reasoning content
452+
extraction for Cohere responses.
453+
"""
454+
# Use the core utility to get the base ChatMessage
455+
base_message = _convert_streaming_chunks_to_chat_message(chunks=chunks)
456+
457+
# Extract text content to check for reasoning
458+
if not base_message.text:
459+
return base_message
460+
461+
# Use the text property for reasoning extraction
462+
combined_text = base_message.text
463+
464+
# Extract reasoning if present
465+
reasoning_content, cleaned_text = _extract_reasoning_from_response(combined_text)
466+
467+
if reasoning_content is None:
468+
# No reasoning found, return original message
469+
return base_message
470+
471+
# Create new message with reasoning support
472+
new_message = ChatMessage.from_assistant(
473+
text=cleaned_text,
474+
reasoning=reasoning_content,
475+
tool_calls=base_message.tool_calls,
476+
meta=base_message.meta,
477+
)
478+
479+
return new_message
480+
481+
353482
def _parse_streaming_response(
354483
response: Iterator[StreamedChatResponseV2],
355484
model: str,
@@ -381,7 +510,7 @@ def _parse_streaming_response(
381510
chunks.append(streaming_chunk)
382511
streaming_callback(streaming_chunk)
383512

384-
return _convert_streaming_chunks_to_chat_message(chunks=chunks)
513+
return _convert_streaming_chunks_to_chat_message_with_reasoning(chunks=chunks)
385514

386515

387516
async def _parse_async_streaming_response(
@@ -409,7 +538,7 @@ async def _parse_async_streaming_response(
409538
chunks.append(streaming_chunk)
410539
await streaming_callback(streaming_chunk)
411540

412-
return _convert_streaming_chunks_to_chat_message(chunks=chunks)
541+
return _convert_streaming_chunks_to_chat_message_with_reasoning(chunks=chunks)
413542

414543

415544
@component

0 commit comments

Comments
 (0)