diff --git a/libs/community/langchain_community/tools/__init__.py b/libs/community/langchain_community/tools/__init__.py index de486cfbb..eb90e148f 100644 --- a/libs/community/langchain_community/tools/__init__.py +++ b/libs/community/langchain_community/tools/__init__.py @@ -216,6 +216,10 @@ from langchain_community.tools.openweathermap.tool import ( OpenWeatherMapQueryRun, ) + from langchain_community.tools.parallel_search.tool import ( + ParallelSearchResults, + ParallelSearchRun, + ) from langchain_community.tools.playwright import ( ClickTool, CurrentWebPageTool, @@ -446,6 +450,8 @@ "O365SendMessage", "OpenAPISpec", "OpenWeatherMapQueryRun", + "ParallelSearchResults", + "ParallelSearchRun", "PolygonAggregates", "PolygonFinancials", "PolygonLastQuote", @@ -600,6 +606,8 @@ "O365SendMessage": "langchain_community.tools.office365.send_message", "OpenAPISpec": "langchain_community.tools.openapi.utils.openapi_utils", "OpenWeatherMapQueryRun": "langchain_community.tools.openweathermap.tool", + "ParallelSearchResults": "langchain_community.tools.parallel_search.tool", + "ParallelSearchRun": "langchain_community.tools.parallel_search.tool", "PolygonAggregates": "langchain_community.tools.polygon.aggregates", "PolygonFinancials": "langchain_community.tools.polygon.financials", "PolygonLastQuote": "langchain_community.tools.polygon.last_quote", diff --git a/libs/community/langchain_community/tools/parallel_search/__init__.py b/libs/community/langchain_community/tools/parallel_search/__init__.py new file mode 100644 index 000000000..54d8c39f8 --- /dev/null +++ b/libs/community/langchain_community/tools/parallel_search/__init__.py @@ -0,0 +1,11 @@ +"""Parallel Search tool.""" + +from langchain_community.tools.parallel_search.tool import ( + ParallelSearchResults, + ParallelSearchRun, +) + +__all__ = [ + "ParallelSearchRun", + "ParallelSearchResults", +] diff --git a/libs/community/langchain_community/tools/parallel_search/tool.py b/libs/community/langchain_community/tools/parallel_search/tool.py new file mode 100644 index 000000000..9727cb5e8 --- /dev/null +++ b/libs/community/langchain_community/tools/parallel_search/tool.py @@ -0,0 +1,261 @@ +"""Tool for the Parallel Search API.""" + +from typing import Any, Dict, List, Literal, Optional, Tuple, Type + +from langchain_core.callbacks import ( + AsyncCallbackManagerForToolRun, + CallbackManagerForToolRun, +) +from langchain_core.tools import BaseTool +from pydantic import BaseModel, Field + +from langchain_community.utilities.parallel_search import ParallelSearchAPIWrapper + + +class ParallelSearchInput(BaseModel): + """Input for the Parallel Search tool.""" + + objective: Optional[str] = Field( + default=None, + description="Natural-language description of the web research goal. " + "Maximum 5000 characters.", + ) + search_queries: Optional[List[str]] = Field( + default=None, + description="Optional list of search queries to supplement the objective. " + "Maximum 200 characters per query. " + "At least one of 'objective' or 'search_queries' must be provided.", + ) + + +class ParallelSearchRun(BaseTool): + """Tool that queries the Parallel Search API and gets back text results. + + Setup: + Install ``langchain-community`` and set environment variable + ``PARALLEL_API_KEY``. + + .. code-block:: bash + + pip install -U langchain-community + export PARALLEL_API_KEY="your-api-key" + + Instantiate: + + .. code-block:: python + + from langchain_community.tools import ParallelSearchRun + + tool = ParallelSearchRun( + processor="base", + max_results=10, + max_chars_per_result=6000, + ) + + Invoke directly with args: + + .. code-block:: python + + tool.invoke({ + 'objective': 'When was the United Nations established?', + 'search_queries': [ + 'Founding year UN', 'Year of founding United Nations' + ] + }) + """ + + name: str = "parallel_search" + description: str = ( + "A web search API optimized for AI agents. " + "Useful for when you need to answer questions about current events or " + "find information on the web. " + "Input should be an objective (natural language description) and/or " + "search queries (keywords)." + ) + args_schema: Type[BaseModel] = ParallelSearchInput + + processor: str = "base" + """The processor to use ("base" or "pro"). Defaults to "base".""" + max_results: int = 10 + """Maximum number of search results to return (1-20). Defaults to 10.""" + max_chars_per_result: int = 6000 + """Maximum characters per search result (100-30000). Defaults to 6000.""" + source_policy: Optional[Dict[str, Any]] = None + """Optional source policy to include/exclude domains.""" + api_wrapper: ParallelSearchAPIWrapper = Field( + default_factory=ParallelSearchAPIWrapper + ) + + def __init__(self, **kwargs: Any) -> None: + # Create api_wrapper with parallel_api_key if provided + if "parallel_api_key" in kwargs: + kwargs["api_wrapper"] = ParallelSearchAPIWrapper( + parallel_api_key=kwargs["parallel_api_key"] + ) + super().__init__(**kwargs) + + def _run( + self, + objective: Optional[str] = None, + search_queries: Optional[List[str]] = None, + run_manager: Optional[CallbackManagerForToolRun] = None, + ) -> str: + """Use the tool.""" + try: + results = self.api_wrapper.results( + objective=objective, + search_queries=search_queries, + processor=self.processor, + max_results=self.max_results, + max_chars_per_result=self.max_chars_per_result, + source_policy=self.source_policy, + ) + return self._format_results(results) + except Exception as e: + return repr(e) + + async def _arun( + self, + objective: Optional[str] = None, + search_queries: Optional[List[str]] = None, + run_manager: Optional[AsyncCallbackManagerForToolRun] = None, + ) -> str: + """Use the tool asynchronously.""" + try: + results = await self.api_wrapper.results_async( + objective=objective, + search_queries=search_queries, + processor=self.processor, + max_results=self.max_results, + max_chars_per_result=self.max_chars_per_result, + source_policy=self.source_policy, + ) + return self._format_results(results) + except Exception as e: + return repr(e) + + def _format_results(self, results: List[Dict[str, Any]]) -> str: + """Format search results as a string.""" + if not results: + return "No results found." + + formatted = [] + for i, result in enumerate(results, 1): + title = result.get("title", "No title") + url = result.get("url", "") + excerpts = result.get("excerpts", []) + excerpt_text = "\n".join(excerpts) if isinstance(excerpts, list) else "" + + formatted.append(f"Result {i}: {title}\nURL: {url}") + if excerpt_text: + formatted.append(f"Content: {excerpt_text}") + formatted.append("") + + return "\n".join(formatted) + + +class ParallelSearchResults(BaseTool): + """Tool that queries the Parallel Search API and gets back structured results. + + Setup: + Install ``langchain-community`` and set environment variable + ``PARALLEL_API_KEY``. + + .. code-block:: bash + + pip install -U langchain-community + export PARALLEL_API_KEY="your-api-key" + + Instantiate: + + .. code-block:: python + + from langchain_community.tools import ParallelSearchResults + + tool = ParallelSearchResults( + processor="base", + max_results=10, + max_chars_per_result=6000, + ) + + Invoke directly with args: + + .. code-block:: python + + tool.invoke({ + 'objective': 'When was the United Nations established?', + 'search_queries': ['Founding year UN'] + }) + """ + + name: str = "parallel_search_results_json" + description: str = ( + "A web search API optimized for AI agents. " + "Useful for when you need to answer questions about current events or " + "find information on the web. " + "Input should be an objective (natural language description) and/or " + "search queries (keywords). " + "Output is a structured list of search results with URLs, titles, and excerpts." + ) + args_schema: Type[BaseModel] = ParallelSearchInput + + processor: str = "base" + """The processor to use ("base" or "pro"). Defaults to "base".""" + max_results: int = 10 + """Maximum number of search results to return (1-20). Defaults to 10.""" + max_chars_per_result: int = 6000 + """Maximum characters per search result (100-30000). Defaults to 6000.""" + source_policy: Optional[Dict[str, Any]] = None + """Optional source policy to include/exclude domains.""" + api_wrapper: ParallelSearchAPIWrapper = Field( + default_factory=ParallelSearchAPIWrapper + ) + response_format: Literal["content_and_artifact"] = "content_and_artifact" + + def __init__(self, **kwargs: Any) -> None: + # Create api_wrapper with parallel_api_key if provided + if "parallel_api_key" in kwargs: + kwargs["api_wrapper"] = ParallelSearchAPIWrapper( + parallel_api_key=kwargs["parallel_api_key"] + ) + super().__init__(**kwargs) + + def _run( + self, + objective: Optional[str] = None, + search_queries: Optional[List[str]] = None, + run_manager: Optional[CallbackManagerForToolRun] = None, + ) -> Tuple[str, List[Dict[str, Any]]]: + """Use the tool.""" + try: + results = self.api_wrapper.results( + objective=objective, + search_queries=search_queries, + processor=self.processor, + max_results=self.max_results, + max_chars_per_result=self.max_chars_per_result, + source_policy=self.source_policy, + ) + return str(results), results + except Exception as e: + return repr(e), [] + + async def _arun( + self, + objective: Optional[str] = None, + search_queries: Optional[List[str]] = None, + run_manager: Optional[AsyncCallbackManagerForToolRun] = None, + ) -> Tuple[str, List[Dict[str, Any]]]: + """Use the tool asynchronously.""" + try: + results = await self.api_wrapper.results_async( + objective=objective, + search_queries=search_queries, + processor=self.processor, + max_results=self.max_results, + max_chars_per_result=self.max_chars_per_result, + source_policy=self.source_policy, + ) + return str(results), results + except Exception as e: + return repr(e), [] diff --git a/libs/community/langchain_community/utilities/__init__.py b/libs/community/langchain_community/utilities/__init__.py index 0174d37c0..f10b00393 100644 --- a/libs/community/langchain_community/utilities/__init__.py +++ b/libs/community/langchain_community/utilities/__init__.py @@ -113,6 +113,9 @@ from langchain_community.utilities.outline import ( OutlineAPIWrapper, ) + from langchain_community.utilities.parallel_search import ( + ParallelSearchAPIWrapper, + ) from langchain_community.utilities.passio_nutrition_ai import ( NutritionAIAPI, ) @@ -185,6 +188,7 @@ "BingSearchAPIWrapper", "BraveSearchWrapper", "DataheraldAPIWrapper", + "ParallelSearchAPIWrapper", "DriaAPIWrapper", "DuckDuckGoSearchAPIWrapper", "GoldenQueryAPIWrapper", @@ -249,6 +253,7 @@ "BingSearchAPIWrapper": "langchain_community.utilities.bing_search", "BraveSearchWrapper": "langchain_community.utilities.brave_search", "DataheraldAPIWrapper": "langchain_community.utilities.dataherald", + "ParallelSearchAPIWrapper": "langchain_community.utilities.parallel_search", "DriaAPIWrapper": "langchain_community.utilities.dria_index", "DuckDuckGoSearchAPIWrapper": "langchain_community.utilities.duckduckgo_search", "GoldenQueryAPIWrapper": "langchain_community.utilities.golden_query", diff --git a/libs/community/langchain_community/utilities/parallel_search.py b/libs/community/langchain_community/utilities/parallel_search.py new file mode 100644 index 000000000..0ebc6ddc6 --- /dev/null +++ b/libs/community/langchain_community/utilities/parallel_search.py @@ -0,0 +1,237 @@ +"""Util that calls Parallel Search API. + +In order to set this up, follow instructions at: +https://docs.parallel.ai/search/search-quickstart +""" + +from typing import Any, Dict, List, Optional + +import aiohttp +import requests +from langchain_core.utils import get_from_dict_or_env +from pydantic import BaseModel, ConfigDict, SecretStr, model_validator + +PARALLEL_API_URL = "https://api.parallel.ai" + + +class ParallelSearchAPIWrapper(BaseModel): + """Wrapper for Parallel Search API.""" + + parallel_api_key: SecretStr + """The API key to use for the Parallel search engine.""" + base_url: str = PARALLEL_API_URL + """The base URL for the Parallel API.""" + api_version: str = "v1beta" + """The API version to use.""" + + model_config = ConfigDict( + extra="forbid", + ) + + @model_validator(mode="before") + @classmethod + def validate_environment(cls, values: Dict) -> Any: + """Validate that api key exists in environment.""" + parallel_api_key = get_from_dict_or_env( + values, "parallel_api_key", "PARALLEL_API_KEY" + ) + values["parallel_api_key"] = parallel_api_key + return values + + def raw_results( + self, + objective: Optional[str] = None, + search_queries: Optional[List[str]] = None, + processor: str = "base", + max_results: int = 10, + max_chars_per_result: int = 6000, + source_policy: Optional[Dict[str, Any]] = None, + ) -> Dict[str, Any]: + """Get raw results from the Parallel Search API. + + Args: + objective: Natural-language description of the web research goal. + search_queries: Optional list of search queries to supplement the objective. + processor: The processor to use ("base" or "pro"). Defaults to "base". + max_results: Maximum number of search results to return (1-20). + Defaults to 10. + max_chars_per_result: Maximum characters per search result (100-30000). + Defaults to 6000. + source_policy: Optional source policy to include/exclude domains. + + Returns: + Raw API response as a dictionary. + """ + if not objective and not search_queries: + raise ValueError("Either 'objective' or 'search_queries' must be provided.") + + payload: Dict[str, Any] = { + "processor": processor, + "max_results": max_results, + "excerpts": { + "max_chars_per_result": max_chars_per_result, + }, + } + + if objective: + payload["objective"] = objective + if search_queries: + payload["search_queries"] = search_queries + if source_policy: + payload["source_policy"] = source_policy + + headers = { + "Content-Type": "application/json", + "x-api-key": self.parallel_api_key.get_secret_value(), + "parallel-beta": "search-extract-2025-10-10", + } + + url = f"{self.base_url}/{self.api_version}/search" + response = requests.post(url, json=payload, headers=headers) + response.raise_for_status() + return response.json() + + def results( + self, + objective: Optional[str] = None, + search_queries: Optional[List[str]] = None, + processor: str = "base", + max_results: int = 10, + max_chars_per_result: int = 6000, + source_policy: Optional[Dict[str, Any]] = None, + ) -> List[Dict[str, Any]]: + """Run query through Parallel Search and return cleaned results. + + Args: + objective: Natural-language description of the web research goal. + search_queries: Optional list of search queries to supplement the objective. + processor: The processor to use ("base" or "pro"). Defaults to "base". + max_results: Maximum number of search results to return (1-20). + Defaults to 10. + max_chars_per_result: Maximum characters per search result (100-30000). + Defaults to 6000. + source_policy: Optional source policy to include/exclude domains. + + Returns: + List of cleaned search results. + """ + raw_results = self.raw_results( + objective=objective, + search_queries=search_queries, + processor=processor, + max_results=max_results, + max_chars_per_result=max_chars_per_result, + source_policy=source_policy, + ) + return self.clean_results(raw_results.get("results", [])) + + async def raw_results_async( + self, + objective: Optional[str] = None, + search_queries: Optional[List[str]] = None, + processor: str = "base", + max_results: int = 10, + max_chars_per_result: int = 6000, + source_policy: Optional[Dict[str, Any]] = None, + ) -> Dict[str, Any]: + """Get raw results from the Parallel Search API asynchronously. + + Args: + objective: Natural-language description of the web research goal. + search_queries: Optional list of search queries to supplement the objective. + processor: The processor to use ("base" or "pro"). Defaults to "base". + max_results: Maximum number of search results to return (1-20). + Defaults to 10. + max_chars_per_result: Maximum characters per search result (100-30000). + Defaults to 6000. + source_policy: Optional source policy to include/exclude domains. + + Returns: + Raw API response as a dictionary. + """ + if not objective and not search_queries: + raise ValueError("Either 'objective' or 'search_queries' must be provided.") + + payload: Dict[str, Any] = { + "processor": processor, + "max_results": max_results, + "excerpts": { + "max_chars_per_result": max_chars_per_result, + }, + } + + if objective: + payload["objective"] = objective + if search_queries: + payload["search_queries"] = search_queries + if source_policy: + payload["source_policy"] = source_policy + + headers = { + "Content-Type": "application/json", + "x-api-key": self.parallel_api_key.get_secret_value(), + "parallel-beta": "search-extract-2025-10-10", + } + + url = f"{self.base_url}/{self.api_version}/search" + async with aiohttp.ClientSession() as session: + async with session.post(url, json=payload, headers=headers) as response: + if response.status != 200: + error_text = await response.text() + raise Exception(f"Error {response.status}: {error_text}") + return await response.json() + + async def results_async( + self, + objective: Optional[str] = None, + search_queries: Optional[List[str]] = None, + processor: str = "base", + max_results: int = 10, + max_chars_per_result: int = 6000, + source_policy: Optional[Dict[str, Any]] = None, + ) -> List[Dict[str, Any]]: + """Run query through Parallel Search and return cleaned results asynchronously. + + Args: + objective: Natural-language description of the web research goal. + search_queries: Optional list of search queries to supplement the objective. + processor: The processor to use ("base" or "pro"). Defaults to "base". + max_results: Maximum number of search results to return (1-20). + Defaults to 10. + max_chars_per_result: Maximum characters per search result (100-30000). + Defaults to 6000. + source_policy: Optional source policy to include/exclude domains. + + Returns: + List of cleaned search results. + """ + raw_results = await self.raw_results_async( + objective=objective, + search_queries=search_queries, + processor=processor, + max_results=max_results, + max_chars_per_result=max_chars_per_result, + source_policy=source_policy, + ) + return self.clean_results(raw_results.get("results", [])) + + def clean_results(self, results: List[Dict[str, Any]]) -> List[Dict[str, Any]]: + """Clean results from Parallel Search API. + + Args: + results: Raw results from the API. + + Returns: + List of cleaned result dictionaries. + """ + cleaned = [] + for result in results: + cleaned_result: Dict[str, Any] = { + "url": result.get("url"), + "title": result.get("title"), + "excerpts": result.get("excerpts", []), + } + if publish_date := result.get("publish_date"): + cleaned_result["publish_date"] = publish_date + cleaned.append(cleaned_result) + return cleaned diff --git a/libs/community/tests/unit_tests/tools/parallel_search/test_parallel_search.py b/libs/community/tests/unit_tests/tools/parallel_search/test_parallel_search.py new file mode 100644 index 000000000..1c1e00a7c --- /dev/null +++ b/libs/community/tests/unit_tests/tools/parallel_search/test_parallel_search.py @@ -0,0 +1,190 @@ +"""Test Parallel Search tools.""" + +from typing import Any +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +from langchain_community.tools.parallel_search.tool import ( + ParallelSearchResults, + ParallelSearchRun, +) +from langchain_community.utilities.parallel_search import ParallelSearchAPIWrapper + + +def test_parallel_search_run_initialization() -> None: + """Test ParallelSearchRun initialization.""" + tool = ParallelSearchRun( + parallel_api_key="test-key", + processor="base", + max_results=5, + ) + assert tool.name == "parallel_search" + assert tool.processor == "base" + assert tool.max_results == 5 + + +def test_parallel_search_results_initialization() -> None: + """Test ParallelSearchResults initialization.""" + tool = ParallelSearchResults( + parallel_api_key="test-key", + processor="pro", + max_results=10, + ) + assert tool.name == "parallel_search_results_json" + assert tool.processor == "pro" + assert tool.max_results == 10 + + +@patch( + "langchain_community.tools.parallel_search.tool.ParallelSearchAPIWrapper.results" +) +def test_parallel_search_run_invoke(mock_results: MagicMock) -> None: + """Test ParallelSearchRun invoke method.""" + mock_results.return_value = [ + { + "url": "https://example.com", + "title": "Example", + "excerpts": ["This is an example excerpt"], + } + ] + + tool = ParallelSearchRun(parallel_api_key="test-key") + result = tool.invoke( + { + "objective": "test objective", + "search_queries": ["test query"], + } + ) + + assert "Result 1" in result + assert "Example" in result + assert "https://example.com" in result + mock_results.assert_called_once() + + +@patch( + "langchain_community.tools.parallel_search.tool.ParallelSearchAPIWrapper.results_async" +) +@pytest.mark.asyncio +async def test_parallel_search_run_ainvoke(mock_results_async: AsyncMock) -> None: + """Test ParallelSearchRun async invoke method.""" + mock_results_async.return_value = [ + { + "url": "https://example.com", + "title": "Example", + "excerpts": ["This is an example excerpt"], + } + ] + + tool = ParallelSearchRun(parallel_api_key="test-key") + result = await tool.ainvoke( + { + "objective": "test objective", + "search_queries": ["test query"], + } + ) + + assert "Result 1" in result + assert "Example" in result + mock_results_async.assert_called_once() + + +@patch( + "langchain_community.tools.parallel_search.tool.ParallelSearchAPIWrapper.results" +) +def test_parallel_search_results_invoke(mock_results: MagicMock) -> None: + """Test ParallelSearchResults invoke method.""" + mock_results.return_value = [ + { + "url": "https://example.com", + "title": "Example", + "excerpts": ["This is an example excerpt"], + } + ] + + tool = ParallelSearchResults(parallel_api_key="test-key") + # Call _run directly to test the tuple return + content, artifact = tool._run( + objective="test objective", + search_queries=["test query"], + ) + + assert isinstance(content, str) + assert isinstance(artifact, list) + assert len(artifact) == 1 + assert artifact[0]["url"] == "https://example.com" + mock_results.assert_called_once() + + +@patch( + "langchain_community.tools.parallel_search.tool.ParallelSearchAPIWrapper.results_async" +) +@pytest.mark.asyncio +async def test_parallel_search_results_ainvoke(mock_results_async: AsyncMock) -> None: + """Test ParallelSearchResults async invoke method.""" + mock_results_async.return_value = [ + { + "url": "https://example.com", + "title": "Example", + "excerpts": ["This is an example excerpt"], + } + ] + + tool = ParallelSearchResults(parallel_api_key="test-key") + # Call _arun directly to test the tuple return + content, artifact = await tool._arun( + objective="test objective", + search_queries=["test query"], + ) + + assert isinstance(content, str) + assert isinstance(artifact, list) + assert len(artifact) == 1 + mock_results_async.assert_called_once() + + +def test_parallel_search_run_format_results() -> None: + """Test ParallelSearchRun result formatting.""" + tool = ParallelSearchRun(parallel_api_key="test-key") + results = [ + { + "url": "https://example.com", + "title": "Example Title", + "excerpts": ["Excerpt 1", "Excerpt 2"], + }, + { + "url": "https://example2.com", + "title": "Example 2", + "excerpts": [], + }, + ] + + formatted = tool._format_results(results) + + assert "Result 1" in formatted + assert "Example Title" in formatted + assert "https://example.com" in formatted + assert "Excerpt 1" in formatted + assert "Result 2" in formatted + assert "Example 2" in formatted + + +def test_parallel_search_run_no_results() -> None: + """Test ParallelSearchRun with no results.""" + tool = ParallelSearchRun(parallel_api_key="test-key") + formatted = tool._format_results([]) + assert formatted == "No results found." + + +@patch( + "langchain_community.tools.parallel_search.tool.ParallelSearchAPIWrapper.results" +) +def test_parallel_search_run_error_handling(mock_results: MagicMock) -> None: + """Test ParallelSearchRun error handling.""" + mock_results.side_effect = Exception("API Error") + + tool = ParallelSearchRun(parallel_api_key="test-key") + result = tool.invoke({"objective": "test"}) + + assert "API Error" in result diff --git a/libs/community/tests/unit_tests/utilities/test_parallel_search.py b/libs/community/tests/unit_tests/utilities/test_parallel_search.py new file mode 100644 index 000000000..783e60803 --- /dev/null +++ b/libs/community/tests/unit_tests/utilities/test_parallel_search.py @@ -0,0 +1,122 @@ +"""Test Parallel Search API wrapper.""" + +from typing import Any +from unittest.mock import MagicMock, patch + +import pytest +from pydantic import SecretStr + +from langchain_community.utilities.parallel_search import ParallelSearchAPIWrapper + + +def test_api_key_explicit() -> None: + """Test that the API key is correctly set when provided explicitly.""" + explicit_key = "explicit-api-key" + wrapper = ParallelSearchAPIWrapper(parallel_api_key=SecretStr(explicit_key)) + assert wrapper.parallel_api_key.get_secret_value() == explicit_key + + +def test_api_key_from_env(monkeypatch: Any) -> None: + """Test that the API key is correctly obtained from the environment variable.""" + env_key = "env-api-key" + monkeypatch.setenv("PARALLEL_API_KEY", env_key) + # Do not pass the api_key explicitly + wrapper = ParallelSearchAPIWrapper() + assert wrapper.parallel_api_key.get_secret_value() == env_key + + +def test_api_key_missing(monkeypatch: Any) -> None: + """Test that instantiation fails when no API key is provided.""" + # Ensure that the environment variable is not set + monkeypatch.delenv("PARALLEL_API_KEY", raising=False) + with pytest.raises(ValueError): + # This should raise an error because no api_key is available. + ParallelSearchAPIWrapper() + + +def test_validate_requires_objective_or_search_queries() -> None: + """Test that either objective or search_queries must be provided.""" + wrapper = ParallelSearchAPIWrapper(parallel_api_key=SecretStr("test-key")) + with pytest.raises(ValueError, match="Either 'objective' or 'search_queries'"): + wrapper.raw_results() + + +@patch("langchain_community.utilities.parallel_search.requests.post") +def test_raw_results_success(mock_post: MagicMock) -> None: + """Test successful raw_results call.""" + mock_response = MagicMock() + mock_response.json.return_value = { + "results": [ + { + "url": "https://example.com", + "title": "Example", + "excerpts": ["This is an example"], + } + ] + } + mock_response.raise_for_status = MagicMock() + mock_post.return_value = mock_response + + wrapper = ParallelSearchAPIWrapper(parallel_api_key=SecretStr("test-key")) + result = wrapper.raw_results(objective="test objective") + + assert "results" in result + mock_post.assert_called_once() + call_kwargs = mock_post.call_args + assert call_kwargs[1]["headers"]["x-api-key"] == "test-key" + assert call_kwargs[1]["json"]["objective"] == "test objective" + + +@patch("langchain_community.utilities.parallel_search.requests.post") +def test_results_cleans_output(mock_post: MagicMock) -> None: + """Test that results method cleans the output correctly.""" + mock_response = MagicMock() + mock_response.json.return_value = { + "results": [ + { + "url": "https://example.com", + "title": "Example Title", + "excerpts": ["Excerpt 1", "Excerpt 2"], + "publish_date": "2025-01-01", + } + ] + } + mock_response.raise_for_status = MagicMock() + mock_post.return_value = mock_response + + wrapper = ParallelSearchAPIWrapper(parallel_api_key=SecretStr("test-key")) + results = wrapper.results(search_queries=["test query"]) + + assert len(results) == 1 + assert results[0]["url"] == "https://example.com" + assert results[0]["title"] == "Example Title" + assert results[0]["excerpts"] == ["Excerpt 1", "Excerpt 2"] + assert results[0]["publish_date"] == "2025-01-01" + + +def test_clean_results() -> None: + """Test clean_results method.""" + wrapper = ParallelSearchAPIWrapper(parallel_api_key=SecretStr("test-key")) + raw_results = [ + { + "url": "https://example.com", + "title": "Example", + "excerpts": ["Excerpt 1"], + "publish_date": "2025-01-01", + }, + { + "url": "https://example2.com", + "title": "Example 2", + "excerpts": [], + }, + ] + + cleaned = wrapper.clean_results(raw_results) + + assert len(cleaned) == 2 + assert cleaned[0]["url"] == "https://example.com" + assert cleaned[0]["title"] == "Example" + assert cleaned[0]["excerpts"] == ["Excerpt 1"] + assert cleaned[0]["publish_date"] == "2025-01-01" + assert cleaned[1]["url"] == "https://example2.com" + assert "publish_date" not in cleaned[1]