diff --git a/biomni/llm.py b/biomni/llm.py index 61747b685..d28e6d657 100644 --- a/biomni/llm.py +++ b/biomni/llm.py @@ -6,7 +6,9 @@ if TYPE_CHECKING: from biomni.config import BiomniConfig -SourceType = Literal["OpenAI", "AzureOpenAI", "Anthropic", "Ollama", "Gemini", "Bedrock", "Groq", "Custom"] +SourceType = Literal[ + "OpenAI", "AzureOpenAI", "Anthropic", "Ollama", "Gemini", "Bedrock", "Groq", "HuggingFace", "Custom" +] ALLOWED_SOURCES: set[str] = set(SourceType.__args__) @@ -26,7 +28,7 @@ def get_llm( model (str): The model name to use temperature (float): Temperature setting for generation stop_sequences (list): Sequences that will stop generation - source (str): Source provider: "OpenAI", "AzureOpenAI", "Anthropic", "Ollama", "Gemini", "Bedrock", or "Custom" + source (str): Source provider: "OpenAI", "AzureOpenAI", "Anthropic", "Ollama", "Gemini", "Bedrock", "HuggingFace", or "Custom" If None, will attempt to auto-detect from model name base_url (str): The base URL for custom model serving (e.g., "http://localhost:8000/v1"), default is None api_key (str): The API key for the custom llm @@ -196,7 +198,21 @@ def get_llm( stop_sequences=stop_sequences, region_name=os.getenv("AWS_REGION", "us-east-1"), ) - + elif source == "HuggingFace": + try: + from langchain_huggingface import ChatHuggingFace, HuggingFaceEndpoint + except ImportError: + raise ImportError( # noqa: B904 + "langchain-huggingface package is required for HuggingFace models. Install with: pip install langchain-huggingface" + ) + return ChatHuggingFace( + llm=HuggingFaceEndpoint( + repo_id=model, + temperature=temperature, + stop_sequences=stop_sequences, + huggingfacehub_api_token=os.getenv("HUGGINGFACE_API_KEY"), + ) + ) elif source == "Custom": try: from langchain_openai import ChatOpenAI diff --git a/biomni_env/new_software_v007.sh b/biomni_env/new_software_v007.sh index 4cdcb4ae0..6bdae5c2c 100644 --- a/biomni_env/new_software_v007.sh +++ b/biomni_env/new_software_v007.sh @@ -8,4 +8,5 @@ pip install fair-esm pip install nnunet nibabel nilearn pip install mi-googlesearch-python pip install git+https://github.com/pylabrobot/pylabrobot.git +pip install langchain-huggingface conda install weasyprint