From 707d7fbca560af8b98aef519cc54a3dec7d690ef Mon Sep 17 00:00:00 2001 From: Senko Rasic Date: Fri, 20 Dec 2024 17:43:09 +0100 Subject: [PATCH 1/9] use gemini2.0 flash for llm integration tests --- tests/integration/test_llm.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/integration/test_llm.py b/tests/integration/test_llm.py index cb1db9b..71a92cf 100644 --- a/tests/integration/test_llm.py +++ b/tests/integration/test_llm.py @@ -39,7 +39,7 @@ def model_urls() -> list[str]: if getenv("ANTHROPIC_API_KEY"): retval.append("anthropic:///claude-3-haiku-20240307") if getenv("GEMINI_API_KEY"): - retval.append("google:///gemini-1.5-pro-latest") + retval.append("google:///gemini-2.0-flash-exp") if getenv("GROQ_API_KEY"): retval.append("groq:///llama-3.2-90b-vision-preview") if getenv("OLLAMA_MODEL"): From d5ce2030d87559536e6c5138632a9ac78f14713f Mon Sep 17 00:00:00 2001 From: Senko Rasic Date: Fri, 10 Jan 2025 12:59:52 +0100 Subject: [PATCH 2/9] Add support for RAG using txtai, chromadb and pinecone Work in progress, no tests or docs yet, API might change. --- pyproject.toml | 15 +++- think/rag/base.py | 146 ++++++++++++++++++++++++++++++++++++++ think/rag/chroma_rag.py | 75 ++++++++++++++++++++ think/rag/pinecone_rag.py | 139 ++++++++++++++++++++++++++++++++++++ think/rag/rag.py | 0 think/rag/txtai_rag.py | 68 ++++++++++++++++++ 6 files changed, 442 insertions(+), 1 deletion(-) create mode 100644 think/rag/base.py create mode 100644 think/rag/chroma_rag.py create mode 100644 think/rag/pinecone_rag.py create mode 100644 think/rag/rag.py create mode 100644 think/rag/txtai_rag.py diff --git a/pyproject.toml b/pyproject.toml index f7af16b..4dfa9b3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -15,7 +15,7 @@ authors = [ license = { text = "MIT" } homepage = "https://github.com/senko/think" repository = "https://github.com/senko/think" -keywords = ["ai", "llm"] +keywords = ["ai", "llm", "rag"] [project.optional-dependencies] openai = [ @@ -36,6 +36,15 @@ ollama = [ bedrock = [ "aioboto3>=13.2.0", ] +txtai = [ + "txtai>=8.1.0", +] +chromadb = [ + "chromadb>=0.6.2", +] +pinecone = [ + "pinecone>=5.4.2", +] all = [ "openai>=1.53.0", @@ -43,6 +52,10 @@ all = [ "google-generativeai>=0.8.3", "groq>=0.12.0", "ollama>=0.3.3", + "txtai>=8.1.0", + "chromadb>=0.6.2", + "pinecone>=5.4.2", + "pinecone-client>=4.1.2", ] [dependency-groups] diff --git a/think/rag/base.py b/think/rag/base.py new file mode 100644 index 0000000..d183655 --- /dev/null +++ b/think/rag/base.py @@ -0,0 +1,146 @@ +from abc import ABC, abstractmethod +from dataclasses import dataclass +from typing import TypedDict, TypeVar + +from ..ai import ask +from ..llm.base import LLM + +PreparedQueryT = TypeVar("PreparedQueryT") + + +class RagDocument(TypedDict): + id: str + text: str + + +@dataclass +class RagResult: + doc: RagDocument + score: float + + +BASE_ANSWER_PROMPT = """Based ONLY on the provided context: + +{% for item in results %} +{{ item.doc.text }} +Score: {{ item.score | round(3) }} +{% if not loop.last %} +--- +{% endif %} +{% endfor %} + +Answer the question: + +{{query}} + +(Note: don't say "based on provided context" in the output, it's confusing for the reader.) +""" + + +class RAG(ABC): + PROVIDERS = ["txtai", "chroma", "pinecone"] + QUERY_PROMPT: str | None = None + ANSWER_PROMPT: str = BASE_ANSWER_PROMPT + + def __init__( + self, + llm: LLM, + ): + self.llm = llm + + @abstractmethod + async def add_documents(self, documents: list[RagDocument]): + """ + Add documents to the RAG index. + + :param documents: Documents to add. + """ + + @abstractmethod + async def remove_documents(self, ids: list[str]): + """ + Remove documents from the RAG index. + + :param ids: Document IDs to remove. + """ + + async def prepare_query(self, query: str) -> PreparedQueryT: + """ + Process user input into query suitable for semantic search. + + :param query: User input. + :return: Query suitable for semantic search. + """ + if self.QUERY_PROMPT is None: + return query + return await ask(self.llm, self.QUERY_PROMPT, query=query) + + @abstractmethod + async def fetch_results( + self, + user_query: str, + prepared_query: PreparedQueryT, + limit: int, + ) -> list[RagResult]: + """ + Use the provided and processed query to search for relevant context. + + :param user_query: Unprocessed user input. + :param prepared_query: Processed user input. + :param limit: Maximum number of search results to return. + :return: Search results to include in the context. + """ + + async def get_answer(self, query: str, results: list[RagResult]) -> str: + """ + Ask the LLM to provide the answer based on the retrieved context + and the user's original query. + + :param query: User input. + :param results: Search results. + :return: Answer to the user query. + """ + return await ask(self.llm, self.ANSWER_PROMPT, results=results, query=query) + + async def __call__(self, query: str, limit: int = 10) -> str: + prepared_query = await self.prepare_query(query) + results = await self.fetch_results(query, prepared_query, limit) + return await self.get_answer(query, results) + + @abstractmethod + async def count(self) -> int: + """ + Get the number of documents in the RAG index. + + :return: The number of documents in the RAG index. + """ + + @classmethod + def for_provider(cls, provider: str) -> type["RAG"]: + """ + Get the RAG class for the specified provider/engine. + + :param provider: The provider name + :return The RAG class for the provider + + Raises a ValueError if the provider is not supported. + The list of supported providers is available in the + PROVIDERS class attribute. + """ + if provider == "txtai": + from .txtai_rag import TxtAiRag + + return TxtAiRag + + elif provider == "chroma": + from .chroma_rag import ChromaRag + + return ChromaRag + + elif provider == "pinecone": + from .pinecone_rag import PineconeRag + + return PineconeRag + + else: + raise ValueError(f"Unknown provider: {provider}") diff --git a/think/rag/chroma_rag.py b/think/rag/chroma_rag.py new file mode 100644 index 0000000..9d59881 --- /dev/null +++ b/think/rag/chroma_rag.py @@ -0,0 +1,75 @@ +from pathlib import Path + +from ..llm.base import LLM +from .base import RAG, RagDocument, RagResult + +try: + import chromadb +except ImportError as err: + raise ImportError( + "ChromaDB embeddings require the chromadb library: pip install chromadb" + ) from err + + +class ChromaRag(RAG): + def __init__( + self, + llm: LLM, + *, + collection: str, + path: Path | str | None = None, + ): + super().__init__(llm) + self.collection_name = collection + self.path = None if path is None else Path(path) + + if self.path: + self.path.mkdir(parents=True, exist_ok=True) + self.client = chromadb.PersistentClient(path=str(self.path)) + else: + self.client = chromadb.Client() + + self.collection = self.client.get_or_create_collection( + name=self.collection_name + ) + + async def add_documents(self, documents: list[RagDocument]): + # Extract document data + ids = [doc["id"] for doc in documents] + texts = [doc["text"] for doc in documents] + + self.collection.add( + documents=texts, + ids=ids, + ) + + async def remove_documents(self, ids: list[str]): + self.collection.delete(ids=ids) + + async def fetch_results( + self, user_query: str, prepared_query: str, limit: int + ) -> list[RagResult]: + results = self.collection.query( + query_texts=[prepared_query], + n_results=limit, + ) + + documents = [] + for doc_id, text, distance in zip( + results["ids"][0], + results["documents"][0], + results["distances"][0], + ): + score = 1.0 - distance + + documents.append( + RagResult( + doc={"id": doc_id, "text": text}, + score=score, + ), + ) + + return documents + + async def count(self) -> int: + return self.collection.count() diff --git a/think/rag/pinecone_rag.py b/think/rag/pinecone_rag.py new file mode 100644 index 0000000..ac9e5d6 --- /dev/null +++ b/think/rag/pinecone_rag.py @@ -0,0 +1,139 @@ +from os import getenv +from typing import List, Optional + +from ..llm.base import LLM +from .base import RAG, RagDocument, RagResult + +try: + from pinecone.grpc import PineconeGRPC as Pinecone +except ImportError as err: + raise ImportError( + "Pinecone requires the pinecone-client library: pip install pinecone-client" + ) from err + + +class PineconeRag(RAG): + DEFAULT_EMBEDDING_MODEL = "multilingual-e5-large" + DEFAULT_EMBED_BATCH_SIZE = 96 + + def __init__( + self, + llm: LLM, + *, + index_name: str, + api_key: Optional[str] = None, + embedding_model: str = DEFAULT_EMBEDDING_MODEL, + embed_batch_size: int = DEFAULT_EMBED_BATCH_SIZE, + ): + super().__init__(llm) + self.index_name = index_name + self.embedding_model = embedding_model + self.embed_batch_size = embed_batch_size + + # Initialize Pinecone client + self.api_key = api_key or getenv("PINECONE_API_KEY") + if not self.api_key: + raise ValueError( + "Pinecone API key must be provided either through constructor " + "or PINECONE_API_KEY environment variable" + ) + + try: + self.client = Pinecone(api_key=self.api_key) + self.index = self.client.Index(self.index_name) + except Exception as e: + raise RuntimeError(f"Failed to initialize Pinecone client: {e}") from e + + async def _embed_texts( + self, texts: List[str], is_query: bool = False + ) -> List[dict]: + """Helper method to embed texts using Pinecone's inference service.""" + try: + input_type = "query" if is_query else "passage" + + batches = [ + texts[i : i + self.embed_batch_size] + for i in range(0, len(texts), self.embed_batch_size) + ] + + embeddings = [] + for batch in batches: + # TODO: how to handle errors here? + batch_result = self.client.inference.embed( + model=self.embedding_model, + inputs=batch, + parameters={"input_type": input_type, "truncate": "END"}, + ) + embeddings.extend(batch_result) + + return embeddings + except Exception as e: + raise RuntimeError(f"Failed to generate embeddings: {e}") from e + + async def add_documents(self, documents: list[RagDocument]): + try: + # Generate embeddings for all documents + texts = [doc["text"] for doc in documents] + embeddings = await self._embed_texts(texts) + + # Prepare records for Pinecone + records = [] + for doc, embedding in zip(documents, embeddings): + records.append( + { + "id": doc["id"], + "values": embedding["values"], + "metadata": {"text": doc["text"]}, + } + ) + + # Upsert to Pinecone index + self.index.upsert(vectors=records) + except Exception as e: + raise RuntimeError(f"Failed to add documents to Pinecone: {e}") from e + + async def remove_documents(self, ids: list[str]): + try: + self.index.delete(ids=ids) + except Exception as e: + raise RuntimeError(f"Failed to remove documents from Pinecone: {e}") from e + + async def prepare_query(self, query: str) -> List[float]: + prepared_query: str = await super().prepare_query(query) + try: + embedding = await self._embed_texts([prepared_query], is_query=True) + return embedding[0]["values"] + except Exception as e: + raise RuntimeError(f"Failed to prepare query: {e}") from e + + async def fetch_results( + self, + user_query: str, + prepared_query: List[float], + limit: int, + ) -> list[RagResult]: + try: + response = self.index.query( + vector=prepared_query, + top_k=limit, + include_metadata=True, + ) + + results = [] + for match in response["matches"]: + results.append( + RagResult( + doc={"id": match["id"], "text": match["metadata"]["text"]}, + score=match["score"], + ) + ) + + return results + except Exception as e: + raise RuntimeError(f"Failed to fetch results from Pinecone: {e}") from e + + async def count(self) -> int: + try: + return self.index.describe_index_stats()["total_vector_count"] + except Exception as e: + raise RuntimeError(f"Failed to count records in Pinecone: {e}") from e diff --git a/think/rag/rag.py b/think/rag/rag.py new file mode 100644 index 0000000..e69de29 diff --git a/think/rag/txtai_rag.py b/think/rag/txtai_rag.py new file mode 100644 index 0000000..f2bbcc7 --- /dev/null +++ b/think/rag/txtai_rag.py @@ -0,0 +1,68 @@ +from pathlib import Path + +from ..llm.base import LLM +from .base import RAG, RagDocument, RagResult + +try: + from txtai import Embeddings +except ImportError as err: + raise ImportError( + "Txtai embeddings require the txtai library: pip install txtai" + ) from err + + +class TxtAiRag(RAG): + DEFAULT_EMBEDDINGS_MODEL = "sentence-transformers/all-MiniLM-L6-v2" + + def __init__( + self, + llm: LLM, + *, + model: str = DEFAULT_EMBEDDINGS_MODEL, + path: Path | str | None = None, + ): + super().__init__(llm) + self.model = model + self.path = None if path is None else Path(path) + + if self.path: + self.path.mkdir(parents=True, exist_ok=True) + + self.embeddings = Embeddings( + { + "path": self.model, + "content": True, + } + ) + if self.path: + if (self.path / "embeddings").exists(): + print("Loading from", self.path) + self.embeddings.load(str(self.path)) + else: + self.embeddings.save(str(self.path)) + + async def add_documents(self, documents: list[RagDocument]): + data = [(doc["id"], doc["text"]) for doc in documents] + self.embeddings.upsert(data) + if self.path: + self.embeddings.save(str(self.path)) + + async def remove_documents(self, ids: list[str]): + self.embeddings.delete(ids) + if self.path: + self.embeddings.save(str(self.path)) + + async def count(self) -> int: + return self.embeddings.count() + + async def fetch_results( + self, user_query: str, prepared_query: str, limit: int + ) -> list[RagResult]: + results = self.embeddings.search(prepared_query, limit=limit) + return [ + RagResult( + doc={"id": result["id"], "text": result["text"]}, + score=result["score"], + ) + for result in results + ] From e9aab406ee9b7d011e5e08fae6a264ddd864c9f1 Mon Sep 17 00:00:00 2001 From: Senko Rasic Date: Mon, 10 Feb 2025 15:37:54 +0100 Subject: [PATCH 3/9] add error handling to the bedrock client --- tests/integration/test_llm.py | 10 ++++-- think/llm/base.py | 8 ++++- think/llm/bedrock.py | 57 +++++++++++++++++++++++++++++++++-- 3 files changed, 68 insertions(+), 7 deletions(-) diff --git a/tests/integration/test_llm.py b/tests/integration/test_llm.py index 71a92cf..395dd34 100644 --- a/tests/integration/test_llm.py +++ b/tests/integration/test_llm.py @@ -1,5 +1,5 @@ import logging -from os import getenv +from os import environ, getenv import pytest from dotenv import load_dotenv @@ -168,9 +168,13 @@ async def test_custom_parser(url): @pytest.mark.parametrize("url", api_model_urls()) @pytest.mark.asyncio -async def test_auth_error(url): +async def test_auth_error(url, monkeypatch): + for key in environ.keys(): + if key.endswith("_API_KEY") or key.startswith("AWS_"): + monkeypatch.setenv(key, "") + c = Chat("You're a friendly assistant").user("Tell me a joke") - invalid_key_url = url.replace("///", "//testing-incorrect-key@/") + invalid_key_url = url.replace("///", "//testing-incorrect-key:abc@/") llm = LLM.from_url(invalid_key_url) with pytest.raises(ConfigError): diff --git a/think/llm/base.py b/think/llm/base.py index 7fef38d..8cbf9a1 100644 --- a/think/llm/base.py +++ b/think/llm/base.py @@ -255,9 +255,15 @@ def from_url(cls, url: str) -> "LLM": ) extra_params = {k: v[0] for k, v in query.items()} + if result.username and result.password: + api_key = f"{result.username}:{result.password}" + elif result.username: + api_key = result.username + else: + api_key = None return cls.for_provider(result.scheme)( model=model, - api_key=result.username, + api_key=api_key, base_url=base_url, **extra_params, ) diff --git a/think/llm/bedrock.py b/think/llm/bedrock.py index c59e9a5..c805df7 100644 --- a/think/llm/bedrock.py +++ b/think/llm/bedrock.py @@ -6,13 +6,19 @@ try: from aioboto3 import Session + from botocore.exceptions import ( + ClientError, + EndpointConnectionError, + NoCredentialsError, + ParamValidationError, + ) except ImportError as err: raise ImportError( "AWS Bedrock client requires the async Boto3 SDK: pip install aioboto3" ) from err -from .base import LLM, BaseAdapter, PydanticResultT +from .base import LLM, BadRequestError, BaseAdapter, ConfigError, PydanticResultT from .chat import Chat, ContentPart, ContentType, Message, Role from .tool import ToolCall, ToolDefinition, ToolResponse @@ -198,10 +204,23 @@ def __init__( **kwargs: str, ): super().__init__(model, api_key=api_key, base_url=base_url) + region = kwargs.get("region") if region is None: raise ValueError("AWS Bedrock client requires a region to be specified") - self.session = Session(region_name=region) + + key_id = None + key_secret = None + if api_key: + if ":" not in api_key: + raise ValueError("AWS Bedrock client requires key ID and secret") + else: + key_id, secret = api_key.split(":", 1) + self.session = Session( + region_name=region, + aws_access_key_id=key_id, + aws_secret_access_key=key_secret, + ) async def _internal_call( self, @@ -233,8 +252,24 @@ async def _internal_call( kwargs["toolConfig"] = adapter.spec raw_message = await client.converse(**kwargs) + except (NoCredentialsError, EndpointConnectionError) as err: + raise ConfigError(err.fmt) from err + except ClientError as err: + error = err.response.get("Error", {}) + error_code = error.get("Code") + error_message = error.get("Message") + if error_code in [ + "InvalidSignatureException", + "UnrecognizedClientException", + ]: + raise ConfigError( + error_message or "Unknown client/credentials error" + ) + raise + except ParamValidationError as err: + raise BadRequestError(err.fmt) from err except: - raise # FIXME + raise return adapter.parse_message(raw_message["output"]["message"]) @@ -267,5 +302,21 @@ async def _internal_stream( async for event in stream: if "contentBlockDelta" in event: yield event["contentBlockDelta"]["delta"]["text"] + except (NoCredentialsError, EndpointConnectionError) as err: + raise ConfigError(err.fmt) from err + except ClientError as err: + error = err.response.get("Error", {}) + error_code = error.get("Code") + error_message = error.get("Message") + if error_code in [ + "InvalidSignatureException", + "UnrecognizedClientException", + ]: + raise ConfigError( + error_message or "Unknown client/credentials error" + ) + raise + except ParamValidationError as err: + raise BadRequestError(err.fmt) from err except: raise From 7c844999bdef554ee9011ae9a6e9fc1a8f68b309 Mon Sep 17 00:00:00 2001 From: Senko Rasic Date: Mon, 10 Feb 2025 15:39:04 +0100 Subject: [PATCH 4/9] fix ollama vision/tool model tests + a few minor fixups --- tests/integration/test_llm.py | 11 +++++++---- think/llm/base.py | 5 ----- think/llm/openai.py | 4 +++- 3 files changed, 10 insertions(+), 10 deletions(-) diff --git a/tests/integration/test_llm.py b/tests/integration/test_llm.py index 395dd34..dcbf9d0 100644 --- a/tests/integration/test_llm.py +++ b/tests/integration/test_llm.py @@ -27,7 +27,7 @@ pytest.skip("Skipping integration tests", allow_module_level=True) -def model_urls() -> list[str]: +def model_urls(vision: bool = False) -> list[str]: """ Returns a list of models to test with, based on available API keys. @@ -39,11 +39,14 @@ def model_urls() -> list[str]: if getenv("ANTHROPIC_API_KEY"): retval.append("anthropic:///claude-3-haiku-20240307") if getenv("GEMINI_API_KEY"): - retval.append("google:///gemini-2.0-flash-exp") + retval.append("google:///gemini-2.0-flash-lite-preview-02-05") if getenv("GROQ_API_KEY"): retval.append("groq:///llama-3.2-90b-vision-preview") if getenv("OLLAMA_MODEL"): - retval.append(f"ollama:///{getenv('OLLAMA_MODEL')}") + if vision: + retval.append(f"ollama:///{getenv('OLLAMA_VISION_MODEL')}") + else: + retval.append(f"ollama:///{getenv('OLLAMA_MODEL')}") if getenv("AWS_SECRET_ACCESS_KEY"): retval.append("bedrock:///amazon.nova-lite-v1:0?region=us-east-1") if retval == []: @@ -104,7 +107,7 @@ def get_temperature(city: str) -> str: assert tool_called, "Expected 'get_temperature' tool to be called" -@pytest.mark.parametrize("url", model_urls()) +@pytest.mark.parametrize("url", model_urls(vision=True)) @pytest.mark.asyncio async def test_vision(url): c = Chat("You're a friendly assistant").user( diff --git a/think/llm/base.py b/think/llm/base.py index 8cbf9a1..b866e6a 100644 --- a/think/llm/base.py +++ b/think/llm/base.py @@ -211,11 +211,6 @@ def from_url(cls, url: str) -> "LLM": """ Initialize an LLM client from a URL. - Arguments: - - `url`: The URL to initialize the client from - - Returns the LLM client instance. - :param url: The URL to initialize the client from :return: The LLM client instance diff --git a/think/llm/openai.py b/think/llm/openai.py index a205e2e..a640bc2 100644 --- a/think/llm/openai.py +++ b/think/llm/openai.py @@ -11,6 +11,8 @@ AsyncStream, AuthenticationError, NotFoundError, + ) + from openai import ( BadRequestError as OpenAIBadRequestError, ) from openai.types.chat import ChatCompletionChunk @@ -20,7 +22,7 @@ "OpenAI client requires the OpenAI Python SDK: pip install openai" ) from err -from .base import LLM, BaseAdapter, ConfigError, BadRequestError, PydanticResultT +from .base import LLM, BadRequestError, BaseAdapter, ConfigError, PydanticResultT from .chat import Chat, ContentPart, ContentType, Message, Role from .tool import ToolCall, ToolDefinition, ToolResponse From 35ef07fe43b70dec0eb4c78caca2034a094496fc Mon Sep 17 00:00:00 2001 From: Senko Rasic Date: Wed, 12 Feb 2025 18:29:56 +0100 Subject: [PATCH 5/9] rag: add similarity calculation, reranking support, and a few fixes --- think/rag/base.py | 62 ++++++++++++++++++++++++++++++++++----- think/rag/chroma_rag.py | 31 +++++++++++++++++--- think/rag/pinecone_rag.py | 32 +++++++++++++------- think/rag/txtai_rag.py | 3 ++ 4 files changed, 105 insertions(+), 23 deletions(-) diff --git a/think/rag/base.py b/think/rag/base.py index d183655..d9fd96d 100644 --- a/think/rag/base.py +++ b/think/rag/base.py @@ -1,11 +1,10 @@ +import math from abc import ABC, abstractmethod from dataclasses import dataclass -from typing import TypedDict, TypeVar +from typing import TypedDict, Any -from ..ai import ask -from ..llm.base import LLM - -PreparedQueryT = TypeVar("PreparedQueryT") +from think.ai import ask +from think.llm.base import LLM class RagDocument(TypedDict): @@ -45,6 +44,7 @@ class RAG(ABC): def __init__( self, llm: LLM, + **kwargs: Any, ): self.llm = llm @@ -64,7 +64,7 @@ async def remove_documents(self, ids: list[str]): :param ids: Document IDs to remove. """ - async def prepare_query(self, query: str) -> PreparedQueryT: + async def prepare_query(self, query: str) -> str: """ Process user input into query suitable for semantic search. @@ -79,7 +79,7 @@ async def prepare_query(self, query: str) -> PreparedQueryT: async def fetch_results( self, user_query: str, - prepared_query: PreparedQueryT, + prepared_query: str, limit: int, ) -> list[RagResult]: """ @@ -102,10 +102,30 @@ async def get_answer(self, query: str, results: list[RagResult]) -> str: """ return await ask(self.llm, self.ANSWER_PROMPT, results=results, query=query) + async def rerank(self, results: list[RagResult]) -> list[RagResult]: + """ + Rerank the search results based on the user query. + + :param results: Search results. + :return: Reranked search results + """ + return results + async def __call__(self, query: str, limit: int = 10) -> str: prepared_query = await self.prepare_query(query) results = await self.fetch_results(query, prepared_query, limit) - return await self.get_answer(query, results) + reranked_results = await self.rerank(results) + return await self.get_answer(query, reranked_results) + + @abstractmethod + async def calculate_similarity(self, query: str, docs: list[str]) -> list[float]: + """ + Calculate the similarity between the query and the text. + + :param query: User input. + :param text: Text to compare. + :return: Similarity score. + """ @abstractmethod async def count(self) -> int: @@ -144,3 +164,29 @@ def for_provider(cls, provider: str) -> type["RAG"]: else: raise ValueError(f"Unknown provider: {provider}") + + @staticmethod + def _cosine_similarity(a: list[float], b: list[float]): + """ + Compute cosine similarity between two vectors of equal length. + + :param a: First vector + :param b: Second vector + :return: Cosine similarity between the two vectors + """ + if len(a) != len(b): + raise ValueError("Vectors must be of equal length") + + # Compute dot product + dot_product = sum(x * y for x, y in zip(a, b)) + + # Compute magnitudes + magnitude1 = math.sqrt(sum(x * x for x in a)) + magnitude2 = math.sqrt(sum(x * x for x in b)) + + # Prevent division by zero + if magnitude1 == 0 or magnitude2 == 0: + return 0.0 + + # Compute cosine similarity + return dot_product / (magnitude1 * magnitude2) diff --git a/think/rag/chroma_rag.py b/think/rag/chroma_rag.py index 9d59881..1142886 100644 --- a/think/rag/chroma_rag.py +++ b/think/rag/chroma_rag.py @@ -54,17 +54,23 @@ async def fetch_results( n_results=limit, ) + ids = results.get("ids") + docs = results.get("documents") + distances = results.get("distances") + if not ids or not docs or not distances: + return [] + documents = [] for doc_id, text, distance in zip( - results["ids"][0], - results["documents"][0], - results["distances"][0], + ids[0], + docs[0], + distances[0], ): score = 1.0 - distance documents.append( RagResult( - doc={"id": doc_id, "text": text}, + doc=RagDocument(id=str(doc_id), text=text), score=score, ), ) @@ -73,3 +79,20 @@ async def fetch_results( async def count(self) -> int: return self.collection.count() + + async def calculate_similarity(self, query: str, docs: list[str]) -> list[float]: + inputs = [query] + docs + assert ( + self.collection._embedding_function + ), "Cannot calculate similarity without an embedding function" + vectors = self.collection._embedding_function(inputs) + query_vector, *doc_vectors = vectors + similarities = [] + for doc_vector in doc_vectors: + similarities.append( + self._cosine_similarity( + query_vector["values"], + doc_vector["values"], + ) + ) + return similarities diff --git a/think/rag/pinecone_rag.py b/think/rag/pinecone_rag.py index ac9e5d6..2a03ee3 100644 --- a/think/rag/pinecone_rag.py +++ b/think/rag/pinecone_rag.py @@ -87,7 +87,6 @@ async def add_documents(self, documents: list[RagDocument]): } ) - # Upsert to Pinecone index self.index.upsert(vectors=records) except Exception as e: raise RuntimeError(f"Failed to add documents to Pinecone: {e}") from e @@ -98,23 +97,21 @@ async def remove_documents(self, ids: list[str]): except Exception as e: raise RuntimeError(f"Failed to remove documents from Pinecone: {e}") from e - async def prepare_query(self, query: str) -> List[float]: - prepared_query: str = await super().prepare_query(query) - try: - embedding = await self._embed_texts([prepared_query], is_query=True) - return embedding[0]["values"] - except Exception as e: - raise RuntimeError(f"Failed to prepare query: {e}") from e - async def fetch_results( self, user_query: str, - prepared_query: List[float], + prepared_query: str, limit: int, ) -> list[RagResult]: + try: + embedding = await self._embed_texts([prepared_query], is_query=True) + vector = embedding[0]["values"] + except Exception as e: + raise RuntimeError(f"Failed to prepare query: {e}") from e + try: response = self.index.query( - vector=prepared_query, + vector=vector, top_k=limit, include_metadata=True, ) @@ -137,3 +134,16 @@ async def count(self) -> int: return self.index.describe_index_stats()["total_vector_count"] except Exception as e: raise RuntimeError(f"Failed to count records in Pinecone: {e}") from e + + async def calculate_similarity(self, query: str, docs: list[str]) -> list[float]: + vectors = await self._embed_texts([query] + docs) + query_vector, *doc_vectors = vectors + similarities = [] + for doc_vector in doc_vectors: + similarities.append( + self._cosine_similarity( + query_vector["values"], + doc_vector["values"], + ) + ) + return similarities diff --git a/think/rag/txtai_rag.py b/think/rag/txtai_rag.py index f2bbcc7..5c704bd 100644 --- a/think/rag/txtai_rag.py +++ b/think/rag/txtai_rag.py @@ -66,3 +66,6 @@ async def fetch_results( ) for result in results ] + + async def calculate_similarity(self, query: str, docs: list[str]) -> list[float]: + return self.embeddings.similarity(query, docs) From fa815833fa9ac1a578afa188f1c5a7ba388b98f8 Mon Sep 17 00:00:00 2001 From: Senko Rasic Date: Wed, 12 Feb 2025 18:31:13 +0100 Subject: [PATCH 6/9] add rag eval --- think/rag/eval.py | 270 ++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 270 insertions(+) create mode 100644 think/rag/eval.py diff --git a/think/rag/eval.py b/think/rag/eval.py new file mode 100644 index 0000000..aeb5082 --- /dev/null +++ b/think/rag/eval.py @@ -0,0 +1,270 @@ +from ..ai import ask +from ..llm.base import LLM +from .base import RAG + + +class RagEval: + CONTEXT_PRECISION_PROMPT = """ + You're tasked with evaluating a knowledge retrieval system. For a user query, you're + given a document retrieved by the system. Based on the document alone, you need to + determine if it's relevant to the query. + + User query: {{ query }} + + Document: {{ result }} + + Answer with "yes" if the document is relevant to the query, or "no" otherwise. + """ + + CLAIM_SPLIT_PROMPT = """ + You're tasked with evaluating a knowledge retrieval system. Given a + {% if is_reference %}ground truth (reference){% else %}system output (answer){% endif %} text, + your task is to split it into individual claims. + + Example: + > Text: "The quick brown fox jumps over the lazy dog." + > Claims: + > The fox is quick and brown. + > The fox jumps over the dog. + > The dog is lazy. + + Note: do not extract trivial claims like "the fox exists". + + Here's the {% if is_reference %}ground truth (reference){% else %}system output (answer){% endif %} text: + {{ text }} + + Please split it into individual claims. Separate each claim with a newline. Do not include + any comments, explanations, or additional information - you must only output the claims themselves. + """ + + CONTEXT_RECALL_PROMPT = """ + You're tasked with evaluating a knowledge retrieval system. For a specific fact or claim, + you're given a set of documents retrieved by the system. Based on the documents alone, + you need to determine if this claim is supported by the documents. + + Claim: {{ claim }} + + Supporting documents: + + {% for result in results %} + * {{ result.doc.text }} + {% endfor %} + + Answer with "yes" if the claim is supported by the documents, or "no" otherwise. + """ + + GENERATE_QUESTIONS_PROMPT = """ + You're tasked with evaluating a knowledge retrieval system. Given a system output (answer), + your task is to generate a set of {{ n_questions }} questions + that the answer is a suitable response for. + + Example: + > Answer: "Paris is the capital of France." + > Questions: + > Which city is the capital of France? + > Paris is the capital of which country? + > ... + + Here is the system output (answer): {{ answer }} + + Please generate {{ n_questions }} questions that this answer is a suitable response for. + + Please output each question on a separate line (ie. separate them by newlines). Do not + include any comments, explanations, or additional information - you must only output the + generated questions. + """ + + def __init__(self, rag: RAG, llm: LLM): + self.rag = rag + self.llm = llm + + async def context_precision( + self, + query: str, + n_results: int = 10, + ) -> float: + """ + Calculate Precision@k and average context precision @k. + + This metric indicates how well the system is able to retrieve relevant documents + for a given query. + + Precision at rank k is the number of relevant documents retrieved in the top k + results divided by k. Context precision (or average precision) at rank k is + the average precision at each rank k. + + More info: https://docs.ragas.io/en/stable/concepts/metrics/available_metrics/context_precision/ + + :param query: The query to evaluate. + :param n_results: The number of results to evaluate. + :return: The average context precision at rank k, in the range [0, 1]. + """ + prepared_query = await self.rag.prepare_query(query) + results = await self.rag.fetch_results(query, prepared_query, n_results) + + n_relevant = 0 + ctx_precision = 0 + + for i, result in enumerate(results): + r = await ask( + self.llm, + self.CONTEXT_PRECISION_PROMPT, + query=query, + result=result.doc["text"], + ) + if "yes" in r.lower(): + n_relevant += 1 + + k = i + 1 # rank + precision_k = n_relevant / k + ctx_precision += precision_k + + return ctx_precision / n_results + + async def split_into_claims( + self, + text: str, + is_reference: bool = False, + ) -> list[str]: + """ + Split the text into individual constituent claims. + + The text can be reference text (ground truth) or system output (answer). + """ + r = await ask( + self.llm, + self.CLAIM_SPLIT_PROMPT, + text=text, + is_reference=is_reference, + ) + + return [line.strip() for line in r.split("\n") if line.strip()] + + async def _supported_by_claims( + self, + query: str, + reference: str | list[str], + is_reference: bool, + n_results: int, + ) -> float: + if isinstance(reference, str): + reference_claims = await self.split_into_claims( + reference, + is_reference=is_reference, + ) + elif isinstance(reference, list) and len(reference) > 0: + reference_claims = reference + else: + raise ValueError( + "Reference must be a string or a non-empty list of strings." + ) + + prepared_query = await self.rag.prepare_query(query) + results = await self.rag.fetch_results(query, prepared_query, n_results) + + n_supported = 0 + + for claim in reference_claims: + r = await ask( + self.llm, + self.CONTEXT_RECALL_PROMPT, + claim=claim, + results=results, + ) + if "yes" in r.lower(): + n_supported += 1 + + return n_supported / len(reference_claims) + + async def context_recall( + self, + query: str, + reference: str | list[str], + n_results: int = 10, + ) -> float: + """ + Calculate Context Recall + + This metric indicates how well the system is able to retrieve documents that + support the ground truth (reference). + + This is estimated by checking if all the ground truth claims (list of reference claims) + is supported by the retrieved documents. + + More info: https://docs.ragas.io/en/stable/concepts/metrics/available_metrics/context_recall/ + + :param query: The query to evaluate. + :param reference: The ground truth claims. + :param n_results: The number of results to use in the context. + :return: The context recall score, in the range [0, 1]. + """ + return await self._supported_by_claims( + query, + reference, + is_reference=True, + n_results=n_results, + ) + + async def faithfulness( + self, + query: str, + answer: str, + n_results: int = 10, + ) -> float: + """ + Calculate answer Faithfulness + + This metric indicates how well the system is able to provide answers that are + supported by the retrieved documents. + + The answer is split into constituent claims, and each claim is checked to see if + it is supported by the retrieved documents. + + More info: https://docs.ragas.io/en/stable/concepts/metrics/available_metrics/faithfulness/ + + :param query: The query to evaluate. + :param answer: The system output / answer. + :param n_results: The number of results to use in the context. + :return: The faithfulness score, in the range [0, 1]. + """ + return await self._supported_by_claims( + query, + answer, + is_reference=False, + n_results=n_results, + ) + + async def answer_relevance( + self, + query: str, + answer: str, + n_questions: int = 3, + ) -> float: + """ + Calculate Answer (Response) Relevance + + This metrics indicates how relevant the answer is to the user query. + + The answer is used to generate a set of artificial questions the answer + is a suitable response for. The questions are embedded and the cosine + similarity between the query and each question is calculated. + + The score is the average similarity across all the generated questions. + + More info: https://docs.ragas.io/en/stable/concepts/metrics/available_metrics/answer_relevance/ + + :param query: The user query. + :param answer: The system output / answer. + :param n_questions: The number of questions to generate. + """ + + r = await ask( + self.llm, + self.GENERATE_QUESTIONS_PROMPT, + answer=answer, + n_questions=n_questions, + ) + questions = [line.strip() for line in r.split("\n") if line.strip()] + + similarities = await self.rag.calculate_similarity(query, questions) + return sum(similarities) / n_questions From a40e1323790d495711b5edbd6771cb2ccce459cc Mon Sep 17 00:00:00 2001 From: Senko Rasic Date: Wed, 12 Feb 2025 18:31:46 +0100 Subject: [PATCH 7/9] add tests for rag & rag eval --- tests/conftest.py | 39 +++++++ tests/integration/test_llm.py | 33 +----- tests/rag/test_eval.py | 103 ++++++++++++++++++ tests/rag/test_rag.py | 194 ++++++++++++++++++++++++++++++++++ 4 files changed, 338 insertions(+), 31 deletions(-) create mode 100644 tests/conftest.py create mode 100644 tests/rag/test_eval.py create mode 100644 tests/rag/test_rag.py diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 0000000..3dd7819 --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,39 @@ +from os import getenv + + +def model_urls(vision: bool = False) -> list[str]: + """ + Returns a list of models to test with, based on available API keys. + + :return: A list of model URLs based on the available API keys. + """ + retval = [] + if getenv("OPENAI_API_KEY"): + retval.append("openai:///gpt-4o-mini") + if getenv("ANTHROPIC_API_KEY"): + retval.append("anthropic:///claude-3-haiku-20240307") + if getenv("GEMINI_API_KEY"): + retval.append("google:///gemini-2.0-flash-lite-preview-02-05") + if getenv("GROQ_API_KEY"): + retval.append("groq:///llama-3.2-90b-vision-preview") + if getenv("OLLAMA_MODEL"): + if vision: + retval.append(f"ollama:///{getenv('OLLAMA_VISION_MODEL')}") + else: + retval.append(f"ollama:///{getenv('OLLAMA_MODEL')}") + if getenv("AWS_SECRET_ACCESS_KEY"): + retval.append("bedrock:///amazon.nova-lite-v1:0?region=us-east-1") + if retval == []: + raise RuntimeError("No LLM API keys found in environment") + return retval + + +def api_model_urls() -> list[str]: + return [url for url in model_urls() if not url.startswith("ollama:")] + + +def first_model_url() -> str: + urls = model_urls() + if urls: + return urls[0] + raise RuntimeError("No LLM API keys found in environment") diff --git a/tests/integration/test_llm.py b/tests/integration/test_llm.py index dcbf9d0..9af1076 100644 --- a/tests/integration/test_llm.py +++ b/tests/integration/test_llm.py @@ -9,6 +9,8 @@ from think.llm.base import BadRequestError, ConfigError from think.llm.chat import Chat +from conftest import api_model_urls, model_urls + load_dotenv() logging.basicConfig(level=logging.DEBUG) @@ -27,37 +29,6 @@ pytest.skip("Skipping integration tests", allow_module_level=True) -def model_urls(vision: bool = False) -> list[str]: - """ - Returns a list of models to test with, based on available API keys. - - :return: A list of model URLs based on the available API keys. - """ - retval = [] - if getenv("OPENAI_API_KEY"): - retval.append("openai:///gpt-4o-mini") - if getenv("ANTHROPIC_API_KEY"): - retval.append("anthropic:///claude-3-haiku-20240307") - if getenv("GEMINI_API_KEY"): - retval.append("google:///gemini-2.0-flash-lite-preview-02-05") - if getenv("GROQ_API_KEY"): - retval.append("groq:///llama-3.2-90b-vision-preview") - if getenv("OLLAMA_MODEL"): - if vision: - retval.append(f"ollama:///{getenv('OLLAMA_VISION_MODEL')}") - else: - retval.append(f"ollama:///{getenv('OLLAMA_MODEL')}") - if getenv("AWS_SECRET_ACCESS_KEY"): - retval.append("bedrock:///amazon.nova-lite-v1:0?region=us-east-1") - if retval == []: - raise RuntimeError("No LLM API keys found in environment") - return retval - - -def api_model_urls() -> list[str]: - return [url for url in model_urls() if not url.startswith("ollama:")] - - @pytest.mark.parametrize("url", model_urls()) @pytest.mark.asyncio async def test_basic_request(url): diff --git a/tests/rag/test_eval.py b/tests/rag/test_eval.py new file mode 100644 index 0000000..83dc515 --- /dev/null +++ b/tests/rag/test_eval.py @@ -0,0 +1,103 @@ +import pytest +from unittest.mock import patch, MagicMock, AsyncMock + +from think.rag.base import RagResult, RagDocument +from think.rag.eval import RagEval + + +@pytest.mark.asyncio +@patch("think.rag.eval.ask", new_callable=AsyncMock) +async def test_context_precision(ask): + llm = MagicMock() + rag = AsyncMock() + rag.fetch_results.return_value = [ + RagResult(doc=RagDocument(id="A", text="A"), score=0.5), + RagResult(doc=RagDocument(id="B", text="B"), score=0.5), + RagResult(doc=RagDocument(id="C", text="C"), score=0.5), + RagResult(doc=RagDocument(id="D", text="D"), score=0.5), + ] + ask.side_effect = ["yes", "no", "yes", "no"] + + eval = RagEval(rag, llm) + + query = "A movie about a ship that sinks" + ctx_precision = await eval.context_precision(query, 2) + + assert ctx_precision == pytest.approx(1.33, rel=1e-2) + + +@pytest.mark.asyncio +@patch("think.rag.eval.ask", new_callable=AsyncMock) +async def test_context_recall(ask): + llm = MagicMock() + rag = AsyncMock() + rag.fetch_results.return_value = [ + RagResult(doc=RagDocument(id="A", text="A"), score=0.5), + RagResult(doc=RagDocument(id="B", text="B"), score=0.5), + RagResult(doc=RagDocument(id="C", text="C"), score=0.5), + RagResult(doc=RagDocument(id="D", text="D"), score=0.5), + ] + ask.side_effect = ["yes", "no", "yes", "no"] + + eval = RagEval(rag, llm) + + query = "A movie about a ship that sinks" + reference = [ + "A ship sinks", + "A love story", + "A historical event", + "A tragedy", + ] + ctx_recall = await eval.context_recall(query, reference, 4) + + assert ctx_recall == pytest.approx(0.5, rel=1e-2) + + +@pytest.mark.asyncio +@patch("think.rag.eval.ask", new_callable=AsyncMock) +async def test_faithfulness(ask): + llm = MagicMock() + rag = AsyncMock() + rag.fetch_results.return_value = [ + RagResult(doc=RagDocument(id="A", text="A"), score=0.5), + RagResult(doc=RagDocument(id="B", text="B"), score=0.5), + RagResult(doc=RagDocument(id="C", text="C"), score=0.5), + RagResult(doc=RagDocument(id="D", text="D"), score=0.5), + ] + ask.side_effect = [ + "The Titanic sank after hitting an iceberg.\n" + + "Many people died.\n" + + "It was a tragic event\n", + "yes", + "no", + "yes", + ] + + eval = RagEval(rag, llm) + + query = "What happened to the Titanic?" + answer = "The Titanic sank after hitting an iceberg. Many people died. It was a tragic event." + faithfulness_score = await eval.faithfulness(query, answer, 4) + + assert faithfulness_score == pytest.approx(0.67, rel=1e-2) + + +@pytest.mark.asyncio +@patch("think.rag.eval.ask", new_callable=AsyncMock) +async def test_answer_relevance(ask): + llm = MagicMock() + rag = AsyncMock() + rag.calculate_similarity.return_value = [0.8, 0.75, 0.85] + ask.side_effect = [ + "What is the fate of the Titanic?\n" + + "How did the Titanic sink?\n" + + "What happened to the Titanic?" + ] + + eval = RagEval(rag, llm) + + query = "What happened to the Titanic?" + answer = "The Titanic sank after hitting an iceberg. Many people died. It was a tragic event." + relevance_score = await eval.answer_relevance(query, answer, 3) + + assert relevance_score == pytest.approx(0.8, rel=1e-2) diff --git a/tests/rag/test_rag.py b/tests/rag/test_rag.py new file mode 100644 index 0000000..8da2db8 --- /dev/null +++ b/tests/rag/test_rag.py @@ -0,0 +1,194 @@ +from dotenv import load_dotenv +from os import getenv +import pytest +import time + + +from think import LLM +from think.rag.base import RAG, RagDocument + +from conftest import first_model_url + +load_dotenv() + +if getenv("INTEGRATION_TESTS", "").lower() not in ["true", "yes", "1", "on"]: + pytest.skip("Skipping integration tests", allow_module_level=True) + + +LLM_URL = first_model_url() + +MOVIES = [ + "Titanic (1997): A sweeping romantic epic set against the backdrop of the " + "ill-fated Titanic, following the love story of Jack and Rose as they " + "navigate class divisions and impending disaster.", + "The Godfather (1972): A gripping mafia saga that follows the Corleone " + "crime family, led by patriarch Vito and his reluctant son Michael, as " + "they navigate power, loyalty, and betrayal.", + "Schindler's List (1993): A harrowing and deeply moving Holocaust drama " + "about Oskar Schindler, a businessman who saves over a thousand Jewish " + "lives by employing them in his factory.", + "The Lord of the Rings: The Return of the King (2003): The epic conclusion " + "to the fantasy trilogy, featuring the final battle for Middle-earth and " + "Frodo's journey to destroy the One Ring.", + "Ben-Hur (1959): A monumental historical drama following Judah Ben-Hur, " + "a nobleman betrayed into slavery who seeks revenge, culminating in a " + "legendary chariot race.", + "Forrest Gump (1994): A heartwarming tale of a simple man whose accidental " + "presence in key historical events shows how love and kindness shape his " + "extraordinary life.", + "Casablanca (1942): A timeless romance set in WWII, following Rick Blaine, " + "a cynical American expatriate, as he must choose between love and aiding " + "the resistance.", + "One Flew Over the Cuckoo's Nest (1975): A rebellious patient shakes up a " + "rigid mental institution, challenging authority and inspiring fellow " + "inmates to reclaim their dignity.", + "Gladiator (2000): A revenge-driven historical epic where a betrayed Roman " + "general fights as a gladiator to avenge his family and bring justice to a " + "corrupt emperor.", + "Gone with the Wind (1939): A sweeping Civil War-era romance that follows " + "the tumultuous life of the headstrong Scarlett O'Hara and her relationship " + "with Rhett Butler.", + "The Silence of the Lambs (1991): A chilling psychological thriller in which " + "young FBI agent Clarice Starling seeks the help of imprisoned cannibal " + "Hannibal Lecter to catch a serial killer.", + "No Country for Old Men (2007): A tense neo-Western thriller about a hunter " + "who stumbles upon a drug deal gone wrong, pursued by a relentless assassin " + "and a weary sheriff.", + "Parasite (2019): A sharp social satire about a poor family infiltrating a " + "wealthy household, exposing deep class divides through dark humor and " + "shocking twists.", + "Amadeus (1984): A dramatic retelling of the rivalry between composers " + "Mozart and Salieri, exploring genius, jealousy, and the price of artistic " + "brilliance.", + "Braveheart (1995): A brutal and inspiring historical epic about William " + "Wallace leading the Scottish rebellion against English tyranny in the " + "13th century.", + "12 Years a Slave (2013): A gut-wrenching true story of Solomon Northup, " + "a free Black man abducted and sold into slavery, depicting the horrors " + "and resilience of the human spirit.", + "The Shape of Water (2017): A unique fantasy romance between a mute woman " + "and a mysterious amphibious creature, set in Cold War-era America.", + "The Departed (2006): A gritty crime thriller about undercover agents on " + "opposite sides of the law, entangled in a dangerous game of deception and " + "survival.", + "Slumdog Millionaire (2008): A rags-to-riches tale of a Mumbai slum boy whose " + "life experiences help him succeed on a game show while searching for his " + "lost love.", + "The Artist (2011): A silent film homage about a 1920s movie star " + "struggling to adapt to the rise of talkies, capturing the magic and " + "pain of Hollywood's transition.", +] + +PINECONE_INDEX_NAME = "think-test" + + +@pytest.fixture +def pinecone_index(): + api_key = getenv("PINECONE_API_KEY") + if api_key is None: + yield None + return + + try: + from pinecone.grpc import PineconeGRPC as Pinecone + except ImportError: + yield None + return + + pc = Pinecone(api_key=api_key) + indices = [idx.name for idx in pc.list_indexes()] + + if PINECONE_INDEX_NAME not in indices: + yield None + return + + index = pc.Index(PINECONE_INDEX_NAME) + try: + index.delete(delete_all=True) + except: # noqa + pass + + yield PINECONE_INDEX_NAME + + try: + index.delete(delete_all=True) + except: # noqa + pass + + +@pytest.mark.asyncio +async def test_txtai_integration(): + llm = LLM.from_url(LLM_URL) + rag_class = RAG.for_provider("txtai") + + rag: RAG = rag_class(llm) + + data = [RagDocument(id=str(i), text=text) for i, text in enumerate(MOVIES)] + await rag.add_documents(data) + + n_docs = await rag.count() + assert n_docs == len(MOVIES) + + query = "A movie about a ship that sinks" + result = await rag(query) + assert "titanic" in result.lower() + + await rag.remove_documents([doc["id"] for doc in data]) + + no_docs = await rag.count() + assert no_docs == 0 + + +@pytest.mark.asyncio +async def test_chroma_integration(tmpdir): + llm = LLM.from_url(LLM_URL) + rag_class = RAG.for_provider("chroma") + + rag: RAG = rag_class(llm, collection="test", path=tmpdir) + + data = [RagDocument(id=str(i), text=text) for i, text in enumerate(MOVIES)] + await rag.add_documents(data) + + n_docs = await rag.count() + assert n_docs == len(MOVIES) + + query = "A movie about a ship that sinks" + result = await rag(query) + assert "titanic" in result.lower() + + await rag.remove_documents([doc["id"] for doc in data]) + + no_docs = await rag.count() + assert no_docs == 0 + + +@pytest.mark.asyncio +@pytest.mark.skip("Pinecone is flaky/slow, can't reliably count docs after insert") +async def test_pinecone_integration(tmpdir, pinecone_index): + if pinecone_index is None: + pytest.skip("Pinecone index not available") + + llm = LLM.from_url(LLM_URL) + rag_class = RAG.for_provider("pinecone") + + rag: RAG = rag_class(llm, index_name=pinecone_index) + + data = [RagDocument(id=str(i), text=text) for i, text in enumerate(MOVIES)] + await rag.add_documents(data) + + for i in range(60): + time.sleep(1) + n_docs = await rag.count() + if n_docs == len(MOVIES): + break + else: + assert False, "Failed to add documents to Pinecone index after 30s" + + query = "A movie about a ship that sinks" + result = await rag(query) + assert "titanic" in result.lower() + + await rag.remove_documents([doc["id"] for doc in data]) + + no_docs = await rag.count() + assert no_docs == 0 From a3b2d396a89ec0837d880a363a612e8dfb18583b Mon Sep 17 00:00:00 2001 From: Senko Rasic Date: Wed, 12 Feb 2025 20:25:11 +0100 Subject: [PATCH 8/9] document rag in README, update optional/dev deps for it --- README.md | 38 +++++++++++++++++++++++++++++++++++++- pyproject.toml | 50 +++++++++++++++----------------------------------- 2 files changed, 52 insertions(+), 36 deletions(-) diff --git a/README.md b/README.md index d1dfe59..7464290 100644 --- a/README.md +++ b/README.md @@ -130,6 +130,42 @@ print(run(generate_python_script("sort a list of numbers"))) For detailed documentation on usage and all available features, please refer to the code docstrings and the integration tests. +## Retrieval-Augmented Generation (RAG) support + +Think supports RAG using TxtAI, ChromaDB and Pinecone vector databases, and +provides scaffolding to integrate other RAG providers. + +Example usage: + +```python + from asyncio import run + + from think import LLM + from think.rag.base import RAG, RagDocument + + llm = LLM.from_url("openai:///gpt-4o-mini") + rag = RAG.for_provider("txtai")(llm) + + async def index_documents(): + data = [ + RagDocument(id="a", text="Titanic: A sweeping romantic epic"), + RagDocument(id="b", text="The Godfather: A gripping mafia saga"), + RagDocument(id="c", text="Forrest Gump: A heartwarming tale of a simple man"), + ] + await rag.add_documents(data) + + run(index_documents()) + query = "A movie about a ship that sinks" + result = run(rag(query)) + print(result) +``` + +You can extend the specific RAG provider classes to add custom functionality, +change LLM prompts, add reranking, etc. + +RAG evaluation is supported via the `think.rag.eval.RagEval` class, supporting Context Precision, +Context Recall, Faithfulness and Answer Relevance metrics. + ## Quickstart Install via `pip`: @@ -214,5 +250,5 @@ To ensure that your contribution is accepted, please follow these guidelines: ## Copyright -Copyright (C) 2023-2024. Senko Rasic and Think contributors. You may use and/or distribute +Copyright (C) 2023-2025. Senko Rasic and Think contributors. You may use and/or distribute this project under the terms of MIT license. See the LICENSE file for more details. diff --git a/pyproject.toml b/pyproject.toml index 4dfa9b3..529b3cd 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,47 +4,23 @@ version = "0.0.7" description = "Create programs that think, using LLMs." readme = "README.md" requires-python = ">=3.10" -dependencies = [ - "pydantic>=2.9.2", - "jinja2>=3.1.2", - "httpx>=0.27.2", -] -authors = [ - { name = "Senko Rasic", email = "senko@senko.net" }, -] +dependencies = ["pydantic>=2.9.2", "jinja2>=3.1.2", "httpx>=0.27.2"] +authors = [{ name = "Senko Rasic", email = "senko@senko.net" }] license = { text = "MIT" } homepage = "https://github.com/senko/think" repository = "https://github.com/senko/think" keywords = ["ai", "llm", "rag"] [project.optional-dependencies] -openai = [ - "openai>=1.53.0", -] -anthropic = [ - "anthropic>=0.37.1", -] -gemini = [ - "google-generativeai>=0.8.3", -] -groq = [ - "groq>=0.12.0", -] -ollama = [ - "ollama>=0.3.3", -] -bedrock = [ - "aioboto3>=13.2.0", -] -txtai = [ - "txtai>=8.1.0", -] -chromadb = [ - "chromadb>=0.6.2", -] -pinecone = [ - "pinecone>=5.4.2", -] +openai = ["openai>=1.53.0"] +anthropic = ["anthropic>=0.37.1"] +gemini = ["google-generativeai>=0.8.3"] +groq = ["groq>=0.12.0"] +ollama = ["ollama>=0.3.3"] +bedrock = ["aioboto3>=13.2.0"] +txtai = ["txtai>=8.1.0"] +chromadb = ["chromadb>=0.6.2"] +pinecone = ["pinecone>=5.4.2"] all = [ "openai>=1.53.0", @@ -71,6 +47,10 @@ dev = [ "google-generativeai>=0.8.3", "groq>=0.12.0", "ollama>=0.3.3", + "txtai>=8.1.0", + "chromadb>=0.6.2", + "pinecone>=5.4.2", + "pinecone-client>=4.1.2", ] [build-system] From 7bdc8b179cee226b0f38bb0285c8f26a97fb15fb Mon Sep 17 00:00:00 2001 From: Senko Rasic Date: Wed, 12 Feb 2025 20:28:24 +0100 Subject: [PATCH 9/9] update ruff to latest release --- .pre-commit-config.yaml | 4 ++-- pyproject.toml | 2 +- think/rag/chroma_rag.py | 6 +++--- 3 files changed, 6 insertions(+), 6 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index cf68075..803deb1 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,11 +1,11 @@ fail_fast: true repos: - repo: https://github.com/astral-sh/ruff-pre-commit - rev: v0.6.4 + rev: v0.9.6 hooks: # Run the linter. - id: ruff - args: [ --fix ] + args: [--fix] # Run the formatter. - id: ruff-format - repo: local diff --git a/pyproject.toml b/pyproject.toml index 529b3cd..b2e48bc 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -36,7 +36,7 @@ all = [ [dependency-groups] dev = [ - "ruff>=0.8.1", + "ruff>=0.9.6", "pytest>=8.3.2", "pytest-coverage>=0.0", "pytest-asyncio>=0.23.8", diff --git a/think/rag/chroma_rag.py b/think/rag/chroma_rag.py index 1142886..2aa4075 100644 --- a/think/rag/chroma_rag.py +++ b/think/rag/chroma_rag.py @@ -82,9 +82,9 @@ async def count(self) -> int: async def calculate_similarity(self, query: str, docs: list[str]) -> list[float]: inputs = [query] + docs - assert ( - self.collection._embedding_function - ), "Cannot calculate similarity without an embedding function" + assert self.collection._embedding_function, ( + "Cannot calculate similarity without an embedding function" + ) vectors = self.collection._embedding_function(inputs) query_vector, *doc_vectors = vectors similarities = []