diff --git a/src/utils/types.py b/src/utils/types.py index 36d8257f..908a7c60 100644 --- a/src/utils/types.py +++ b/src/utils/types.py @@ -1,15 +1,45 @@ """Common types for the project.""" -from typing import Any, Optional +import ast import json +import logging +import re +from typing import Any, Optional + from llama_stack_client.lib.agents.event_logger import interleaved_content_as_str from llama_stack_client.lib.agents.tool_parser import ToolParser from llama_stack_client.types.shared.completion_message import CompletionMessage from llama_stack_client.types.shared.tool_call import ToolCall from llama_stack_client.types.tool_execution_step import ToolExecutionStep from pydantic import BaseModel -from models.responses import RAGChunk + from constants import DEFAULT_RAG_TOOL +from models.responses import RAGChunk + +logger = logging.getLogger(__name__) + +# RAG Response Format Patterns +# ============================ +# These patterns match the format produced by llama-stack's knowledge_search tool. +# Source: llama_stack/providers/inline/tool_runtime/rag/memory.py +# +# The format consists of: +# - Header (hardcoded): "knowledge_search tool found N chunks:\nBEGIN of knowledge_search tool results.\n" +# - Chunks (configurable template, default): "Result {index}\nContent: {chunk.content}\nMetadata: {metadata}\n" +# - Footer (hardcoded): "END of knowledge_search tool results.\n" +# +# Note: The chunk template is configurable via RAGQueryConfig.chunk_template in llama-stack. +# If customized, these patterns may not match. A warning is logged when fallback occurs. + +# Pattern to match individual RAG result blocks: " Result N\nContent: ..." +# Captures result number and everything until the next result or end marker +RAG_RESULT_PATTERN = re.compile( + r"\s*Result\s+(\d+)\s*\nContent:\s*(.*?)(?=\s*Result\s+\d+\s*\n|END of knowledge_search)", + re.DOTALL, +) + +# Pattern to extract metadata dict from a result block +RAG_METADATA_PATTERN = re.compile(r"Metadata:\s*(\{[^}]+\})", re.DOTALL) class Singleton(type): @@ -117,47 +147,135 @@ def append_tool_calls_from_llama(self, tec: ToolExecutionStep) -> None: self._extract_rag_chunks_from_response(response_content) def _extract_rag_chunks_from_response(self, response_content: str) -> None: - """Extract RAG chunks from tool response content.""" + """Extract RAG chunks from tool response content. + + Parses RAG tool responses in multiple formats: + 1. JSON format with "chunks" array or list of chunk objects + 2. Formatted text with "Result N" blocks containing Content and Metadata + + For formatted text responses, extracts: + - Content text for each result + - Metadata including docs_url, title, chunk_id, document_id + """ + if not response_content or not response_content.strip(): + return + + # Try JSON format first + if self._try_parse_json_chunks(response_content): + return + + # Try formatted text with "Result N" blocks + if self._try_parse_formatted_chunks(response_content): + return + + # Fallback: treat entire response as single chunk + # This may indicate the RAG response format has changed + logger.warning( + "Unable to parse individual RAG chunks from response. " + "Falling back to single-chunk extraction. " + "This may indicate a change in the RAG tool response format. " + "Response preview: %.200s...", + response_content[:200] if len(response_content) > 200 else response_content, + ) + self.rag_chunks.append( + RAGChunk( + content=response_content, + source=DEFAULT_RAG_TOOL, + score=None, + ) + ) + + def _try_parse_json_chunks(self, response_content: str) -> bool: + """Try to parse response as JSON chunks. + + Returns True if successfully parsed, False otherwise. + """ try: - # Parse the response to get chunks - # Try JSON first - try: - data = json.loads(response_content) - if isinstance(data, dict) and "chunks" in data: - for chunk in data["chunks"]: + data = json.loads(response_content) + if isinstance(data, dict) and "chunks" in data: + for chunk in data["chunks"]: + self.rag_chunks.append( + RAGChunk( + content=chunk.get("content", ""), + source=chunk.get("source"), + score=chunk.get("score"), + ) + ) + return True + if isinstance(data, list): + for chunk in data: + if isinstance(chunk, dict): self.rag_chunks.append( RAGChunk( - content=chunk.get("content", ""), + content=chunk.get("content", str(chunk)), source=chunk.get("source"), score=chunk.get("score"), ) ) - elif isinstance(data, list): - # Handle list of chunks - for chunk in data: - if isinstance(chunk, dict): - self.rag_chunks.append( - RAGChunk( - content=chunk.get("content", str(chunk)), - source=chunk.get("source"), - score=chunk.get("score"), - ) - ) - except json.JSONDecodeError: - # If not JSON, treat the entire response as a single chunk - if response_content.strip(): - self.rag_chunks.append( - RAGChunk( - content=response_content, - source=DEFAULT_RAG_TOOL, - score=None, - ) - ) - except (KeyError, AttributeError, TypeError, ValueError): - # Treat response as single chunk on data access/structure errors - if response_content.strip(): - self.rag_chunks.append( - RAGChunk( - content=response_content, source=DEFAULT_RAG_TOOL, score=None - ) - ) + return bool(data) + except (json.JSONDecodeError, KeyError, AttributeError, TypeError, ValueError): + pass + return False + + def _try_parse_formatted_chunks(self, response_content: str) -> bool: + """Try to parse formatted text response with 'Result N' blocks. + + Parses responses in format: + knowledge_search tool found N chunks: + BEGIN of knowledge_search tool results. + Result 1 + Content: + Metadata: {'chunk_id': '...', 'docs_url': '...', 'title': '...', ...} + Result 2 + ... + END of knowledge_search tool results. + + Returns True if at least one chunk was parsed, False otherwise. + """ + # Check if this looks like a formatted RAG response + if "Result" not in response_content or "Content:" not in response_content: + return False + + matches = RAG_RESULT_PATTERN.findall(response_content) + if not matches: + return False + + for _result_num, content_block in matches: + chunk = self._parse_single_chunk(content_block) + if chunk: + self.rag_chunks.append(chunk) + + return bool(self.rag_chunks) + + def _parse_single_chunk(self, content_block: str) -> RAGChunk | None: + """Parse a single chunk from a content block. + + Args: + content_block: Text containing content and optionally metadata + + Returns: + RAGChunk if successfully parsed, None otherwise + """ + # Extract metadata if present + metadata: dict[str, Any] = {} + metadata_match = RAG_METADATA_PATTERN.search(content_block) + if metadata_match: + try: + metadata = ast.literal_eval(metadata_match.group(1)) + except (ValueError, SyntaxError) as e: + logger.debug("Failed to parse chunk metadata: %s", e) + + # Extract content (everything before "Metadata:" if present) + if metadata_match: + content = content_block[: metadata_match.start()].strip() + else: + content = content_block.strip() + + if not content: + return None + + return RAGChunk( + content=content, + source=metadata.get("docs_url") or metadata.get("source"), + score=metadata.get("score"), + ) diff --git a/tests/unit/utils/test_types.py b/tests/unit/utils/test_types.py index a2429baa..dbf468e4 100644 --- a/tests/unit/utils/test_types.py +++ b/tests/unit/utils/test_types.py @@ -1,8 +1,13 @@ """Unit tests for functions defined in utils/types.py.""" +import json +import logging + +import pytest from pytest_mock import MockerFixture -from utils.types import GraniteToolParser +from constants import DEFAULT_RAG_TOOL +from utils.types import GraniteToolParser, TurnSummary class TestGraniteToolParser: @@ -55,3 +60,213 @@ def test_get_tool_calls_from_completion_message_when_message_has_tool_calls( assert ( tool_parser.get_tool_calls(completion_message) == tool_calls ), f"get_tool_calls should return {tool_calls}" + + +class TestTurnSummaryExtractRagChunks: + """Unit tests for TurnSummary._extract_rag_chunks_from_response.""" + + # pylint: disable=protected-access + + def _create_turn_summary(self) -> TurnSummary: + """Create a TurnSummary instance for testing.""" + return TurnSummary(llm_response="test response", tool_calls=[]) + + def test_empty_response(self) -> None: + """Test that empty response produces no chunks.""" + summary = self._create_turn_summary() + summary._extract_rag_chunks_from_response("") + assert len(summary.rag_chunks) == 0 + + def test_whitespace_only_response(self) -> None: + """Test that whitespace-only response produces no chunks.""" + summary = self._create_turn_summary() + summary._extract_rag_chunks_from_response(" \n\t ") + assert len(summary.rag_chunks) == 0 + + def test_json_dict_with_chunks(self) -> None: + """Test parsing JSON dict with chunks array.""" + summary = self._create_turn_summary() + response = json.dumps( + { + "chunks": [ + {"content": "Chunk 1 content", "source": "doc1.md", "score": 0.95}, + {"content": "Chunk 2 content", "source": "doc2.md", "score": 0.85}, + ] + } + ) + summary._extract_rag_chunks_from_response(response) + + assert len(summary.rag_chunks) == 2 + assert summary.rag_chunks[0].content == "Chunk 1 content" + assert summary.rag_chunks[0].source == "doc1.md" + assert summary.rag_chunks[0].score == 0.95 + assert summary.rag_chunks[1].content == "Chunk 2 content" + assert summary.rag_chunks[1].source == "doc2.md" + assert summary.rag_chunks[1].score == 0.85 + + def test_json_list_of_chunks(self) -> None: + """Test parsing JSON list of chunk objects.""" + summary = self._create_turn_summary() + response = json.dumps( + [ + {"content": "First chunk", "source": "source1"}, + {"content": "Second chunk", "source": "source2"}, + ] + ) + summary._extract_rag_chunks_from_response(response) + + assert len(summary.rag_chunks) == 2 + assert summary.rag_chunks[0].content == "First chunk" + assert summary.rag_chunks[1].content == "Second chunk" + + def test_formatted_text_single_result(self) -> None: + """Test parsing formatted text with single Result block.""" + summary = self._create_turn_summary() + response = """knowledge_search tool found 1 chunks: +BEGIN of knowledge_search tool results. + Result 1 +Content: This is the content of the first chunk. +Metadata: {'chunk_id': 'abc123', 'document_id': 'doc1', 'docs_url': 'https://example.com/doc1', 'title': 'Example Doc'} +END of knowledge_search tool results. +""" + summary._extract_rag_chunks_from_response(response) + + assert len(summary.rag_chunks) == 1 + assert ( + summary.rag_chunks[0].content == "This is the content of the first chunk." + ) + assert summary.rag_chunks[0].source == "https://example.com/doc1" + + def test_formatted_text_multiple_results(self) -> None: + """Test parsing formatted text with multiple Result blocks.""" + summary = self._create_turn_summary() + response = """knowledge_search tool found 3 chunks: +BEGIN of knowledge_search tool results. + Result 1 +Content: First chunk content here. +Metadata: {'chunk_id': 'id1', 'docs_url': 'https://docs.example.com/page1', 'title': 'Page 1'} + Result 2 +Content: Second chunk with more text. +Metadata: {'chunk_id': 'id2', 'docs_url': 'https://docs.example.com/page2', 'title': 'Page 2'} + Result 3 +Content: Third and final chunk. +Metadata: {'chunk_id': 'id3', 'source': 'https://docs.example.com/page3', 'title': 'Page 3'} +END of knowledge_search tool results. +""" + summary._extract_rag_chunks_from_response(response) + + assert len(summary.rag_chunks) == 3 + assert summary.rag_chunks[0].content == "First chunk content here." + assert summary.rag_chunks[0].source == "https://docs.example.com/page1" + assert summary.rag_chunks[1].content == "Second chunk with more text." + assert summary.rag_chunks[1].source == "https://docs.example.com/page2" + assert summary.rag_chunks[2].content == "Third and final chunk." + # Falls back to 'source' when 'docs_url' is not present + assert summary.rag_chunks[2].source == "https://docs.example.com/page3" + + def test_formatted_text_multiline_content(self) -> None: + """Test parsing formatted text with multiline content.""" + summary = self._create_turn_summary() + response = """knowledge_search tool found 1 chunks: +BEGIN of knowledge_search tool results. + Result 1 +Content: # Heading + +This is a paragraph with multiple lines. + +* Bullet point 1 +* Bullet point 2 + +More text here. +Metadata: {'chunk_id': 'multi1', 'docs_url': 'https://example.com/multiline'} +END of knowledge_search tool results. +""" + summary._extract_rag_chunks_from_response(response) + + assert len(summary.rag_chunks) == 1 + assert "# Heading" in summary.rag_chunks[0].content + assert "* Bullet point 1" in summary.rag_chunks[0].content + assert "More text here." in summary.rag_chunks[0].content + + def test_formatted_text_without_metadata(self) -> None: + """Test parsing formatted text when metadata parsing fails.""" + summary = self._create_turn_summary() + response = """knowledge_search tool found 1 chunks: +BEGIN of knowledge_search tool results. + Result 1 +Content: Content without valid metadata. +Metadata: {invalid json here} +END of knowledge_search tool results. +""" + summary._extract_rag_chunks_from_response(response) + + assert len(summary.rag_chunks) == 1 + assert summary.rag_chunks[0].content == "Content without valid metadata." + assert summary.rag_chunks[0].source is None + + def test_fallback_to_single_chunk(self, caplog: pytest.LogCaptureFixture) -> None: + """Test fallback to treating response as single chunk with warning log.""" + summary = self._create_turn_summary() + response = "This is just plain text without any special formatting." + + with caplog.at_level(logging.WARNING): + summary._extract_rag_chunks_from_response(response) + + assert len(summary.rag_chunks) == 1 + assert summary.rag_chunks[0].content == response + assert summary.rag_chunks[0].source == DEFAULT_RAG_TOOL + assert summary.rag_chunks[0].score is None + + # Verify warning was logged + assert len(caplog.records) == 1 + assert "Unable to parse individual RAG chunks" in caplog.records[0].message + assert "Falling back to single-chunk extraction" in caplog.records[0].message + + def test_real_world_response_format(self) -> None: + """Test with real-world formatted response from knowledge_search.""" + summary = self._create_turn_summary() + response = """knowledge_search tool found 2 chunks: +BEGIN of knowledge_search tool results. + Result 1 +Content: # JobSet Operator overview + +Use the JobSet Operator on Red Hat OpenShift Container Platform to easily manage and run large-scale, coordinated workloads. + +[IMPORTANT] +---- +JobSet Operator is a Technology Preview feature only. +---- +Metadata: {'chunk_id': '901a76d0-dc86-438b-91a3-bfac880a0c17', 'document_id': '8a84b126-46ae-454d-a752-c21ea121cb0d', 'source': 'https://docs.openshift.com/container-platform//4.19', 'docs_url': 'https://docs.openshift.com/container-platform//4.19', 'title': 'JobSet Operator overview', 'url_reachable': False} + Result 2 +Content: The JobSet Operator automatically sets up stable headless service. +Metadata: {'chunk_id': '1240732d-33b5-4900-baeb-d63306c97080', 'document_id': '8a84b126-46ae-454d-a752-c21ea121cb0d', 'source': 'https://docs.openshift.com/container-platform//4.19', 'docs_url': 'https://docs.openshift.com/container-platform//4.19', 'title': 'JobSet Operator overview', 'url_reachable': False} +END of knowledge_search tool results. + The above results were retrieved to help answer the user's query. +""" + summary._extract_rag_chunks_from_response(response) + + assert len(summary.rag_chunks) == 2 + assert "# JobSet Operator overview" in summary.rag_chunks[0].content + assert "Technology Preview" in summary.rag_chunks[0].content + assert ( + summary.rag_chunks[0].source + == "https://docs.openshift.com/container-platform//4.19" + ) + assert "stable headless service" in summary.rag_chunks[1].content + assert ( + summary.rag_chunks[1].source + == "https://docs.openshift.com/container-platform//4.19" + ) + + def test_json_with_optional_fields(self) -> None: + """Test parsing JSON chunks with missing optional fields.""" + summary = self._create_turn_summary() + response = json.dumps( + {"chunks": [{"content": "Content only, no source or score"}]} + ) + summary._extract_rag_chunks_from_response(response) + + assert len(summary.rag_chunks) == 1 + assert summary.rag_chunks[0].content == "Content only, no source or score" + assert summary.rag_chunks[0].source is None + assert summary.rag_chunks[0].score is None