Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
fail_fast: true
repos:
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.6.4
rev: v0.9.6
hooks:
# Run the linter.
- id: ruff
args: [ --fix ]
args: [--fix]
# Run the formatter.
- id: ruff-format
- repo: local
Expand Down
38 changes: 37 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,42 @@ print(run(generate_python_script("sort a list of numbers")))
For detailed documentation on usage and all available features, please refer to the
code docstrings and the integration tests.

## Retrieval-Augmented Generation (RAG) support

Think supports RAG using TxtAI, ChromaDB and Pinecone vector databases, and
provides scaffolding to integrate other RAG providers.

Example usage:

```python
from asyncio import run

from think import LLM
from think.rag.base import RAG, RagDocument

llm = LLM.from_url("openai:///gpt-4o-mini")
rag = RAG.for_provider("txtai")(llm)

async def index_documents():
data = [
RagDocument(id="a", text="Titanic: A sweeping romantic epic"),
RagDocument(id="b", text="The Godfather: A gripping mafia saga"),
RagDocument(id="c", text="Forrest Gump: A heartwarming tale of a simple man"),
]
await rag.add_documents(data)

run(index_documents())
query = "A movie about a ship that sinks"
result = run(rag(query))
print(result)
```

You can extend the specific RAG provider classes to add custom functionality,
change LLM prompts, add reranking, etc.

RAG evaluation is supported via the `think.rag.eval.RagEval` class, supporting Context Precision,
Context Recall, Faithfulness and Answer Relevance metrics.

## Quickstart

Install via `pip`:
Expand Down Expand Up @@ -214,5 +250,5 @@ To ensure that your contribution is accepted, please follow these guidelines:

## Copyright

Copyright (C) 2023-2024. Senko Rasic and Think contributors. You may use and/or distribute
Copyright (C) 2023-2025. Senko Rasic and Think contributors. You may use and/or distribute
this project under the terms of MIT license. See the LICENSE file for more details.
49 changes: 21 additions & 28 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,50 +4,39 @@ version = "0.0.7"
description = "Create programs that think, using LLMs."
readme = "README.md"
requires-python = ">=3.10"
dependencies = [
"pydantic>=2.9.2",
"jinja2>=3.1.2",
"httpx>=0.27.2",
]
authors = [
{ name = "Senko Rasic", email = "[email protected]" },
]
dependencies = ["pydantic>=2.9.2", "jinja2>=3.1.2", "httpx>=0.27.2"]
authors = [{ name = "Senko Rasic", email = "[email protected]" }]
license = { text = "MIT" }
homepage = "https://github.com/senko/think"
repository = "https://github.com/senko/think"
keywords = ["ai", "llm"]
keywords = ["ai", "llm", "rag"]

[project.optional-dependencies]
openai = [
"openai>=1.53.0",
]
anthropic = [
"anthropic>=0.37.1",
]
gemini = [
"google-generativeai>=0.8.3",
]
groq = [
"groq>=0.12.0",
]
ollama = [
"ollama>=0.3.3",
]
bedrock = [
"aioboto3>=13.2.0",
]
openai = ["openai>=1.53.0"]
anthropic = ["anthropic>=0.37.1"]
gemini = ["google-generativeai>=0.8.3"]
groq = ["groq>=0.12.0"]
ollama = ["ollama>=0.3.3"]
bedrock = ["aioboto3>=13.2.0"]
txtai = ["txtai>=8.1.0"]
chromadb = ["chromadb>=0.6.2"]
pinecone = ["pinecone>=5.4.2"]

all = [
"openai>=1.53.0",
"anthropic>=0.37.1",
"google-generativeai>=0.8.3",
"groq>=0.12.0",
"ollama>=0.3.3",
"txtai>=8.1.0",
"chromadb>=0.6.2",
"pinecone>=5.4.2",
"pinecone-client>=4.1.2",
]

