Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
499 changes: 499 additions & 0 deletions INTERNALS.md

Large diffs are not rendered by default.

12 changes: 7 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -319,14 +319,16 @@ Contributions are welcome!

To ensure that your contribution is accepted, please follow these guidelines:

- read [INTERNALS.md](INTERNALS.md) document to get familiar with the codebase
- open an issue to discuss your idea before you start working on it, or if there's
already an issue for your idea, join the conversation there and explain how you
plan to implement it
- make sure that your code is well documented (docstrings, type annotations, comments,
etc.) and tested (test coverage should only go up)
- make sure that your code is formatted and type-checked with `ruff` (default settings)
- make sure that your code is well documented (docstrings, type annotations, comments, etc.) and tested (test coverage should only go up)
- install and use `pre-commit` hooks (`uv run pre-commit install`) to ensure formatting, linting, type-checking and tests are run before committing

## Copyright

Copyright (C) 2023-2025. Senko Rasic and Think contributors. You may use and/or distribute
this project under the terms of MIT license. See the LICENSE file for more details.
Copyright (C) 2023-2025. Senko Rasic and Think contributors.

You may use and/or distribute this project under the terms of MIT license.
See the [LICENSE](LICENSE) file for more details.
42 changes: 21 additions & 21 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -34,27 +34,6 @@ all = [
"pinecone-client>=4.1.2",
]

[dependency-groups]
dev = [
"ruff>=0.9.6",
"pytest>=8.3.2",
"pytest-coverage>=0.0",
"pytest-asyncio>=0.23.8",
"pre-commit>=3.8.0",
"python-dotenv>=1.0.1",
"openai>=1.53.0",
"anthropic>=0.37.1",
"google-generativeai>=0.8.3",
"groq>=0.12.0",
"ollama>=0.3.3",
"txtai>=8.1.0",
"chromadb>=0.6.2",
"pinecone>=5.4.2",
"pinecone-client>=4.1.2",
"aioboto3>=13.2.0",
"ty>=0.0.1a1",
]

[build-system]
requires = ["hatchling"]
build-backend = "hatchling.build"
Expand All @@ -72,3 +51,24 @@ exclude_lines = ["if TYPE_CHECKING:"]

[tool.pyright]
typeCheckingMode = "off"

[tool.uv]
dev-dependencies = [
"pytest-asyncio>=1.1.0",
"pytest-coverage>=0.0",
"pytest>=8.4.1",
"ty>=0.0.1a16",
"ruff>=0.9.6",
"pre-commit>=3.8.0",
"python-dotenv>=1.0.1",
"openai>=1.53.0",
"anthropic>=0.37.1",
"google-generativeai>=0.8.3",
"groq>=0.12.0",
"ollama>=0.3.3",
"txtai>=8.1.0",
"chromadb>=0.6.2",
"pinecone>=5.4.2",
"pinecone-client>=4.1.2",
"aioboto3>=13.2.0",
]
2 changes: 1 addition & 1 deletion tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ def model_urls(vision: bool = False) -> list[str]:
if getenv("GEMINI_API_KEY"):
retval.append("google:///gemini-2.0-flash-lite-preview-02-05")
if getenv("GROQ_API_KEY"):
retval.append("groq:///llama-3.2-90b-vision-preview")
retval.append("groq:///?model=meta-llama/llama-4-scout-17b-16e-instruct")
if getenv("OLLAMA_MODEL"):
if vision:
retval.append(f"ollama:///{getenv('OLLAMA_VISION_MODEL')}")
Expand Down
9 changes: 6 additions & 3 deletions tests/integration/test_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from pydantic import BaseModel

from think import LLM
from think.llm.base import BadRequestError, ConfigError
from think.llm.base import BadRequestError, ConfigError, BaseAdapter
from think.llm.chat import Chat

from conftest import api_model_urls, model_urls
Expand Down Expand Up @@ -182,16 +182,19 @@ async def test_chat_error(url):
c = Chat("You're a friendly assistant").user("Tell me a joke")
llm = LLM.from_url(url)

class FakeAdapter:
class FakeAdapter(BaseAdapter):
spec = None

def __init__(self, *args, **kwargs):
pass

def get_tool_spec(self, tool):
return {"name": tool.name}

def dump_chat(self, chat: Chat):
return "", {"messages": "invalid"}

llm.adapter_class = FakeAdapter
llm.adapter_class = FakeAdapter # type: ignore

with pytest.raises(BadRequestError):
await llm(c)
Expand Down
2 changes: 1 addition & 1 deletion tests/llm/test_anthropic_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,5 +113,5 @@ def test_adapter(chat, ex_system, expected):
assert system == ex_system
assert messages == expected

