diff --git a/nemoguardrails/integrations/langchain/runnable_rails.py b/nemoguardrails/integrations/langchain/runnable_rails.py index 54b68d383..74ad50937 100644 --- a/nemoguardrails/integrations/langchain/runnable_rails.py +++ b/nemoguardrails/integrations/langchain/runnable_rails.py @@ -16,9 +16,11 @@ from __future__ import annotations import logging -from typing import Any, AsyncIterator, Dict, Iterator, List, Optional, Union +from typing import Any, AsyncIterator, Dict, Iterator, List, Optional, Union, cast -from langchain_core.language_models import BaseLanguageModel +from langchain_core.language_models import BaseChatModel, BaseLanguageModel +from langchain_core.language_models.llms import BaseLLM +from langchain_core.messages import AIMessage, HumanMessage, SystemMessage from langchain_core.prompt_values import ChatPromptValue, StringPromptValue from langchain_core.runnables import Runnable, RunnableConfig from langchain_core.runnables.utils import Input, Output, gather_with_concurrency @@ -33,7 +35,7 @@ message_to_dict, ) from nemoguardrails.integrations.langchain.utils import async_wrap -from nemoguardrails.rails.llm.options import GenerationOptions +from nemoguardrails.rails.llm.options import GenerationOptions, GenerationResponse logger = logging.getLogger(__name__) @@ -62,7 +64,7 @@ class RunnableRails(Runnable[Input, Output]): def __init__( self, config: RailsConfig, - llm: Optional[BaseLanguageModel] = None, + llm: Optional[Union[BaseLLM, BaseChatModel]] = None, tools: Optional[List[Tool]] = None, passthrough: bool = True, runnable: Optional[Runnable] = None, @@ -110,7 +112,7 @@ def __init__( if self.passthrough_runnable: self._init_passthrough_fn() - def _init_passthrough_fn(self): + def _init_passthrough_fn(self) -> None: """Initialize the passthrough function for the LLM rails instance.""" async def passthrough_fn(context: dict, events: List[dict]): @@ -134,7 +136,8 @@ async def passthrough_fn(context: dict, events: List[dict]): return text, _output - self.rails.llm_generation_actions.passthrough_fn = passthrough_fn + # Dynamically assign passthrough_fn to avoid type checker issues + setattr(self.rails.llm_generation_actions, "passthrough_fn", passthrough_fn) def __or__( self, other: Union[BaseLanguageModel, Runnable[Any, Any]] @@ -687,6 +690,9 @@ def _full_rails_invoke( res = self.rails.generate( messages=input_messages, options=GenerationOptions(output_vars=True) ) + # When using output_vars=True, rails.generate returns a GenerationResponse + if not isinstance(res, GenerationResponse): + raise Exception(f"Expected GenerationResponse, got {type(res)}") context = res.output_data result = res.response