Skip to content
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions stagehand/browser.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,8 +180,11 @@ async def connect_local_browser(
except Exception as e:
logger.error(f"Failed to create downloads_path {downloads_path}: {e}")

executable_path_option = local_browser_launch_options.get("executablePath")

# Prepare Launch Options (translate keys if needed)
launch_options = {
"executable_path": executable_path_option,
"headless": local_browser_launch_options.get(
"headless", stagehand_instance.config.headless
),
Expand All @@ -208,6 +211,8 @@ async def connect_local_browser(
"ignoreHTTPSErrors", True
),
}
if executable_path_option:
launch_options["executable_path"] = executable_path_option
launch_options = {k: v for k, v in launch_options.items() if v is not None}

# Launch Context
Expand Down
194 changes: 194 additions & 0 deletions stagehand/llm/qwenclient.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,194 @@
from typing import Any, Dict, Optional

import json
import aiohttp
from pydantic import BaseModel
from stagehand.llm.client import LLMClient # 继承项目现有的 LLMClient 基类
from stagehand.metrics import start_inference_timer, get_inference_time_ms


class HybridDict(dict):
def __init__(self, data: dict):
super().__init__(data)
# 递归处理嵌套字典(如 usage、choices 里的内容)
for k, v in data.items():
if isinstance(v, dict):
self[k] = HybridDict(v)
elif isinstance(v, list):
# 处理列表中的字典(如 choices 数组)
self[k] = [HybridDict(item) if isinstance(item, dict) else item for item in v]

# 支持属性访问(如 self.usage → self['usage'])
def __getattr__(self, name):
try:
return self[name]
except KeyError:
raise AttributeError(f"'HybridDict' object has no attribute '{name}'")

# 支持属性赋值(可选)
def __setattr__(self, name, value):
self[name] = value


class QwenClient(LLMClient):
def __init__(self, stagehand_logger, api_key: str, model_name: str = "qwen-turbo", **kwargs):
# 调用父类构造函数,符合现有 LLMClient 的初始化方式
super().__init__(
stagehand_logger=stagehand_logger,
api_key=api_key,
default_model=model_name, **kwargs
)
self.api_key = api_key
self.api_base = "https://dashscope.aliyuncs.com/compatible-mode/v1/chat/completions"

async def create_response(
self,
*,
messages: list[dict[str, str]],
model: Optional[str] = None,
function_name: Optional[str] = None,
**kwargs: Any,
) -> dict[str, Any]:
# 1. 基础参数校验
model = model or self.default_model
if not model:
raise ValueError("未指定模型名称")

# 2. 构建请求头(修复:自定义 headers,而非引用 litellm 模块)
headers = {
"Content-Type": "application/json",
"Authorization": f"Bearer {self.model_api_key}"
}

# 3. 处理 response_format + 自动追加 JSON 格式提示
response_format = kwargs.get("response_format")
# 深拷贝 messages,避免修改原始数据(关键!)
processed_messages = [msg.copy() for msg in messages]

# 仅当指定 response_format 时,追加 JSON 格式要求
if response_format:
json_format_prompt = """
你是浏览器自动化的元素识别助手,需基于 Accessibility Tree(语义化节点树)返回符合要求的可操作元素,规则如下:
### 输入说明
1. Accessibility Tree 包含节点格式:[节点ID] 角色: 标签(如 [4] textbox: 请输入账号);
2. 目标操作是用户的自然语言指令(如「在请输入账号输入框中输入内容」)。
### 返回规则
1. 仅返回合法 JSON 字符串,无任何多余文字/解释/代码块;
2. 顶层为字典,仅包含 "elements" 键(值为数组);
3. 数组内每个元素必须包含以下字段:
- element_id:Accessibility Tree 中的原始数字 ID(如 [4] 中的 4,整数类型,用于定位元素);
- description:元素的描述(结合角色和标签,如「textbox: 请输入账号」);
- method:Playwright 支持的操作方法(如 textbox/textarea 用 "fill",button/link 用 "click");
- arguments:必须是list数组结构,操作的参数列表list,即使只有一个元素也应该用数组嵌套,没有元素时填入一个空字符串''。
### 返回格式示例
{
"elements": [
{
"element_id": 4,
"description": "textbox: 请输入账号",
"method": "fill",
"arguments": ["15211228071"]
}
]
}
### 强制要求
1. 仅返回上述格式的 JSON,无其他内容;
2. element_id 必须是 Accessibility Tree 中的原始数字,不允许自定义;
3. 若未找到匹配元素,返回 {"elements": []};
4. 若用户指令是动作(如输入/点击),优先返回最匹配的单个元素;若为观察(如"找到所有按钮"),返回所有符合条件的元素。
"""
has_system_msg = False

# 遍历 messages,在已有 system 消息后追加提示
for msg in processed_messages:
if msg["role"] == "system":
msg["content"] += json_format_prompt
has_system_msg = True
break

# 若无 system 消息,新增一条(保证 JSON 提示存在)
if not has_system_msg:
processed_messages.insert(0, {
"role": "system",
"content": f"你是一个专业的助手{json_format_prompt}"
})

# 4. 初始化请求体(使用处理后的 messages)
payload = {
"model": model,
"messages": processed_messages, # 用追加提示后的 messages
"temperature": kwargs.get("temperature", 0.1),
"max_tokens": kwargs.get("max_tokens", 1024),
"top_p": kwargs.get("top_p", 0.9),
}

# 5. 处理 response_format 格式转换(原有逻辑不变)
if response_format:
# 场景1:传入的是 Pydantic 模型(比如 ObserveInferenceSchema)
if isinstance(response_format, type) and issubclass(response_format, BaseModel):
payload["response_format"] = {
"type": "json_schema",
"json_schema": {
"name": response_format.__name__,
"strict": True,
"schema": response_format.model_json_schema(),
}
}
# 场景2:传入的是 json_schema 格式(但 strict 为 false)
elif isinstance(response_format, dict) and response_format.get("type") == "json_schema":
response_format["json_schema"]["strict"] = True
payload["response_format"] = response_format
# 场景3:传入的是 json_object(通用兼容)
elif isinstance(response_format, dict) and response_format.get("type") == "json_object":
payload["response_format"] = response_format
# 其他情况:兜底为 json_object
else:
payload["response_format"] = {"type": "json_object"}

# 6. 发送请求
start_time = start_inference_timer()
async with aiohttp.ClientSession() as session:
async with session.post(self.api_base, json=payload, headers=headers) as response:
if response.status != 200:
raise ValueError(f"通义千问 API 错误: {await response.text()}")
response_data = await response.json()

# 7. 解析 JSON 响应(原有逻辑不变)
if response_format:
content = response_data.get("choices", [{}])[0].get("message", {}).get("content", "")
if content:
try:
parsed_content = json.loads(content)
# ========== 核心适配逻辑 ==========
# 遍历 elements 数组,强制转换 arguments 为列表
if isinstance(parsed_content, dict) and "elements" in parsed_content:
elements = parsed_content["elements"]
if isinstance(elements, list):
for elem in elements:
# 1. 如果 arguments 字段不存在 → 初始化为空列表
if "arguments" not in elem:
elem["arguments"] = []
else:
# 2. 如果 arguments 是 None → 替换为空列表
if elem["arguments"] is None:
elem["arguments"] = []
# 3. 如果 arguments 是非列表类型(字符串/数字/布尔等)→ 包装为列表
elif not isinstance(elem["arguments"], list):
# 额外处理:如果值是 None,包装为包含空字符串的列表(按需可选)
elem["arguments"] = [elem["arguments"] if elem["arguments"] is not None else ""]
# ========== 适配结束 ==========
response_data["choices"][0]["message"]["content"] = parsed_content
except json.JSONDecodeError as e:
raise ValueError(f"千问模型返回非 JSON 内容:{content} | 错误:{str(e)}")

# 8. 调用指标回调(关键修改:包装响应为对象)
inference_time_ms = get_inference_time_ms(start_time)
if self.metrics_callback and function_name:
response_obj = HybridDict(response_data)
self.metrics_callback(response_obj, inference_time_ms, function_name)

return HybridDict(response_data)
26 changes: 19 additions & 7 deletions stagehand/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
from .metrics import StagehandFunctionName, StagehandMetrics
from .page import StagehandPage
from .utils import get_download_path, make_serializable
from .llm.qwenclient import QwenClient # 导入千问客户端

load_dotenv()

Expand Down Expand Up @@ -284,13 +285,24 @@ def __init__(
# Setup LLM client if LOCAL mode
self.llm = None
if not self.use_api:
self.llm = LLMClient(
stagehand_logger=self.logger,
api_key=self.model_api_key,
default_model=self.model_name,
metrics_callback=self._handle_llm_metrics,
**self.model_client_options,
)
# 检查是否为千问模型,使用自定义 QwenClient
if self.model_name in ["qwen-turbo", "qwen-plus", "qwen-max"]:
self.llm = QwenClient(
stagehand_logger=self.logger,
api_key=self.api_key,
model_name=self.model_name,
metrics_callback=self._handle_llm_metrics,
**self.model_client_options,
)
else:
# 其他模型使用默认 LLMClient
self.llm = LLMClient(
stagehand_logger=self.logger,
api_key=self.model_api_key,
default_model=self.model_name,
metrics_callback=self._handle_llm_metrics,
**self.model_client_options,
)

def _register_signal_handlers(self):
"""Register signal handlers for SIGINT and SIGTERM to ensure proper cleanup."""
Expand Down
4 changes: 4 additions & 0 deletions stagehand/schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,10 @@ class AvailableModel(str, Enum):
CLAUDE_3_7_SONNET_LATEST = "claude-3-7-sonnet-latest"
COMPUTER_USE_PREVIEW = "computer-use-preview"
GEMINI_2_0_FLASH = "gemini-2.0-flash"
# 添加通义千问模型
QWEN_TURBO = "qwen-turbo"
QWEN_PLUS = "qwen-plus"
QWEN_MAX = "qwen-max"


class StagehandBaseModel(BaseModel):
Expand Down