[dependency-groups]
dev = [
"ruff>=0.8.1",
"ruff>=0.9.6",
"pytest>=8.3.2",
"pytest-coverage>=0.0",
"pytest-asyncio>=0.23.8",
Expand All @@ -58,6 +47,10 @@ dev = [
"google-generativeai>=0.8.3",
"groq>=0.12.0",
"ollama>=0.3.3",
"txtai>=8.1.0",
"chromadb>=0.6.2",
"pinecone>=5.4.2",
"pinecone-client>=4.1.2",
]

[build-system]
Expand Down
39 changes: 39 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
from os import getenv


def model_urls(vision: bool = False) -> list[str]:
"""
Returns a list of models to test with, based on available API keys.

:return: A list of model URLs based on the available API keys.
"""
retval = []
if getenv("OPENAI_API_KEY"):
retval.append("openai:///gpt-4o-mini")
if getenv("ANTHROPIC_API_KEY"):
retval.append("anthropic:///claude-3-haiku-20240307")
if getenv("GEMINI_API_KEY"):
retval.append("google:///gemini-2.0-flash-lite-preview-02-05")
if getenv("GROQ_API_KEY"):
retval.append("groq:///llama-3.2-90b-vision-preview")
if getenv("OLLAMA_MODEL"):
if vision:
retval.append(f"ollama:///{getenv('OLLAMA_VISION_MODEL')}")
else:
retval.append(f"ollama:///{getenv('OLLAMA_MODEL')}")
if getenv("AWS_SECRET_ACCESS_KEY"):
retval.append("bedrock:///amazon.nova-lite-v1:0?region=us-east-1")
if retval == []:
raise RuntimeError("No LLM API keys found in environment")
return retval


def api_model_urls() -> list[str]:
return [url for url in model_urls() if not url.startswith("ollama:")]


def first_model_url() -> str:
urls = model_urls()
if urls:
return urls[0]
raise RuntimeError("No LLM API keys found in environment")
42 changes: 10 additions & 32 deletions tests/integration/test_llm.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import logging
from os import getenv
from os import environ, getenv

import pytest
from dotenv import load_dotenv
Expand All @@ -9,6 +9,8 @@
from think.llm.base import BadRequestError, ConfigError
from think.llm.chat import Chat

from conftest import api_model_urls, model_urls

load_dotenv()
logging.basicConfig(level=logging.DEBUG)

Expand All @@ -27,34 +29,6 @@
pytest.skip("Skipping integration tests", allow_module_level=True)


def model_urls() -> list[str]:
"""
Returns a list of models to test with, based on available API keys.

:return: A list of model URLs based on the available API keys.
"""
retval = []
if getenv("OPENAI_API_KEY"):
retval.append("openai:///gpt-4o-mini")
if getenv("ANTHROPIC_API_KEY"):
retval.append("anthropic:///claude-3-haiku-20240307")
if getenv("GEMINI_API_KEY"):
retval.append("google:///gemini-1.5-pro-latest")
if getenv("GROQ_API_KEY"):
retval.append("groq:///llama-3.2-90b-vision-preview")
if getenv("OLLAMA_MODEL"):
retval.append(f"ollama:///{getenv('OLLAMA_MODEL')}")
if getenv("AWS_SECRET_ACCESS_KEY"):
retval.append("bedrock:///amazon.nova-lite-v1:0?region=us-east-1")
if retval == []:
raise RuntimeError("No LLM API keys found in environment")
return retval


def api_model_urls() -> list[str]:
return [url for url in model_urls() if not url.startswith("ollama:")]


@pytest.mark.parametrize("url", model_urls())
@pytest.mark.asyncio
async def test_basic_request(url):
Expand Down Expand Up @@ -104,7 +78,7 @@ def get_temperature(city: str) -> str:
assert tool_called, "Expected 'get_temperature' tool to be called"


@pytest.mark.parametrize("url", model_urls())
@pytest.mark.parametrize("url", model_urls(vision=True))
@pytest.mark.asyncio
async def test_vision(url):
c = Chat("You're a friendly assistant").user(
Expand Down Expand Up @@ -168,9 +142,13 @@ async def test_custom_parser(url):

@pytest.mark.parametrize("url", api_model_urls())
@pytest.mark.asyncio
async def test_auth_error(url):
async def test_auth_error(url, monkeypatch):
for key in environ.keys():
if key.endswith("_API_KEY") or key.startswith("AWS_"):
monkeypatch.setenv(key, "")

c = Chat("You're a friendly assistant").user("Tell me a joke")
invalid_key_url = url.replace("///", "//testing-incorrect-key@/")
invalid_key_url = url.replace("///", "//testing-incorrect-key:abc@/")
llm = LLM.from_url(invalid_key_url)

with pytest.raises(ConfigError):
Expand Down
103 changes: 103 additions & 0 deletions tests/rag/test_eval.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
import pytest
from unittest.mock import patch, MagicMock, AsyncMock

from think.rag.base import RagResult, RagDocument
from think.rag.eval import RagEval


@pytest.mark.asyncio
@patch("think.rag.eval.ask", new_callable=AsyncMock)
async def test_context_precision(ask):
llm = MagicMock()
rag = AsyncMock()
rag.fetch_results.return_value = [
RagResult(doc=RagDocument(id="A", text="A"), score=0.5),
RagResult(doc=RagDocument(id="B", text="B"), score=0.5),
RagResult(doc=RagDocument(id="C", text="C"), score=0.5),
RagResult(doc=RagDocument(id="D", text="D"), score=0.5),
]
ask.side_effect = ["yes", "no", "yes", "no"]

eval = RagEval(rag, llm)

query = "A movie about a ship that sinks"
ctx_precision = await eval.context_precision(query, 2)

assert ctx_precision == pytest.approx(1.33, rel=1e-2)


@pytest.mark.asyncio
@patch("think.rag.eval.ask", new_callable=AsyncMock)
async def test_context_recall(ask):
llm = MagicMock()
rag = AsyncMock()
rag.fetch_results.return_value = [
RagResult(doc=RagDocument(id="A", text="A"), score=0.5),
RagResult(doc=RagDocument(id="B", text="B"), score=0.5),
RagResult(doc=RagDocument(id="C", text="C"), score=0.5),
RagResult(doc=RagDocument(id="D", text="D"), score=0.5),
]
ask.side_effect = ["yes", "no", "yes", "no"]

eval = RagEval(rag, llm)

query = "A movie about a ship that sinks"
reference = [
"A ship sinks",
"A love story",
"A historical event",
"A tragedy",
]
ctx_recall = await eval.context_recall(query, reference, 4)

assert ctx_recall == pytest.approx(0.5, rel=1e-2)


@pytest.mark.asyncio
@patch("think.rag.eval.ask", new_callable=AsyncMock)
async def test_faithfulness(ask):
llm = MagicMock()
rag = AsyncMock()
rag.fetch_results.return_value = [
RagResult(doc=RagDocument(id="A", text="A"), score=0.5),
RagResult(doc=RagDocument(id="B", text="B"), score=0.5),
RagResult(doc=RagDocument(id="C", text="C"), score=0.5),
RagResult(doc=RagDocument(id="D", text="D"), score=0.5),
]
ask.side_effect = [
"The Titanic sank after hitting an iceberg.\n"
+ "Many people died.\n"
+ "It was a tragic event\n",
"yes",
"no",
"yes",
]

eval = RagEval(rag, llm)

query = "What happened to the Titanic?"
answer = "The Titanic sank after hitting an iceberg. Many people died. It was a tragic event."
faithfulness_score = await eval.faithfulness(query, answer, 4)

assert faithfulness_score == pytest.approx(0.67, rel=1e-2)


@pytest.mark.asyncio
@patch("think.rag.eval.ask", new_callable=AsyncMock)
async def test_answer_relevance(ask):
llm = MagicMock()
rag = AsyncMock()
rag.calculate_similarity.return_value = [0.8, 0.75, 0.85]
ask.side_effect = [
"What is the fate of the Titanic?\n"
+ "How did the Titanic sink?\n"
+ "What happened to the Titanic?"
]

eval = RagEval(rag, llm)

query = "What happened to the Titanic?"
answer = "The Titanic sank after hitting an iceberg. Many people died. It was a tragic event."
relevance_score = await eval.answer_relevance(query, answer, 3)

assert relevance_score == pytest.approx(0.8, rel=1e-2)
Loading