Skip to content
Open
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
246 changes: 159 additions & 87 deletions packages/lmi/src/lmi/llms.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
import logging
from abc import ABC
from collections.abc import (
AsyncIterable,
AsyncGenerator,
Awaitable,
Callable,
Coroutine,
Expand Down Expand Up @@ -201,7 +201,7 @@ async def acompletion(self, messages: list[Message], **kwargs) -> list[LLMResult

async def acompletion_iter(
self, messages: list[Message], **kwargs
) -> AsyncIterable[LLMResult]:
) -> AsyncGenerator[LLMResult]:
"""Return an async generator that yields completions.

Only the last tuple will be non-zero.
Expand All @@ -224,18 +224,49 @@ def __str__(self) -> str:
# None means we won't provide a tool_choice to the LLM API
UNSPECIFIED_TOOL_CHOICE: ClassVar[None] = None

@overload
async def call(
self,
messages: list[Message] | str,
callbacks: (
Sequence[Callable[..., Any] | Callable[..., Awaitable]] | None
) = ...,
name: str | None = ...,
output_type: type[BaseModel] | TypeAdapter | JSONSchema | None = ...,
tools: list[Tool] | None = ...,
tool_choice: Tool | str | None = ...,
stream: bool = False,
**kwargs,
) -> list[LLMResult]: ...

@overload
async def call( # type: ignore[overload-cannot-match]
self,
messages: list[Message] | str,
callbacks: (
Sequence[Callable[..., Any] | Callable[..., Awaitable]] | None
) = ...,
name: str | None = ...,
output_type: type[BaseModel] | TypeAdapter | JSONSchema | None = ...,
tools: list[Tool] | None = ...,
tool_choice: Tool | str | None = ...,
stream: bool = True,
**kwargs,
) -> AsyncGenerator[LLMResult]: ...

async def call( # noqa: C901, PLR0915
self,
messages: list[Message],
messages: list[Message] | str,
callbacks: (
Sequence[Callable[..., Any] | Callable[..., Awaitable]] | None
) = None,
name: str | None = None,
output_type: type[BaseModel] | TypeAdapter | JSONSchema | None = None,
tools: list[Tool] | None = None,
tool_choice: Tool | str | None = TOOL_CHOICE_REQUIRED,
stream: bool = False,
**kwargs,
) -> list[LLMResult]:
) -> list[LLMResult] | AsyncGenerator[LLMResult]:
"""Call the LLM model with the given messages and configuration.

Args:
Expand All @@ -245,22 +276,30 @@ async def call( # noqa: C901, PLR0915
output_type: The type of the output.
tools: A list of tools to use.
tool_choice: The tool choice to use.
stream: Whether to stream the response or return all results at once.
kwargs: Additional keyword arguments for the chat completion.

Returns:
A list of LLMResult objects containing the result of the call.
A list of LLMResult objects containing the result of the call when stream=False,
or an AsyncGenerator[LLMResult] when stream=True.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Two comments:

  • Indent the later lines by 4
  • Consider not restating the type hints/variable names, as that can lead to drift over time

For example:

When not streaming, it's a list of result objects for each call, otherwise
    it's an async generator of result objects.


Raises:
ValueError: If the LLM type is unknown.
"""
if isinstance(messages, str):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we want to do this here, can we remove it from call_single?

# convenience for single message
messages = [Message(content=messages)]
chat_kwargs = copy.deepcopy(kwargs)
# if using the config for an LLMModel,
# there may be a nested 'config' key
# that can't be used by chat
chat_kwargs.pop("config", None)
chat_kwargs.pop("stream", None)
n = chat_kwargs.get("n") or self.config.get("n", 1)
if n < 1:
raise ValueError("Number of completions (n) must be >= 1.")
if stream and n > 1:
raise ValueError("Number of completions (n) must be 1 when streaming.")
if "fallbacks" not in chat_kwargs and "fallbacks" in self.config:
chat_kwargs["fallbacks"] = self.config.get("fallbacks", [])

Expand Down Expand Up @@ -328,48 +367,83 @@ async def call( # noqa: C901, PLR0915
)
for m in messages
]
results: list[LLMResult] = []

start_clock = asyncio.get_running_loop().time()
if callbacks is None:

# If not streaming, simply return the results
if not stream:
sync_callbacks = [
f for f in (callbacks or []) if not is_coroutine_callable(f)
]
async_callbacks = [f for f in (callbacks or []) if is_coroutine_callable(f)]
results = await self.acompletion(messages, **chat_kwargs)
else:
if tools:
raise NotImplementedError("Using tools with callbacks is not supported")
n = chat_kwargs.get("n") or self.config.get("n", 1)
if n > 1:
raise NotImplementedError(
"Multiple completions with callbacks is not supported"
for result in results:
text = cast("str", result.text)
await do_callbacks(async_callbacks, sync_callbacks, text, name)
usage = result.prompt_count, result.completion_count
if not sum(usage):
result.completion_count = self.count_tokens(text)
result.seconds_to_last_token = (
asyncio.get_running_loop().time() - start_clock
)
sync_callbacks = [f for f in callbacks if not is_coroutine_callable(f)]
async_callbacks = [f for f in callbacks if is_coroutine_callable(f)]
stream_results = await self.acompletion_iter(messages, **chat_kwargs)
text_result = []
async for result in stream_results:
if result.text:
if result.seconds_to_first_token == 0:
result.seconds_to_first_token = (
asyncio.get_running_loop().time() - start_clock
)
text_result.append(result.text)
await do_callbacks(
async_callbacks, sync_callbacks, result.text, name
result.name = name
if self.llm_result_callback:
possibly_awaitable_result = self.llm_result_callback(result)
if isawaitable(possibly_awaitable_result):
await possibly_awaitable_result
return results

# If streaming, return an AsyncGenerator[LLMResult]
if tools:
raise NotImplementedError("Using tools with streaming is not supported")
if callbacks:
raise NotImplementedError("Using callbacks with streaming is not supported")
Comment on lines +396 to +399
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you reword to say "not yet supported"

Also, can you add a comment to each of these saying why they're not supported?


async def process_stream() -> AsyncGenerator[LLMResult]:
async_iterable = await self.acompletion_iter(messages, **chat_kwargs)
async for result in async_iterable:
usage = result.prompt_count, result.completion_count
if not sum(usage):
result.completion_count = self.count_tokens(
cast("str", result.text)
)
results.append(result)

for result in results:
usage = result.prompt_count, result.completion_count
if not sum(usage):
result.completion_count = self.count_tokens(cast("str", result.text))
result.seconds_to_last_token = (
asyncio.get_running_loop().time() - start_clock
)
result.name = name
if self.llm_result_callback:
possibly_awaitable_result = self.llm_result_callback(result)
if isawaitable(possibly_awaitable_result):
await possibly_awaitable_result
return results
result.seconds_to_last_token = (
asyncio.get_running_loop().time() - start_clock
)
result.name = name
yield result

return process_stream()

@overload
async def call_single(
self,
messages: list[Message] | str,
callbacks: (
Sequence[Callable[..., Any] | Callable[..., Awaitable]] | None
) = ...,
name: str | None = ...,
output_type: type[BaseModel] | TypeAdapter | JSONSchema | None = ...,
tools: list[Tool] | None = ...,
tool_choice: Tool | str | None = ...,
stream: bool = False,
**kwargs,
) -> LLMResult: ...

@overload
async def call_single( # type: ignore[overload-cannot-match]
Copy link
Collaborator

@jamesbraza jamesbraza Jul 18, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This type ignore -- why do you have it? We should have no type ignores here imo, otherwise it means the typing is wrong

self,
messages: list[Message] | str,
callbacks: (
Sequence[Callable[..., Any] | Callable[..., Awaitable]] | None
) = ...,
name: str | None = ...,
output_type: type[BaseModel] | TypeAdapter | JSONSchema | None = ...,
tools: list[Tool] | None = ...,
tool_choice: Tool | str | None = ...,
stream: bool = True,
**kwargs,
) -> AsyncGenerator[LLMResult]: ...

async def call_single(
self,
Expand All @@ -381,21 +455,31 @@ async def call_single(
output_type: type[BaseModel] | TypeAdapter | JSONSchema | None = None,
tools: list[Tool] | None = None,
tool_choice: Tool | str | None = TOOL_CHOICE_REQUIRED,
stream: bool = False,
**kwargs,
) -> LLMResult:
) -> LLMResult | AsyncGenerator[LLMResult]:
if isinstance(messages, str):
# convenience for single message
messages = [Message(content=messages)]
kwargs = {**kwargs, "n": 1}
results = await self.call(
messages,
callbacks,
name,
output_type,
tools,
tool_choice,
n=1,
stream,
**kwargs,
)

if stream:
if not isinstance(results, AsyncGenerator):
raise TypeError("Expected AsyncGenerator of results when streaming")
return results

if not isinstance(results, list):
raise TypeError("Expected list of results when not streaming")
if len(results) != 1:
# Can be caused by issues like https://github.com/BerriAI/litellm/issues/12298
raise ValueError(f"Got {len(results)} results when expecting just one.")
Expand All @@ -413,8 +497,8 @@ def rate_limited(

@overload
def rate_limited(
func: Callable[P, AsyncIterable[LLMResult]],
) -> Callable[P, Coroutine[Any, Any, AsyncIterable[LLMResult]]]: ...
func: Callable[P, AsyncGenerator[LLMResult]],
) -> Callable[P, Coroutine[Any, Any, AsyncGenerator[LLMResult]]]: ...


def rate_limited(func):
Expand All @@ -440,7 +524,7 @@ async def wrapper(self, *args, **kwargs):
# portion before yielding
if isasyncgenfunction(func):

async def rate_limited_generator() -> AsyncIterable[LLMResult]:
async def rate_limited_generator() -> AsyncGenerator[LLMResult]:
async for item in func(self, *args, **kwargs):
token_count = 0
if isinstance(item, LLMResult):
Expand Down Expand Up @@ -469,8 +553,8 @@ def request_limited(

@overload
def request_limited(
func: Callable[P, Coroutine[Any, Any, AsyncIterable[LLMResult]]],
) -> Callable[P, Coroutine[Any, Any, AsyncIterable[LLMResult]]]: ...
func: Callable[P, Coroutine[Any, Any, AsyncGenerator[LLMResult]]],
) -> Callable[P, Coroutine[Any, Any, AsyncGenerator[LLMResult]]]: ...


def request_limited(func):
Expand All @@ -487,7 +571,7 @@ async def wrapper(self, *args, **kwargs):

if isasyncgenfunction(func):

async def request_limited_generator() -> AsyncIterable[LLMResult]:
async def request_limited_generator() -> AsyncGenerator[LLMResult]:
first_item = True
async for item in func(self, *args, **kwargs):
# Skip rate limit check for first item since we already checked at generator start
Expand Down Expand Up @@ -608,16 +692,6 @@ def maybe_set_config_attribute(cls, input_data: dict[str, Any]) -> dict[str, Any
_DeploymentTypedDictValidator.validate_python(model_list)
return data

# SEE: https://platform.openai.com/docs/api-reference/chat/create#chat-create-tool_choice
# > `none` means the model will not call any tool and instead generates a message.
# > `auto` means the model can pick between generating a message or calling one or more tools.
# > `required` means the model must call one or more tools.
NO_TOOL_CHOICE: ClassVar[str] = "none"
MODEL_CHOOSES_TOOL: ClassVar[str] = "auto"
TOOL_CHOICE_REQUIRED: ClassVar[str] = "required"
# None means we won't provide a tool_choice to the LLM API
UNSPECIFIED_TOOL_CHOICE: ClassVar[None] = None

def __getstate__(self):
# Prevent _router from being pickled, SEE: https://stackoverflow.com/a/2345953
state = super().__getstate__()
Expand Down Expand Up @@ -719,7 +793,7 @@ async def acompletion(self, messages: list[Message], **kwargs) -> list[LLMResult
@rate_limited
async def acompletion_iter(
self, messages: list[Message], **kwargs
) -> AsyncIterable[LLMResult]:
) -> AsyncGenerator[LLMResult]:
# cast is necessary for LiteLLM typing bug: https://github.com/BerriAI/litellm/issues/7641
prompts = cast(
"list[litellm.types.llms.openai.AllMessageValues]",
Expand All @@ -728,7 +802,7 @@ async def acompletion_iter(
stream_options = {
"include_usage": True,
}
# NOTE: Specifically requesting reasoning for deepseek-r1 models

if kwargs.get("include_reasoning"):
stream_options["include_reasoning"] = True

Expand All @@ -740,43 +814,41 @@ async def acompletion_iter(
**kwargs,
)
start_clock = asyncio.get_running_loop().time()
outputs = []
accumulated_text = ""
logprobs = []
role = None
reasoning_content = []
used_model = None
first_token_time = None

async for completion in stream_completions:
if not used_model:
used_model = completion.model or self.name
choice = completion.choices[0]
delta = choice.delta
# logprobs can be None, or missing a content attribute,
# or a ChoiceLogprobs object with a NoneType/empty content attribute

if first_token_time is None and delta.content:
first_token_time = asyncio.get_running_loop().time()

if logprob_content := getattr(choice.logprobs, "content", None):
logprobs.append(logprob_content[0].logprob or 0)
outputs.append(delta.content or "")
role = delta.role or role
if hasattr(delta, "reasoning_content"):
reasoning_content.append(delta.reasoning_content or "")
text = "".join(outputs)
result = LLMResult(
model=used_model,
text=text,
prompt=messages,
messages=[Message(role=role, content=text)],
logprob=sum_logprobs(logprobs),
reasoning_content="".join(reasoning_content),
)

if text:
result.seconds_to_first_token = (
asyncio.get_running_loop().time() - start_clock
)
if hasattr(completion, "usage"):
result.prompt_count = completion.usage.prompt_tokens
result.completion_count = completion.usage.completion_tokens
if delta.content:
accumulated_text += delta.content
role = delta.role or role
if hasattr(delta, "reasoning_content"):
reasoning_content.append(delta.reasoning_content or "")

yield result
yield LLMResult(
model=used_model,
text=delta.content,
prompt=messages,
messages=[Message(role=role, content=accumulated_text)],
logprob=sum_logprobs(logprobs),
reasoning_content="".join(reasoning_content),
seconds_to_first_token=(
first_token_time - start_clock if first_token_time else None
),
)

def count_tokens(self, text: str) -> int:
return litellm.token_counter(model=self.name, text=text)
Expand Down
Loading
Loading