chat2 = adapter.load_chat(messages, system=system)
chat2 = adapter.load_chat(messages, system=None if system is NOT_GIVEN else system) # type: ignore
assert chat.messages == chat2.messages
45 changes: 24 additions & 21 deletions tests/llm/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,15 +27,18 @@ async def _internal_call(
max_tokens: int | None,
adapter: BaseAdapter,
response_format: PydanticResultT | None = None,
) -> Message: ...
) -> Message:
raise NotImplementedError()

async def _internal_stream(
self,
chat: Chat,
adapter: BaseAdapter,
temperature: float | None,
max_tokens: int | None,
) -> AsyncGenerator[str, None]: ...
) -> AsyncGenerator[str, None]:
raise NotImplementedError()
yield # Make it a generator


def text_message(text: str) -> Message:
Expand Down Expand Up @@ -78,13 +81,13 @@ async def test_call_minimal():
assert client.api_key == "fake-key"
assert client.model == "fake-model"

client._internal_call = AsyncMock(return_value=response_msg)
client._internal_call = AsyncMock(return_value=response_msg) # type: ignore
response = await client(chat, temperature=0.5, max_tokens=10)

assert response == "Hi!"

client._internal_call.assert_called_once()
args = client._internal_call.call_args
client._internal_call.assert_called_once() # type: ignore
args = client._internal_call.call_args # type: ignore

assert args.args[0] == chat
assert args.args[1] == 0.5 # temperature
Expand All @@ -100,7 +103,7 @@ async def test_call_with_tools():
chat = Chat("system message").user("user message")
client = MyClient(api_key="fake-key", model="fake-model")

client._internal_call = AsyncMock(
client._internal_call = AsyncMock( # type: ignore
side_effect=[
tool_call_message(),
text_message("Hi!"),
Expand All @@ -115,9 +118,9 @@ def fake_tool(a: int, b: str) -> str:

response = await client(chat, tools=[fake_tool], max_steps=1)

client._internal_call.assert_called()
assert client._internal_call.call_count == 2
args = client._internal_call.call_args_list[1]
client._internal_call.assert_called() # type: ignore
assert client._internal_call.call_count == 2 # type: ignore
args = client._internal_call.call_args_list[1] # type: ignore
assert args.args[0] == chat
assert response == "Hi!"

Expand All @@ -134,7 +137,7 @@ async def test_call_with_tool_error():
chat = Chat("system message").user("user message")
client = MyClient(api_key="fake-key", model="fake-model")

client._internal_call = AsyncMock(
client._internal_call = AsyncMock( # type: ignore
side_effect=[
tool_call_message(),
text_message("Hi!"),
Expand All @@ -147,22 +150,22 @@ def fake_tool(a: int, b: str) -> str:

response = await client(chat, tools=[fake_tool], max_steps=1)

client._internal_call.assert_called()
assert client._internal_call.call_count == 2
args = client._internal_call.call_args_list[1]
client._internal_call.assert_called() # type: ignore
assert client._internal_call.call_count == 2 # type: ignore
args = client._internal_call.call_args_list[1] # type: ignore
assert args.args[0] == chat
assert response == "Hi!"

tc = chat.messages[-2].content[0].tool_response
assert tc is not None
assert "some error" in tc.error
assert tc.error and "some error" in tc.error


@pytest.mark.asyncio
async def test_call_with_pydantic():
chat = Chat("system message").user("user message")
client = MyClient(api_key="fake-key", model="fake-model")
client._internal_call = AsyncMock(
client._internal_call = AsyncMock( # type: ignore
return_value=text_message(
json.dumps(
{
Expand All @@ -181,8 +184,8 @@ class TestModel(BaseModel):
assert response.text == "Hi!"
assert chat.messages[-1].parsed == response

client._internal_call.assert_called_once()
args = client._internal_call.call_args
client._internal_call.assert_called_once() # type: ignore
args = client._internal_call.call_args # type: ignore

assert args.args[0] == chat
assert args.kwargs["response_format"] is TestModel
Expand All @@ -192,7 +195,7 @@ class TestModel(BaseModel):
async def test_call_with_custom_parser():
chat = Chat("system message").user("user message")
client = MyClient(api_key="fake-key", model="fake-model")
client._internal_call = AsyncMock(return_value=text_message("Hi!"))
client._internal_call = AsyncMock(return_value=text_message("Hi!")) # type: ignore

def custom_parser(val: str) -> float:
assert val == "Hi!"
Expand All @@ -219,15 +222,15 @@ async def do_stream():
for c in original_message:
yield c

client._internal_stream = MagicMock(return_value=do_stream())
client._internal_stream = MagicMock(return_value=do_stream()) # type: ignore
text = []
async for word in client.stream(chat, temperature=0.5, max_tokens=10):
text.append(word)

assert "".join(text) == original_message

client._internal_stream.assert_called_once()
args = client._internal_stream.call_args
client._internal_stream.assert_called_once() # type: ignore
args = client._internal_stream.call_args # type: ignore

assert args.args[0] == chat
assert isinstance(args.args[1], MyAdapter)
Expand Down
Loading
Loading