-
Notifications
You must be signed in to change notification settings - Fork 14
Enabled stream on lmi #316
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 4 commits
bd629d0
9d5778a
d028e48
129063d
baed6ae
e699df4
c4770f6
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -17,7 +17,7 @@ | |
| import logging | ||
| from abc import ABC | ||
| from collections.abc import ( | ||
| AsyncIterable, | ||
| AsyncGenerator, | ||
| Awaitable, | ||
| Callable, | ||
| Coroutine, | ||
|
|
@@ -201,7 +201,7 @@ async def acompletion(self, messages: list[Message], **kwargs) -> list[LLMResult | |
|
|
||
| async def acompletion_iter( | ||
| self, messages: list[Message], **kwargs | ||
| ) -> AsyncIterable[LLMResult]: | ||
| ) -> AsyncGenerator[LLMResult]: | ||
| """Return an async generator that yields completions. | ||
|
|
||
| Only the last tuple will be non-zero. | ||
|
|
@@ -224,18 +224,49 @@ def __str__(self) -> str: | |
| # None means we won't provide a tool_choice to the LLM API | ||
| UNSPECIFIED_TOOL_CHOICE: ClassVar[None] = None | ||
|
|
||
| @overload | ||
| async def call( | ||
| self, | ||
| messages: list[Message] | str, | ||
| callbacks: ( | ||
| Sequence[Callable[..., Any] | Callable[..., Awaitable]] | None | ||
| ) = ..., | ||
| name: str | None = ..., | ||
| output_type: type[BaseModel] | TypeAdapter | JSONSchema | None = ..., | ||
| tools: list[Tool] | None = ..., | ||
| tool_choice: Tool | str | None = ..., | ||
| stream: bool = False, | ||
| **kwargs, | ||
| ) -> list[LLMResult]: ... | ||
|
|
||
| @overload | ||
| async def call( # type: ignore[overload-cannot-match] | ||
| self, | ||
| messages: list[Message] | str, | ||
| callbacks: ( | ||
| Sequence[Callable[..., Any] | Callable[..., Awaitable]] | None | ||
| ) = ..., | ||
| name: str | None = ..., | ||
| output_type: type[BaseModel] | TypeAdapter | JSONSchema | None = ..., | ||
| tools: list[Tool] | None = ..., | ||
| tool_choice: Tool | str | None = ..., | ||
| stream: bool = True, | ||
| **kwargs, | ||
| ) -> AsyncGenerator[LLMResult]: ... | ||
|
|
||
| async def call( # noqa: C901, PLR0915 | ||
| self, | ||
| messages: list[Message], | ||
| messages: list[Message] | str, | ||
| callbacks: ( | ||
| Sequence[Callable[..., Any] | Callable[..., Awaitable]] | None | ||
| ) = None, | ||
| name: str | None = None, | ||
| output_type: type[BaseModel] | TypeAdapter | JSONSchema | None = None, | ||
| tools: list[Tool] | None = None, | ||
| tool_choice: Tool | str | None = TOOL_CHOICE_REQUIRED, | ||
| stream: bool = False, | ||
| **kwargs, | ||
| ) -> list[LLMResult]: | ||
| ) -> list[LLMResult] | AsyncGenerator[LLMResult]: | ||
| """Call the LLM model with the given messages and configuration. | ||
|
|
||
| Args: | ||
|
|
@@ -245,22 +276,30 @@ async def call( # noqa: C901, PLR0915 | |
| output_type: The type of the output. | ||
| tools: A list of tools to use. | ||
| tool_choice: The tool choice to use. | ||
| stream: Whether to stream the response or return all results at once. | ||
| kwargs: Additional keyword arguments for the chat completion. | ||
|
|
||
| Returns: | ||
| A list of LLMResult objects containing the result of the call. | ||
| A list of LLMResult objects containing the result of the call when stream=False, | ||
| or an AsyncGenerator[LLMResult] when stream=True. | ||
|
|
||
| Raises: | ||
| ValueError: If the LLM type is unknown. | ||
| """ | ||
| if isinstance(messages, str): | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If we want to do this here, can we remove it from |
||
| # convenience for single message | ||
| messages = [Message(content=messages)] | ||
| chat_kwargs = copy.deepcopy(kwargs) | ||
| # if using the config for an LLMModel, | ||
| # there may be a nested 'config' key | ||
| # that can't be used by chat | ||
| chat_kwargs.pop("config", None) | ||
| chat_kwargs.pop("stream", None) | ||
maykcaldas marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| n = chat_kwargs.get("n") or self.config.get("n", 1) | ||
| if n < 1: | ||
| raise ValueError("Number of completions (n) must be >= 1.") | ||
| if stream and n > 1: | ||
| raise ValueError("Number of completions (n) must be 1 when streaming.") | ||
| if "fallbacks" not in chat_kwargs and "fallbacks" in self.config: | ||
| chat_kwargs["fallbacks"] = self.config.get("fallbacks", []) | ||
|
|
||
|
|
@@ -328,48 +367,83 @@ async def call( # noqa: C901, PLR0915 | |
| ) | ||
| for m in messages | ||
| ] | ||
| results: list[LLMResult] = [] | ||
|
|
||
| start_clock = asyncio.get_running_loop().time() | ||
| if callbacks is None: | ||
|
|
||
| # If not streaming, simply return the results | ||
| if not stream: | ||
| sync_callbacks = [ | ||
| f for f in (callbacks or []) if not is_coroutine_callable(f) | ||
| ] | ||
| async_callbacks = [f for f in (callbacks or []) if is_coroutine_callable(f)] | ||
| results = await self.acompletion(messages, **chat_kwargs) | ||
| else: | ||
| if tools: | ||
| raise NotImplementedError("Using tools with callbacks is not supported") | ||
| n = chat_kwargs.get("n") or self.config.get("n", 1) | ||
| if n > 1: | ||
| raise NotImplementedError( | ||
| "Multiple completions with callbacks is not supported" | ||
| for result in results: | ||
| text = cast("str", result.text) | ||
| await do_callbacks(async_callbacks, sync_callbacks, text, name) | ||
| usage = result.prompt_count, result.completion_count | ||
| if not sum(usage): | ||
| result.completion_count = self.count_tokens(text) | ||
| result.seconds_to_last_token = ( | ||
| asyncio.get_running_loop().time() - start_clock | ||
| ) | ||
| sync_callbacks = [f for f in callbacks if not is_coroutine_callable(f)] | ||
| async_callbacks = [f for f in callbacks if is_coroutine_callable(f)] | ||
| stream_results = await self.acompletion_iter(messages, **chat_kwargs) | ||
| text_result = [] | ||
| async for result in stream_results: | ||
| if result.text: | ||
| if result.seconds_to_first_token == 0: | ||
| result.seconds_to_first_token = ( | ||
| asyncio.get_running_loop().time() - start_clock | ||
| ) | ||
| text_result.append(result.text) | ||
| await do_callbacks( | ||
| async_callbacks, sync_callbacks, result.text, name | ||
| result.name = name | ||
| if self.llm_result_callback: | ||
| possibly_awaitable_result = self.llm_result_callback(result) | ||
| if isawaitable(possibly_awaitable_result): | ||
| await possibly_awaitable_result | ||
| return results | ||
|
|
||
| # If streaming, return an AsyncGenerator[LLMResult] | ||
| if tools: | ||
| raise NotImplementedError("Using tools with streaming is not supported") | ||
| if callbacks: | ||
| raise NotImplementedError("Using callbacks with streaming is not supported") | ||
|
Comment on lines
+396
to
+399
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can you reword to say "not yet supported" Also, can you add a comment to each of these saying why they're not supported? |
||
|
|
||
| async def process_stream() -> AsyncGenerator[LLMResult]: | ||
| async_iterable = await self.acompletion_iter(messages, **chat_kwargs) | ||
| async for result in async_iterable: | ||
| usage = result.prompt_count, result.completion_count | ||
| if not sum(usage): | ||
| result.completion_count = self.count_tokens( | ||
| cast("str", result.text) | ||
| ) | ||
| results.append(result) | ||
|
|
||
| for result in results: | ||
| usage = result.prompt_count, result.completion_count | ||
| if not sum(usage): | ||
| result.completion_count = self.count_tokens(cast("str", result.text)) | ||
| result.seconds_to_last_token = ( | ||
| asyncio.get_running_loop().time() - start_clock | ||
| ) | ||
| result.name = name | ||
| if self.llm_result_callback: | ||
| possibly_awaitable_result = self.llm_result_callback(result) | ||
| if isawaitable(possibly_awaitable_result): | ||
| await possibly_awaitable_result | ||
| return results | ||
| result.seconds_to_last_token = ( | ||
| asyncio.get_running_loop().time() - start_clock | ||
| ) | ||
| result.name = name | ||
| yield result | ||
|
|
||
| return process_stream() | ||
|
|
||
| @overload | ||
| async def call_single( | ||
| self, | ||
| messages: list[Message] | str, | ||
| callbacks: ( | ||
| Sequence[Callable[..., Any] | Callable[..., Awaitable]] | None | ||
| ) = ..., | ||
| name: str | None = ..., | ||
| output_type: type[BaseModel] | TypeAdapter | JSONSchema | None = ..., | ||
| tools: list[Tool] | None = ..., | ||
| tool_choice: Tool | str | None = ..., | ||
| stream: bool = False, | ||
| **kwargs, | ||
| ) -> LLMResult: ... | ||
|
|
||
| @overload | ||
| async def call_single( # type: ignore[overload-cannot-match] | ||
|
||
| self, | ||
| messages: list[Message] | str, | ||
| callbacks: ( | ||
| Sequence[Callable[..., Any] | Callable[..., Awaitable]] | None | ||
| ) = ..., | ||
| name: str | None = ..., | ||
| output_type: type[BaseModel] | TypeAdapter | JSONSchema | None = ..., | ||
| tools: list[Tool] | None = ..., | ||
| tool_choice: Tool | str | None = ..., | ||
| stream: bool = True, | ||
| **kwargs, | ||
| ) -> AsyncGenerator[LLMResult]: ... | ||
|
|
||
| async def call_single( | ||
| self, | ||
|
|
@@ -381,21 +455,31 @@ async def call_single( | |
| output_type: type[BaseModel] | TypeAdapter | JSONSchema | None = None, | ||
| tools: list[Tool] | None = None, | ||
| tool_choice: Tool | str | None = TOOL_CHOICE_REQUIRED, | ||
| stream: bool = False, | ||
| **kwargs, | ||
| ) -> LLMResult: | ||
| ) -> LLMResult | AsyncGenerator[LLMResult]: | ||
| if isinstance(messages, str): | ||
| # convenience for single message | ||
| messages = [Message(content=messages)] | ||
| kwargs = {**kwargs, "n": 1} | ||
| results = await self.call( | ||
| messages, | ||
| callbacks, | ||
| name, | ||
| output_type, | ||
| tools, | ||
| tool_choice, | ||
| n=1, | ||
| stream, | ||
| **kwargs, | ||
| ) | ||
|
|
||
| if stream: | ||
| if not isinstance(results, AsyncGenerator): | ||
| raise TypeError("Expected AsyncGenerator of results when streaming") | ||
| return results | ||
|
|
||
| if not isinstance(results, list): | ||
| raise TypeError("Expected list of results when not streaming") | ||
| if len(results) != 1: | ||
| # Can be caused by issues like https://github.com/BerriAI/litellm/issues/12298 | ||
| raise ValueError(f"Got {len(results)} results when expecting just one.") | ||
|
|
@@ -413,8 +497,8 @@ def rate_limited( | |
|
|
||
| @overload | ||
| def rate_limited( | ||
| func: Callable[P, AsyncIterable[LLMResult]], | ||
| ) -> Callable[P, Coroutine[Any, Any, AsyncIterable[LLMResult]]]: ... | ||
| func: Callable[P, AsyncGenerator[LLMResult]], | ||
| ) -> Callable[P, Coroutine[Any, Any, AsyncGenerator[LLMResult]]]: ... | ||
|
|
||
|
|
||
| def rate_limited(func): | ||
|
|
@@ -440,7 +524,7 @@ async def wrapper(self, *args, **kwargs): | |
| # portion before yielding | ||
| if isasyncgenfunction(func): | ||
|
|
||
| async def rate_limited_generator() -> AsyncIterable[LLMResult]: | ||
| async def rate_limited_generator() -> AsyncGenerator[LLMResult]: | ||
| async for item in func(self, *args, **kwargs): | ||
| token_count = 0 | ||
| if isinstance(item, LLMResult): | ||
|
|
@@ -469,8 +553,8 @@ def request_limited( | |
|
|
||
| @overload | ||
| def request_limited( | ||
| func: Callable[P, Coroutine[Any, Any, AsyncIterable[LLMResult]]], | ||
| ) -> Callable[P, Coroutine[Any, Any, AsyncIterable[LLMResult]]]: ... | ||
| func: Callable[P, Coroutine[Any, Any, AsyncGenerator[LLMResult]]], | ||
| ) -> Callable[P, Coroutine[Any, Any, AsyncGenerator[LLMResult]]]: ... | ||
|
|
||
|
|
||
| def request_limited(func): | ||
|
|
@@ -487,7 +571,7 @@ async def wrapper(self, *args, **kwargs): | |
|
|
||
| if isasyncgenfunction(func): | ||
|
|
||
| async def request_limited_generator() -> AsyncIterable[LLMResult]: | ||
| async def request_limited_generator() -> AsyncGenerator[LLMResult]: | ||
| first_item = True | ||
| async for item in func(self, *args, **kwargs): | ||
| # Skip rate limit check for first item since we already checked at generator start | ||
|
|
@@ -608,16 +692,6 @@ def maybe_set_config_attribute(cls, input_data: dict[str, Any]) -> dict[str, Any | |
| _DeploymentTypedDictValidator.validate_python(model_list) | ||
| return data | ||
|
|
||
| # SEE: https://platform.openai.com/docs/api-reference/chat/create#chat-create-tool_choice | ||
| # > `none` means the model will not call any tool and instead generates a message. | ||
| # > `auto` means the model can pick between generating a message or calling one or more tools. | ||
| # > `required` means the model must call one or more tools. | ||
| NO_TOOL_CHOICE: ClassVar[str] = "none" | ||
| MODEL_CHOOSES_TOOL: ClassVar[str] = "auto" | ||
| TOOL_CHOICE_REQUIRED: ClassVar[str] = "required" | ||
| # None means we won't provide a tool_choice to the LLM API | ||
| UNSPECIFIED_TOOL_CHOICE: ClassVar[None] = None | ||
|
|
||
| def __getstate__(self): | ||
| # Prevent _router from being pickled, SEE: https://stackoverflow.com/a/2345953 | ||
| state = super().__getstate__() | ||
|
|
@@ -719,7 +793,7 @@ async def acompletion(self, messages: list[Message], **kwargs) -> list[LLMResult | |
| @rate_limited | ||
| async def acompletion_iter( | ||
| self, messages: list[Message], **kwargs | ||
| ) -> AsyncIterable[LLMResult]: | ||
| ) -> AsyncGenerator[LLMResult]: | ||
| # cast is necessary for LiteLLM typing bug: https://github.com/BerriAI/litellm/issues/7641 | ||
| prompts = cast( | ||
| "list[litellm.types.llms.openai.AllMessageValues]", | ||
|
|
@@ -728,7 +802,7 @@ async def acompletion_iter( | |
| stream_options = { | ||
| "include_usage": True, | ||
| } | ||
| # NOTE: Specifically requesting reasoning for deepseek-r1 models | ||
|
|
||
| if kwargs.get("include_reasoning"): | ||
| stream_options["include_reasoning"] = True | ||
|
|
||
|
|
@@ -740,43 +814,41 @@ async def acompletion_iter( | |
| **kwargs, | ||
| ) | ||
| start_clock = asyncio.get_running_loop().time() | ||
| outputs = [] | ||
| accumulated_text = "" | ||
| logprobs = [] | ||
| role = None | ||
| reasoning_content = [] | ||
| used_model = None | ||
| first_token_time = None | ||
|
|
||
| async for completion in stream_completions: | ||
| if not used_model: | ||
| used_model = completion.model or self.name | ||
| choice = completion.choices[0] | ||
| delta = choice.delta | ||
| # logprobs can be None, or missing a content attribute, | ||
| # or a ChoiceLogprobs object with a NoneType/empty content attribute | ||
|
|
||
| if first_token_time is None and delta.content: | ||
| first_token_time = asyncio.get_running_loop().time() | ||
|
|
||
| if logprob_content := getattr(choice.logprobs, "content", None): | ||
| logprobs.append(logprob_content[0].logprob or 0) | ||
| outputs.append(delta.content or "") | ||
| role = delta.role or role | ||
| if hasattr(delta, "reasoning_content"): | ||
| reasoning_content.append(delta.reasoning_content or "") | ||
| text = "".join(outputs) | ||
| result = LLMResult( | ||
| model=used_model, | ||
| text=text, | ||
| prompt=messages, | ||
| messages=[Message(role=role, content=text)], | ||
| logprob=sum_logprobs(logprobs), | ||
| reasoning_content="".join(reasoning_content), | ||
| ) | ||
|
|
||
| if text: | ||
| result.seconds_to_first_token = ( | ||
| asyncio.get_running_loop().time() - start_clock | ||
| ) | ||
| if hasattr(completion, "usage"): | ||
| result.prompt_count = completion.usage.prompt_tokens | ||
| result.completion_count = completion.usage.completion_tokens | ||
| if delta.content: | ||
| accumulated_text += delta.content | ||
| role = delta.role or role | ||
| if hasattr(delta, "reasoning_content"): | ||
| reasoning_content.append(delta.reasoning_content or "") | ||
|
|
||
| yield result | ||
| yield LLMResult( | ||
| model=used_model, | ||
| text=delta.content, | ||
maykcaldas marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| prompt=messages, | ||
| messages=[Message(role=role, content=accumulated_text)], | ||
| logprob=sum_logprobs(logprobs), | ||
| reasoning_content="".join(reasoning_content), | ||
| seconds_to_first_token=( | ||
| first_token_time - start_clock if first_token_time else None | ||
| ), | ||
maykcaldas marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| ) | ||
|
|
||
maykcaldas marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| def count_tokens(self, text: str) -> int: | ||
| return litellm.token_counter(model=self.name, text=text) | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Two comments:
For example: