diff --git a/astrbot/core/backup/__init__.py b/astrbot/core/backup/__init__.py new file mode 100644 index 000000000..8e33ef970 --- /dev/null +++ b/astrbot/core/backup/__init__.py @@ -0,0 +1,26 @@ +"""AstrBot 备份与恢复模块 + +提供数据导出和导入功能,支持用户在服务器迁移时一键备份和恢复所有数据。 +""" + +# 从 constants 模块导入共享常量 +from .constants import ( + BACKUP_MANIFEST_VERSION, + KB_METADATA_MODELS, + MAIN_DB_MODELS, + get_backup_directories, +) + +# 导入导出器和导入器 +from .exporter import AstrBotExporter +from .importer import AstrBotImporter, ImportPreCheckResult + +__all__ = [ + "AstrBotExporter", + "AstrBotImporter", + "ImportPreCheckResult", + "MAIN_DB_MODELS", + "KB_METADATA_MODELS", + "get_backup_directories", + "BACKUP_MANIFEST_VERSION", +] diff --git a/astrbot/core/backup/constants.py b/astrbot/core/backup/constants.py new file mode 100644 index 000000000..b45b702e7 --- /dev/null +++ b/astrbot/core/backup/constants.py @@ -0,0 +1,77 @@ +"""AstrBot 备份模块共享常量 + +此文件定义了导出器和导入器共享的常量,确保两端配置一致。 +""" + +from sqlmodel import SQLModel + +from astrbot.core.db.po import ( + Attachment, + CommandConfig, + CommandConflict, + ConversationV2, + Persona, + PlatformMessageHistory, + PlatformSession, + PlatformStat, + Preference, +) +from astrbot.core.knowledge_base.models import ( + KBDocument, + KBMedia, + KnowledgeBase, +) +from astrbot.core.utils.astrbot_path import ( + get_astrbot_config_path, + get_astrbot_plugin_data_path, + get_astrbot_plugin_path, + get_astrbot_t2i_templates_path, + get_astrbot_temp_path, + get_astrbot_webchat_path, +) + +# ============================================================ +# 共享常量 - 确保导出和导入端配置一致 +# ============================================================ + +# 主数据库模型类映射 +MAIN_DB_MODELS: dict[str, type[SQLModel]] = { + "platform_stats": PlatformStat, + "conversations": ConversationV2, + "personas": Persona, + "preferences": Preference, + "platform_message_history": PlatformMessageHistory, + "platform_sessions": PlatformSession, + "attachments": Attachment, + "command_configs": CommandConfig, + "command_conflicts": CommandConflict, +} + +# 知识库元数据模型类映射 +KB_METADATA_MODELS: dict[str, type[SQLModel]] = { + "knowledge_bases": KnowledgeBase, + "kb_documents": KBDocument, + "kb_media": KBMedia, +} + + +def get_backup_directories() -> dict[str, str]: + """获取需要备份的目录列表 + + 使用 astrbot_path 模块动态获取路径,支持通过环境变量 ASTRBOT_ROOT 自定义根目录。 + + Returns: + dict: 键为备份文件中的目录名称,值为目录的绝对路径 + """ + return { + "plugins": get_astrbot_plugin_path(), # 插件本体 + "plugin_data": get_astrbot_plugin_data_path(), # 插件数据 + "config": get_astrbot_config_path(), # 配置目录 + "t2i_templates": get_astrbot_t2i_templates_path(), # T2I 模板 + "webchat": get_astrbot_webchat_path(), # WebChat 数据 + "temp": get_astrbot_temp_path(), # 临时文件 + } + + +# 备份清单版本号 +BACKUP_MANIFEST_VERSION = "1.1" diff --git a/astrbot/core/backup/exporter.py b/astrbot/core/backup/exporter.py new file mode 100644 index 000000000..bd98124ce --- /dev/null +++ b/astrbot/core/backup/exporter.py @@ -0,0 +1,471 @@ +"""AstrBot 数据导出器 + +负责将所有数据导出为 ZIP 备份文件。 +导出格式为 JSON,这是数据库无关的方案,支持未来向 MySQL/PostgreSQL 迁移。 +""" + +import hashlib +import json +import os +import zipfile +from datetime import datetime, timezone +from pathlib import Path +from typing import TYPE_CHECKING, Any + +from sqlalchemy import select + +from astrbot.core import logger +from astrbot.core.config.default import VERSION +from astrbot.core.db import BaseDatabase + +# 从共享常量模块导入 +from .constants import ( + BACKUP_MANIFEST_VERSION, + KB_METADATA_MODELS, + MAIN_DB_MODELS, + get_backup_directories, +) + +if TYPE_CHECKING: + from astrbot.core.knowledge_base.kb_mgr import KnowledgeBaseManager + + +class AstrBotExporter: + """AstrBot 数据导出器 + + 导出内容: + - 主数据库所有表(data/data_v4.db) + - 知识库元数据(data/knowledge_base/kb.db) + - 每个知识库的向量文档数据 + - 配置文件(data/cmd_config.json) + - 附件文件 + - 知识库多媒体文件 + - 插件目录(data/plugins) + - 插件数据目录(data/plugin_data) + - 配置目录(data/config) + - T2I 模板目录(data/t2i_templates) + - WebChat 数据目录(data/webchat) + - 临时文件目录(data/temp) + """ + + def __init__( + self, + main_db: BaseDatabase, + kb_manager: "KnowledgeBaseManager | None" = None, + config_path: str = "data/cmd_config.json", + attachments_dir: str = "data/attachments", + data_root: str = "data", + ): + self.main_db = main_db + self.kb_manager = kb_manager + self.config_path = config_path + self.attachments_dir = attachments_dir + self.data_root = data_root + self._checksums: dict[str, str] = {} + + async def export_all( + self, + output_dir: str = "data/backups", + progress_callback: Any | None = None, + ) -> str: + """导出所有数据到 ZIP 文件 + + Args: + output_dir: 输出目录 + progress_callback: 进度回调函数,接收参数 (stage, current, total, message) + + Returns: + str: 生成的 ZIP 文件路径 + """ + # 确保输出目录存在 + Path(output_dir).mkdir(parents=True, exist_ok=True) + + timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") + zip_filename = f"astrbot_backup_{timestamp}.zip" + zip_path = os.path.join(output_dir, zip_filename) + + logger.info(f"开始导出备份到 {zip_path}") + + try: + with zipfile.ZipFile(zip_path, "w", zipfile.ZIP_DEFLATED) as zf: + # 1. 导出主数据库 + if progress_callback: + await progress_callback("main_db", 0, 100, "正在导出主数据库...") + main_data = await self._export_main_database() + main_db_json = json.dumps( + main_data, ensure_ascii=False, indent=2, default=str + ) + zf.writestr("databases/main_db.json", main_db_json) + self._add_checksum("databases/main_db.json", main_db_json) + if progress_callback: + await progress_callback("main_db", 100, 100, "主数据库导出完成") + + # 2. 导出知识库数据 + kb_meta_data: dict[str, Any] = { + "knowledge_bases": [], + "kb_documents": [], + "kb_media": [], + } + if self.kb_manager: + if progress_callback: + await progress_callback( + "kb_metadata", 0, 100, "正在导出知识库元数据..." + ) + kb_meta_data = await self._export_kb_metadata() + kb_meta_json = json.dumps( + kb_meta_data, ensure_ascii=False, indent=2, default=str + ) + zf.writestr("databases/kb_metadata.json", kb_meta_json) + self._add_checksum("databases/kb_metadata.json", kb_meta_json) + if progress_callback: + await progress_callback( + "kb_metadata", 100, 100, "知识库元数据导出完成" + ) + + # 导出每个知识库的文档数据 + kb_insts = self.kb_manager.kb_insts + total_kbs = len(kb_insts) + for idx, (kb_id, kb_helper) in enumerate(kb_insts.items()): + if progress_callback: + await progress_callback( + "kb_documents", + idx, + total_kbs, + f"正在导出知识库 {kb_helper.kb.kb_name} 的文档数据...", + ) + doc_data = await self._export_kb_documents(kb_helper) + doc_json = json.dumps( + doc_data, ensure_ascii=False, indent=2, default=str + ) + doc_path = f"databases/kb_{kb_id}/documents.json" + zf.writestr(doc_path, doc_json) + self._add_checksum(doc_path, doc_json) + + # 导出 FAISS 索引文件 + await self._export_faiss_index(zf, kb_helper, kb_id) + + # 导出知识库多媒体文件 + await self._export_kb_media_files(zf, kb_helper, kb_id) + + if progress_callback: + await progress_callback( + "kb_documents", total_kbs, total_kbs, "知识库文档导出完成" + ) + + # 3. 导出配置文件 + if progress_callback: + await progress_callback("config", 0, 100, "正在导出配置文件...") + if os.path.exists(self.config_path): + with open(self.config_path, encoding="utf-8") as f: + config_content = f.read() + zf.writestr("config/cmd_config.json", config_content) + self._add_checksum("config/cmd_config.json", config_content) + if progress_callback: + await progress_callback("config", 100, 100, "配置文件导出完成") + + # 4. 导出附件文件 + if progress_callback: + await progress_callback("attachments", 0, 100, "正在导出附件...") + await self._export_attachments(zf, main_data.get("attachments", [])) + if progress_callback: + await progress_callback("attachments", 100, 100, "附件导出完成") + + # 5. 导出插件和其他目录 + if progress_callback: + await progress_callback( + "directories", 0, 100, "正在导出插件和数据目录..." + ) + dir_stats = await self._export_directories(zf) + if progress_callback: + await progress_callback("directories", 100, 100, "目录导出完成") + + # 6. 生成 manifest + if progress_callback: + await progress_callback("manifest", 0, 100, "正在生成清单...") + manifest = self._generate_manifest(main_data, kb_meta_data, dir_stats) + manifest_json = json.dumps(manifest, ensure_ascii=False, indent=2) + zf.writestr("manifest.json", manifest_json) + if progress_callback: + await progress_callback("manifest", 100, 100, "清单生成完成") + + logger.info(f"备份导出完成: {zip_path}") + return zip_path + + except Exception as e: + logger.error(f"备份导出失败: {e}") + # 清理失败的文件 + if os.path.exists(zip_path): + os.remove(zip_path) + raise + + async def _export_main_database(self) -> dict[str, list[dict]]: + """导出主数据库所有表""" + export_data: dict[str, list[dict]] = {} + + async with self.main_db.get_db() as session: + for table_name, model_class in MAIN_DB_MODELS.items(): + try: + result = await session.execute(select(model_class)) + records = result.scalars().all() + export_data[table_name] = [ + self._model_to_dict(record) for record in records + ] + logger.debug( + f"导出表 {table_name}: {len(export_data[table_name])} 条记录" + ) + except Exception as e: + logger.warning(f"导出表 {table_name} 失败: {e}") + export_data[table_name] = [] + + return export_data + + async def _export_kb_metadata(self) -> dict[str, list[dict]]: + """导出知识库元数据库""" + if not self.kb_manager: + return {"knowledge_bases": [], "kb_documents": [], "kb_media": []} + + export_data: dict[str, list[dict]] = {} + + async with self.kb_manager.kb_db.get_db() as session: + for table_name, model_class in KB_METADATA_MODELS.items(): + try: + result = await session.execute(select(model_class)) + records = result.scalars().all() + export_data[table_name] = [ + self._model_to_dict(record) for record in records + ] + logger.debug( + f"导出知识库表 {table_name}: {len(export_data[table_name])} 条记录" + ) + except Exception as e: + logger.warning(f"导出知识库表 {table_name} 失败: {e}") + export_data[table_name] = [] + + return export_data + + async def _export_kb_documents(self, kb_helper: Any) -> dict[str, Any]: + """导出知识库的文档块数据""" + try: + from astrbot.core.db.vec_db.faiss_impl.vec_db import FaissVecDB + + vec_db: FaissVecDB = kb_helper.vec_db + if not vec_db or not vec_db.document_storage: + return {"documents": []} + + # 获取所有文档 + docs = await vec_db.document_storage.get_documents( + metadata_filters={}, + offset=0, + limit=None, # 获取全部 + ) + + return {"documents": docs} + except Exception as e: + logger.warning(f"导出知识库文档失败: {e}") + return {"documents": []} + + async def _export_faiss_index( + self, + zf: zipfile.ZipFile, + kb_helper: Any, + kb_id: str, + ) -> None: + """导出 FAISS 索引文件""" + try: + index_path = kb_helper.kb_dir / "index.faiss" + if index_path.exists(): + archive_path = f"databases/kb_{kb_id}/index.faiss" + zf.write(str(index_path), archive_path) + logger.debug(f"导出 FAISS 索引: {archive_path}") + except Exception as e: + logger.warning(f"导出 FAISS 索引失败: {e}") + + async def _export_kb_media_files( + self, zf: zipfile.ZipFile, kb_helper: Any, kb_id: str + ) -> None: + """导出知识库的多媒体文件""" + try: + media_dir = kb_helper.kb_medias_dir + if not media_dir.exists(): + return + + for root, _, files in os.walk(media_dir): + for file in files: + file_path = Path(root) / file + # 计算相对路径 + rel_path = file_path.relative_to(kb_helper.kb_dir) + archive_path = f"files/kb_media/{kb_id}/{rel_path}" + zf.write(str(file_path), archive_path) + except Exception as e: + logger.warning(f"导出知识库媒体文件失败: {e}") + + async def _export_directories( + self, zf: zipfile.ZipFile + ) -> dict[str, dict[str, int]]: + """导出插件和其他数据目录 + + Returns: + dict: 每个目录的统计信息 {dir_name: {"files": count, "size": bytes}} + """ + stats: dict[str, dict[str, int]] = {} + backup_directories = get_backup_directories() + + for dir_name, dir_path in backup_directories.items(): + full_path = Path(dir_path) + if not full_path.exists(): + logger.debug(f"目录不存在,跳过: {full_path}") + continue + + file_count = 0 + total_size = 0 + + try: + for root, dirs, files in os.walk(full_path): + # 跳过 __pycache__ 目录 + dirs[:] = [d for d in dirs if d != "__pycache__"] + + for file in files: + # 跳过 .pyc 文件 + if file.endswith(".pyc"): + continue + + file_path = Path(root) / file + try: + # 计算相对路径 + rel_path = file_path.relative_to(full_path) + archive_path = f"directories/{dir_name}/{rel_path}" + zf.write(str(file_path), archive_path) + file_count += 1 + total_size += file_path.stat().st_size + except Exception as e: + logger.warning(f"导出文件 {file_path} 失败: {e}") + + stats[dir_name] = {"files": file_count, "size": total_size} + logger.debug( + f"导出目录 {dir_name}: {file_count} 个文件, {total_size} 字节" + ) + except Exception as e: + logger.warning(f"导出目录 {dir_path} 失败: {e}") + stats[dir_name] = {"files": 0, "size": 0} + + return stats + + async def _export_attachments( + self, zf: zipfile.ZipFile, attachments: list[dict] + ) -> None: + """导出附件文件""" + for attachment in attachments: + try: + file_path = attachment.get("path", "") + if file_path and os.path.exists(file_path): + # 使用 attachment_id 作为文件名 + attachment_id = attachment.get("attachment_id", "") + ext = os.path.splitext(file_path)[1] + archive_path = f"files/attachments/{attachment_id}{ext}" + zf.write(file_path, archive_path) + except Exception as e: + logger.warning(f"导出附件失败: {e}") + + def _model_to_dict(self, record: Any) -> dict: + """将 SQLModel 实例转换为字典 + + 这是数据库无关的序列化方式,支持未来迁移到其他数据库。 + """ + # 使用 SQLModel 内置的 model_dump 方法(如果可用) + if hasattr(record, "model_dump"): + data = record.model_dump(mode="python") + # 处理 datetime 类型 + for key, value in data.items(): + if isinstance(value, datetime): + data[key] = value.isoformat() + return data + + # 回退到手动提取 + data = {} + # 使用 inspect 获取表信息 + from sqlalchemy import inspect as sa_inspect + + mapper = sa_inspect(record.__class__) + for column in mapper.columns: + value = getattr(record, column.name) + # 处理 datetime 类型 - 统一转为 ISO 格式字符串 + if isinstance(value, datetime): + value = value.isoformat() + data[column.name] = value + return data + + def _add_checksum(self, path: str, content: str | bytes) -> None: + """计算并添加文件校验和""" + if isinstance(content, str): + content = content.encode("utf-8") + checksum = hashlib.sha256(content).hexdigest() + self._checksums[path] = f"sha256:{checksum}" + + def _generate_manifest( + self, + main_data: dict[str, list[dict]], + kb_meta_data: dict[str, list[dict]], + dir_stats: dict[str, dict[str, int]] | None = None, + ) -> dict: + """生成备份清单""" + if dir_stats is None: + dir_stats = {} + # 收集知识库 ID + kb_document_tables = {} + if self.kb_manager: + for kb_id in self.kb_manager.kb_insts.keys(): + kb_document_tables[kb_id] = "documents" + + # 收集附件文件列表 + attachment_files = [] + for attachment in main_data.get("attachments", []): + attachment_id = attachment.get("attachment_id", "") + path = attachment.get("path", "") + if attachment_id and path: + ext = os.path.splitext(path)[1] + attachment_files.append(f"{attachment_id}{ext}") + + # 收集知识库媒体文件 + kb_media_files: dict[str, list[str]] = {} + if self.kb_manager: + for kb_id, kb_helper in self.kb_manager.kb_insts.items(): + media_files: list[str] = [] + media_dir = kb_helper.kb_medias_dir + if media_dir.exists(): + for root, _, files in os.walk(media_dir): + for file in files: + media_files.append(file) + if media_files: + kb_media_files[kb_id] = media_files + + manifest = { + "version": BACKUP_MANIFEST_VERSION, + "astrbot_version": VERSION, + "exported_at": datetime.now(timezone.utc).isoformat(), + "schema_version": { + "main_db": "v4", + "kb_db": "v1", + }, + "tables": { + "main_db": list(main_data.keys()), + "kb_metadata": list(kb_meta_data.keys()), + "kb_documents": kb_document_tables, + }, + "files": { + "attachments": attachment_files, + "kb_media": kb_media_files, + }, + "directories": list(dir_stats.keys()), + "checksums": self._checksums, + "statistics": { + "main_db": { + table: len(records) for table, records in main_data.items() + }, + "kb_metadata": { + table: len(records) for table, records in kb_meta_data.items() + }, + "directories": dir_stats, + }, + } + + return manifest diff --git a/astrbot/core/backup/importer.py b/astrbot/core/backup/importer.py new file mode 100644 index 000000000..d129e8a33 --- /dev/null +++ b/astrbot/core/backup/importer.py @@ -0,0 +1,794 @@ +"""AstrBot 数据导入器 + +负责从 ZIP 备份文件恢复所有数据。 +导入时进行版本校验: +- 主版本(前两位)不同时直接拒绝导入 +- 小版本(第三位)不同时提示警告,用户可选择强制导入 +- 版本匹配时也需要用户确认 +""" + +import json +import os +import shutil +import zipfile +from dataclasses import dataclass, field +from datetime import datetime +from pathlib import Path +from typing import TYPE_CHECKING, Any + +from sqlalchemy import delete + +from astrbot.core import logger +from astrbot.core.config.default import VERSION +from astrbot.core.db import BaseDatabase + +# 从共享常量模块导入 +from .constants import ( + KB_METADATA_MODELS, + MAIN_DB_MODELS, + get_backup_directories, +) + +if TYPE_CHECKING: + from astrbot.core.knowledge_base.kb_mgr import KnowledgeBaseManager + + +def parse_version(version_str: str) -> tuple[int, ...]: + """将版本字符串解析为数值元组用于比较 + + Args: + version_str: 版本字符串,如 "1.0", "1.10", "2.0.1" + + Returns: + 数值元组,如 (1, 0), (1, 10), (2, 0, 1) + """ + try: + parts = version_str.split(".") + return tuple(int(p) for p in parts) + except (ValueError, AttributeError): + # 解析失败时返回 (0,),确保能够比较 + return (0,) + + +def compare_versions(v1: str, v2: str) -> int: + """比较两个版本号 + + Args: + v1: 第一个版本字符串 + v2: 第二个版本字符串 + + Returns: + -1 如果 v1 < v2 + 0 如果 v1 == v2 + 1 如果 v1 > v2 + """ + t1 = parse_version(v1) + t2 = parse_version(v2) + + # 补齐长度以便比较 + max_len = max(len(t1), len(t2)) + t1 = t1 + (0,) * (max_len - len(t1)) + t2 = t2 + (0,) * (max_len - len(t2)) + + if t1 < t2: + return -1 + elif t1 > t2: + return 1 + else: + return 0 + + +@dataclass +class ImportPreCheckResult: + """导入预检查结果 + + 用于在实际导入前检查备份文件的版本兼容性, + 并返回确认信息让用户决定是否继续导入。 + """ + + # 检查是否通过(文件有效且版本可导入) + valid: bool = False + # 是否可以导入(版本兼容) + can_import: bool = False + # 版本状态: match(完全匹配), minor_diff(小版本差异), major_diff(主版本不同,拒绝) + version_status: str = "" + # 备份文件中的 AstrBot 版本 + backup_version: str = "" + # 当前运行的 AstrBot 版本 + current_version: str = VERSION + # 备份创建时间 + backup_time: str = "" + # 确认消息(显示给用户) + confirm_message: str = "" + # 警告消息列表 + warnings: list[str] = field(default_factory=list) + # 错误消息(如果检查失败) + error: str = "" + # 备份包含的内容摘要 + backup_summary: dict = field(default_factory=dict) + + def to_dict(self) -> dict: + return { + "valid": self.valid, + "can_import": self.can_import, + "version_status": self.version_status, + "backup_version": self.backup_version, + "current_version": self.current_version, + "backup_time": self.backup_time, + "confirm_message": self.confirm_message, + "warnings": self.warnings, + "error": self.error, + "backup_summary": self.backup_summary, + } + + +class ImportResult: + """导入结果""" + + def __init__(self): + self.success = True + self.imported_tables: dict[str, int] = {} + self.imported_files: dict[str, int] = {} + self.imported_directories: dict[str, int] = {} + self.warnings: list[str] = [] + self.errors: list[str] = [] + + def add_warning(self, msg: str) -> None: + self.warnings.append(msg) + logger.warning(msg) + + def add_error(self, msg: str) -> None: + self.errors.append(msg) + self.success = False + logger.error(msg) + + def to_dict(self) -> dict: + return { + "success": self.success, + "imported_tables": self.imported_tables, + "imported_files": self.imported_files, + "imported_directories": self.imported_directories, + "warnings": self.warnings, + "errors": self.errors, + } + + +class AstrBotImporter: + """AstrBot 数据导入器 + + 导入备份文件中的所有数据,包括: + - 主数据库所有表 + - 知识库元数据和文档 + - 配置文件 + - 附件文件 + - 知识库多媒体文件 + - 插件目录(data/plugins) + - 插件数据目录(data/plugin_data) + - 配置目录(data/config) + - T2I 模板目录(data/t2i_templates) + - WebChat 数据目录(data/webchat) + - 临时文件目录(data/temp) + """ + + def __init__( + self, + main_db: BaseDatabase, + kb_manager: "KnowledgeBaseManager | None" = None, + config_path: str = "data/cmd_config.json", + attachments_dir: str = "data/attachments", + kb_root_dir: str = "data/knowledge_base", + data_root: str = "data", + ): + self.main_db = main_db + self.kb_manager = kb_manager + self.config_path = config_path + self.attachments_dir = attachments_dir + self.kb_root_dir = kb_root_dir + self.data_root = data_root + + def pre_check(self, zip_path: str) -> ImportPreCheckResult: + """预检查备份文件 + + 在实际导入前检查备份文件的有效性和版本兼容性。 + 返回检查结果供前端显示确认对话框。 + + Args: + zip_path: ZIP 备份文件路径 + + Returns: + ImportPreCheckResult: 预检查结果 + """ + result = ImportPreCheckResult() + result.current_version = VERSION + + if not os.path.exists(zip_path): + result.error = f"备份文件不存在: {zip_path}" + return result + + try: + with zipfile.ZipFile(zip_path, "r") as zf: + # 读取 manifest + try: + manifest_data = zf.read("manifest.json") + manifest = json.loads(manifest_data) + except KeyError: + result.error = "备份文件缺少 manifest.json,不是有效的 AstrBot 备份" + return result + except json.JSONDecodeError as e: + result.error = f"manifest.json 格式错误: {e}" + return result + + # 提取基本信息 + result.backup_version = manifest.get("astrbot_version", "未知") + result.backup_time = manifest.get("exported_at", "未知") + result.valid = True + + # 构建备份摘要 + result.backup_summary = { + "tables": list(manifest.get("tables", {}).keys()), + "has_knowledge_bases": manifest.get("has_knowledge_bases", False), + "has_config": manifest.get("has_config", False), + "directories": manifest.get("directories", []), + } + + # 检查版本兼容性 + version_check = self._check_version_compatibility(result.backup_version) + result.version_status = version_check["status"] + result.can_import = version_check["can_import"] + + # 版本信息由前端根据 version_status 和 i18n 生成显示 + # 不再将版本消息添加到 warnings 列表中,避免中文硬编码 + # warnings 列表保留用于其他非版本相关的警告 + + return result + + except zipfile.BadZipFile: + result.error = "无效的 ZIP 文件" + return result + except Exception as e: + result.error = f"检查备份文件失败: {e}" + return result + + def _check_version_compatibility(self, backup_version: str) -> dict: + """检查版本兼容性 + + 规则: + - 主版本(前两位,如 4.9)必须一致,否则拒绝 + - 小版本(第三位,如 4.9.1 vs 4.9.2)不同时,警告但允许导入 + + Returns: + dict: {status, can_import, message} + """ + if not backup_version: + return { + "status": "major_diff", + "can_import": False, + "message": "备份文件缺少版本信息", + } + + backup_parts = parse_version(backup_version) + current_parts = parse_version(VERSION) + + # 补齐到至少 2 位用于主版本比较 + backup_major = ( + backup_parts[:2] + if len(backup_parts) >= 2 + else backup_parts + (0,) * (2 - len(backup_parts)) + ) + current_major = ( + current_parts[:2] + if len(current_parts) >= 2 + else current_parts + (0,) * (2 - len(current_parts)) + ) + + if backup_major != current_major: + return { + "status": "major_diff", + "can_import": False, + "message": ( + f"主版本不兼容: 备份版本 {backup_version}, 当前版本 {VERSION}。" + f"跨主版本导入可能导致数据损坏,请使用相同主版本的 AstrBot。" + ), + } + + # 比较完整版本 + backup_full = backup_parts + (0,) * (3 - len(backup_parts)) + current_full = current_parts + (0,) * (3 - len(current_parts)) + + if backup_full != current_full: + return { + "status": "minor_diff", + "can_import": True, + "message": ( + f"小版本差异: 备份版本 {backup_version}, 当前版本 {VERSION}。" + ), + } + + return { + "status": "match", + "can_import": True, + "message": "版本匹配", + } + + async def import_all( + self, + zip_path: str, + mode: str = "replace", # "replace" 清空后导入 + progress_callback: Any | None = None, + ) -> ImportResult: + """从 ZIP 文件导入所有数据 + + Args: + zip_path: ZIP 备份文件路径 + mode: 导入模式,目前仅支持 "replace"(清空后导入) + progress_callback: 进度回调函数,接收参数 (stage, current, total, message) + + Returns: + ImportResult: 导入结果 + """ + result = ImportResult() + + if not os.path.exists(zip_path): + result.add_error(f"备份文件不存在: {zip_path}") + return result + + logger.info(f"开始从 {zip_path} 导入备份") + + try: + with zipfile.ZipFile(zip_path, "r") as zf: + # 1. 读取并验证 manifest + if progress_callback: + await progress_callback("validate", 0, 100, "正在验证备份文件...") + + try: + manifest_data = zf.read("manifest.json") + manifest = json.loads(manifest_data) + except KeyError: + result.add_error("备份文件缺少 manifest.json") + return result + except json.JSONDecodeError as e: + result.add_error(f"manifest.json 格式错误: {e}") + return result + + # 版本校验 + try: + self._validate_version(manifest) + except ValueError as e: + result.add_error(str(e)) + return result + + if progress_callback: + await progress_callback("validate", 100, 100, "验证完成") + + # 2. 导入主数据库 + if progress_callback: + await progress_callback("main_db", 0, 100, "正在导入主数据库...") + + try: + main_data_content = zf.read("databases/main_db.json") + main_data = json.loads(main_data_content) + + if mode == "replace": + await self._clear_main_db() + + imported = await self._import_main_database(main_data) + result.imported_tables.update(imported) + except Exception as e: + result.add_error(f"导入主数据库失败: {e}") + return result + + if progress_callback: + await progress_callback("main_db", 100, 100, "主数据库导入完成") + + # 3. 导入知识库 + if self.kb_manager and "databases/kb_metadata.json" in zf.namelist(): + if progress_callback: + await progress_callback("kb", 0, 100, "正在导入知识库...") + + try: + kb_meta_content = zf.read("databases/kb_metadata.json") + kb_meta_data = json.loads(kb_meta_content) + + if mode == "replace": + await self._clear_kb_data() + + await self._import_knowledge_bases(zf, kb_meta_data, result) + except Exception as e: + result.add_warning(f"导入知识库失败: {e}") + + if progress_callback: + await progress_callback("kb", 100, 100, "知识库导入完成") + + # 4. 导入配置文件 + if progress_callback: + await progress_callback("config", 0, 100, "正在导入配置文件...") + + if "config/cmd_config.json" in zf.namelist(): + try: + config_content = zf.read("config/cmd_config.json") + # 备份现有配置 + if os.path.exists(self.config_path): + backup_path = f"{self.config_path}.bak" + shutil.copy2(self.config_path, backup_path) + + with open(self.config_path, "wb") as f: + f.write(config_content) + result.imported_files["config"] = 1 + except Exception as e: + result.add_warning(f"导入配置文件失败: {e}") + + if progress_callback: + await progress_callback("config", 100, 100, "配置文件导入完成") + + # 5. 导入附件文件 + if progress_callback: + await progress_callback("attachments", 0, 100, "正在导入附件...") + + attachment_count = await self._import_attachments( + zf, main_data.get("attachments", []) + ) + result.imported_files["attachments"] = attachment_count + + if progress_callback: + await progress_callback("attachments", 100, 100, "附件导入完成") + + # 6. 导入插件和其他目录 + if progress_callback: + await progress_callback( + "directories", 0, 100, "正在导入插件和数据目录..." + ) + + dir_stats = await self._import_directories(zf, manifest, result) + result.imported_directories = dir_stats + + if progress_callback: + await progress_callback("directories", 100, 100, "目录导入完成") + + logger.info(f"备份导入完成: {result.to_dict()}") + return result + + except zipfile.BadZipFile: + result.add_error("无效的 ZIP 文件") + return result + except Exception as e: + result.add_error(f"导入失败: {e}") + return result + + def _validate_version(self, manifest: dict) -> None: + """验证版本兼容性 - 仅允许相同主版本导入 + + 注意:此方法仅在 import_all 中调用,用于双重校验。 + 前端应先调用 pre_check 获取详细的版本信息并让用户确认。 + """ + backup_version = manifest.get("astrbot_version") + if not backup_version: + raise ValueError("备份文件缺少版本信息") + + # 使用新的版本兼容性检查 + version_check = self._check_version_compatibility(backup_version) + + if version_check["status"] == "major_diff": + raise ValueError(version_check["message"]) + + # minor_diff 和 match 都允许导入 + if version_check["status"] == "minor_diff": + logger.warning(f"版本差异警告: {version_check['message']}") + + async def _clear_main_db(self) -> None: + """清空主数据库所有表""" + async with self.main_db.get_db() as session: + async with session.begin(): + for table_name, model_class in MAIN_DB_MODELS.items(): + try: + await session.execute(delete(model_class)) + logger.debug(f"已清空表 {table_name}") + except Exception as e: + logger.warning(f"清空表 {table_name} 失败: {e}") + + async def _clear_kb_data(self) -> None: + """清空知识库数据""" + if not self.kb_manager: + return + + # 清空知识库元数据表 + async with self.kb_manager.kb_db.get_db() as session: + async with session.begin(): + for table_name, model_class in KB_METADATA_MODELS.items(): + try: + await session.execute(delete(model_class)) + logger.debug(f"已清空知识库表 {table_name}") + except Exception as e: + logger.warning(f"清空知识库表 {table_name} 失败: {e}") + + # 删除知识库文件目录 + for kb_id in list(self.kb_manager.kb_insts.keys()): + try: + kb_helper = self.kb_manager.kb_insts[kb_id] + await kb_helper.terminate() + if kb_helper.kb_dir.exists(): + shutil.rmtree(kb_helper.kb_dir) + except Exception as e: + logger.warning(f"清理知识库 {kb_id} 失败: {e}") + + self.kb_manager.kb_insts.clear() + + async def _import_main_database( + self, data: dict[str, list[dict]] + ) -> dict[str, int]: + """导入主数据库数据""" + imported: dict[str, int] = {} + + async with self.main_db.get_db() as session: + async with session.begin(): + for table_name, rows in data.items(): + model_class = MAIN_DB_MODELS.get(table_name) + if not model_class: + logger.warning(f"未知的表: {table_name}") + continue + + count = 0 + for row in rows: + try: + # 转换 datetime 字符串为 datetime 对象 + row = self._convert_datetime_fields(row, model_class) + obj = model_class(**row) + session.add(obj) + count += 1 + except Exception as e: + logger.warning(f"导入记录到 {table_name} 失败: {e}") + + imported[table_name] = count + logger.debug(f"导入表 {table_name}: {count} 条记录") + + return imported + + async def _import_knowledge_bases( + self, + zf: zipfile.ZipFile, + kb_meta_data: dict[str, list[dict]], + result: ImportResult, + ) -> None: + """导入知识库数据""" + if not self.kb_manager: + return + + # 1. 导入知识库元数据 + async with self.kb_manager.kb_db.get_db() as session: + async with session.begin(): + for table_name, rows in kb_meta_data.items(): + model_class = KB_METADATA_MODELS.get(table_name) + if not model_class: + continue + + count = 0 + for row in rows: + try: + row = self._convert_datetime_fields(row, model_class) + obj = model_class(**row) + session.add(obj) + count += 1 + except Exception as e: + logger.warning(f"导入知识库记录到 {table_name} 失败: {e}") + + result.imported_tables[f"kb_{table_name}"] = count + + # 2. 导入每个知识库的文档和文件 + for kb_data in kb_meta_data.get("knowledge_bases", []): + kb_id = kb_data.get("kb_id") + if not kb_id: + continue + + # 创建知识库目录 + kb_dir = Path(self.kb_root_dir) / kb_id + kb_dir.mkdir(parents=True, exist_ok=True) + + # 导入文档数据 + doc_path = f"databases/kb_{kb_id}/documents.json" + if doc_path in zf.namelist(): + try: + doc_content = zf.read(doc_path) + doc_data = json.loads(doc_content) + + # 导入到文档存储数据库 + await self._import_kb_documents(kb_id, doc_data) + except Exception as e: + result.add_warning(f"导入知识库 {kb_id} 的文档失败: {e}") + + # 导入 FAISS 索引 + faiss_path = f"databases/kb_{kb_id}/index.faiss" + if faiss_path in zf.namelist(): + try: + target_path = kb_dir / "index.faiss" + with zf.open(faiss_path) as src, open(target_path, "wb") as dst: + dst.write(src.read()) + except Exception as e: + result.add_warning(f"导入知识库 {kb_id} 的 FAISS 索引失败: {e}") + + # 导入媒体文件 + media_prefix = f"files/kb_media/{kb_id}/" + for name in zf.namelist(): + if name.startswith(media_prefix): + try: + rel_path = name[len(media_prefix) :] + target_path = kb_dir / rel_path + target_path.parent.mkdir(parents=True, exist_ok=True) + with zf.open(name) as src, open(target_path, "wb") as dst: + dst.write(src.read()) + except Exception as e: + result.add_warning(f"导入媒体文件 {name} 失败: {e}") + + # 3. 重新加载知识库实例 + await self.kb_manager.load_kbs() + + async def _import_kb_documents(self, kb_id: str, doc_data: dict) -> None: + """导入知识库文档到向量数据库""" + from astrbot.core.db.vec_db.faiss_impl.document_storage import DocumentStorage + + kb_dir = Path(self.kb_root_dir) / kb_id + doc_db_path = kb_dir / "doc.db" + + # 初始化文档存储 + doc_storage = DocumentStorage(str(doc_db_path)) + await doc_storage.initialize() + + try: + documents = doc_data.get("documents", []) + for doc in documents: + try: + await doc_storage.insert_document( + doc_id=doc.get("doc_id", ""), + text=doc.get("text", ""), + metadata=json.loads(doc.get("metadata", "{}")), + ) + except Exception as e: + logger.warning(f"导入文档块失败: {e}") + finally: + await doc_storage.close() + + async def _import_attachments( + self, + zf: zipfile.ZipFile, + attachments: list[dict], + ) -> int: + """导入附件文件""" + count = 0 + + # 确保附件目录存在 + Path(self.attachments_dir).mkdir(parents=True, exist_ok=True) + + attachment_prefix = "files/attachments/" + for name in zf.namelist(): + if name.startswith(attachment_prefix) and name != attachment_prefix: + try: + # 从附件记录中找到原始路径 + attachment_id = os.path.splitext(os.path.basename(name))[0] + original_path = None + for att in attachments: + if att.get("attachment_id") == attachment_id: + original_path = att.get("path") + break + + if original_path: + target_path = Path(original_path) + else: + target_path = Path(self.attachments_dir) / os.path.basename( + name + ) + + target_path.parent.mkdir(parents=True, exist_ok=True) + with zf.open(name) as src, open(target_path, "wb") as dst: + dst.write(src.read()) + count += 1 + except Exception as e: + logger.warning(f"导入附件 {name} 失败: {e}") + + return count + + async def _import_directories( + self, + zf: zipfile.ZipFile, + manifest: dict, + result: ImportResult, + ) -> dict[str, int]: + """导入插件和其他数据目录 + + Args: + zf: ZIP 文件对象 + manifest: 备份清单 + result: 导入结果对象 + + Returns: + dict: 每个目录导入的文件数量 + """ + dir_stats: dict[str, int] = {} + + # 检查备份版本是否支持目录备份(需要版本 >= 1.1) + backup_version = manifest.get("version", "1.0") + if compare_versions(backup_version, "1.1") < 0: + logger.info("备份版本不支持目录备份,跳过目录导入") + return dir_stats + + backed_up_dirs = manifest.get("directories", []) + backup_directories = get_backup_directories() + + for dir_name in backed_up_dirs: + if dir_name not in backup_directories: + result.add_warning(f"未知的目录类型: {dir_name}") + continue + + target_dir = Path(backup_directories[dir_name]) + archive_prefix = f"directories/{dir_name}/" + + file_count = 0 + + try: + # 获取该目录下的所有文件 + dir_files = [ + name + for name in zf.namelist() + if name.startswith(archive_prefix) and name != archive_prefix + ] + + if not dir_files: + continue + + # 备份现有目录(如果存在) + if target_dir.exists(): + backup_path = Path(f"{target_dir}.bak") + if backup_path.exists(): + shutil.rmtree(backup_path) + shutil.move(str(target_dir), str(backup_path)) + logger.debug(f"已备份现有目录 {target_dir} 到 {backup_path}") + + # 创建目标目录 + target_dir.mkdir(parents=True, exist_ok=True) + + # 解压文件 + for name in dir_files: + try: + # 计算相对路径 + rel_path = name[len(archive_prefix) :] + if not rel_path: # 跳过目录条目 + continue + + target_path = target_dir / rel_path + target_path.parent.mkdir(parents=True, exist_ok=True) + + with zf.open(name) as src, open(target_path, "wb") as dst: + dst.write(src.read()) + file_count += 1 + except Exception as e: + result.add_warning(f"导入文件 {name} 失败: {e}") + + dir_stats[dir_name] = file_count + logger.debug(f"导入目录 {dir_name}: {file_count} 个文件") + + except Exception as e: + result.add_warning(f"导入目录 {dir_name} 失败: {e}") + dir_stats[dir_name] = 0 + + return dir_stats + + def _convert_datetime_fields(self, row: dict, model_class: type) -> dict: + """转换 datetime 字符串字段为 datetime 对象""" + result = row.copy() + + # 获取模型的 datetime 字段 + from sqlalchemy import inspect as sa_inspect + + try: + mapper = sa_inspect(model_class) + for column in mapper.columns: + if column.name in result and result[column.name] is not None: + # 检查是否是 datetime 类型的列 + from sqlalchemy import DateTime + + if isinstance(column.type, DateTime): + value = result[column.name] + if isinstance(value, str): + # 解析 ISO 格式的日期时间字符串 + result[column.name] = datetime.fromisoformat(value) + except Exception: + pass + + return result diff --git a/astrbot/core/utils/astrbot_path.py b/astrbot/core/utils/astrbot_path.py index e13379b92..f0fa52bbb 100644 --- a/astrbot/core/utils/astrbot_path.py +++ b/astrbot/core/utils/astrbot_path.py @@ -5,6 +5,10 @@ 数据目录路径:固定为根目录下的 data 目录 配置文件路径:固定为数据目录下的 config 目录 插件目录路径:固定为数据目录下的 plugins 目录 +插件数据目录路径:固定为数据目录下的 plugin_data 目录 +T2I 模板目录路径:固定为数据目录下的 t2i_templates 目录 +WebChat 数据目录路径:固定为数据目录下的 webchat 目录 +临时文件目录路径:固定为数据目录下的 temp 目录 """ import os @@ -37,3 +41,23 @@ def get_astrbot_config_path() -> str: def get_astrbot_plugin_path() -> str: """获取Astrbot插件目录路径""" return os.path.realpath(os.path.join(get_astrbot_data_path(), "plugins")) + + +def get_astrbot_plugin_data_path() -> str: + """获取Astrbot插件数据目录路径""" + return os.path.realpath(os.path.join(get_astrbot_data_path(), "plugin_data")) + + +def get_astrbot_t2i_templates_path() -> str: + """获取Astrbot T2I 模板目录路径""" + return os.path.realpath(os.path.join(get_astrbot_data_path(), "t2i_templates")) + + +def get_astrbot_webchat_path() -> str: + """获取Astrbot WebChat 数据目录路径""" + return os.path.realpath(os.path.join(get_astrbot_data_path(), "webchat")) + + +def get_astrbot_temp_path() -> str: + """获取Astrbot临时文件目录路径""" + return os.path.realpath(os.path.join(get_astrbot_data_path(), "temp")) diff --git a/astrbot/dashboard/routes/__init__.py b/astrbot/dashboard/routes/__init__.py index 951db956c..bca1a2268 100644 --- a/astrbot/dashboard/routes/__init__.py +++ b/astrbot/dashboard/routes/__init__.py @@ -1,4 +1,5 @@ from .auth import AuthRoute +from .backup import BackupRoute from .chat import ChatRoute from .command import CommandRoute from .config import ConfigRoute @@ -17,6 +18,7 @@ __all__ = [ "AuthRoute", + "BackupRoute", "ChatRoute", "CommandRoute", "ConfigRoute", diff --git a/astrbot/dashboard/routes/backup.py b/astrbot/dashboard/routes/backup.py new file mode 100644 index 000000000..0445810b7 --- /dev/null +++ b/astrbot/dashboard/routes/backup.py @@ -0,0 +1,590 @@ +"""备份管理 API 路由""" + +import asyncio +import os +import re +import traceback +import uuid +from datetime import datetime +from pathlib import Path + +from quart import request, send_file + +from astrbot.core import logger +from astrbot.core.backup.exporter import AstrBotExporter +from astrbot.core.backup.importer import AstrBotImporter +from astrbot.core.core_lifecycle import AstrBotCoreLifecycle +from astrbot.core.db import BaseDatabase + +from .route import Response, Route, RouteContext + + +def secure_filename(filename: str) -> str: + """清洗文件名,移除路径遍历字符和危险字符 + + Args: + filename: 原始文件名 + + Returns: + 安全的文件名 + """ + # 跨平台处理:先将反斜杠替换为正斜杠,再取文件名 + filename = filename.replace("\\", "/") + # 仅保留文件名部分,移除路径 + filename = os.path.basename(filename) + + # 替换路径遍历字符 + filename = filename.replace("..", "_") + + # 仅保留字母、数字、下划线、连字符、点 + filename = re.sub(r"[^\w\-.]", "_", filename) + + # 移除前导点(隐藏文件)和尾部点 + filename = filename.strip(".") + + # 如果文件名为空或只包含下划线,生成一个默认名称 + if not filename or filename.replace("_", "") == "": + filename = "backup" + + return filename + + +def generate_unique_filename(original_filename: str) -> str: + """生成唯一的文件名,添加时间戳前缀 + + Args: + original_filename: 原始文件名(已清洗) + + Returns: + 唯一的文件名 + """ + timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") + name, ext = os.path.splitext(original_filename) + return f"uploaded_{timestamp}_{name}{ext}" + + +class BackupRoute(Route): + """备份管理路由 + + 提供备份导出、导入、列表等 API 接口 + """ + + def __init__( + self, + context: RouteContext, + db: BaseDatabase, + core_lifecycle: AstrBotCoreLifecycle, + ) -> None: + super().__init__(context) + self.db = db + self.core_lifecycle = core_lifecycle + self.backup_dir = "data/backups" + + # 任务状态跟踪 + self.backup_tasks: dict[str, dict] = {} + self.backup_progress: dict[str, dict] = {} + + # 注册路由 + self.routes = { + "/backup/list": ("GET", self.list_backups), + "/backup/export": ("POST", self.export_backup), + "/backup/upload": ("POST", self.upload_backup), # 上传文件 + "/backup/check": ("POST", self.check_backup), # 预检查 + "/backup/import": ("POST", self.import_backup), # 确认导入 + "/backup/progress": ("GET", self.get_progress), + "/backup/download": ("GET", self.download_backup), + "/backup/delete": ("POST", self.delete_backup), + } + self.register_routes() + + def _init_task(self, task_id: str, task_type: str, status: str = "pending") -> None: + """初始化任务状态""" + self.backup_tasks[task_id] = { + "type": task_type, + "status": status, + "result": None, + "error": None, + } + self.backup_progress[task_id] = { + "status": status, + "stage": "waiting", + "current": 0, + "total": 100, + "message": "", + } + + def _set_task_result( + self, + task_id: str, + status: str, + result: dict | None = None, + error: str | None = None, + ) -> None: + """设置任务结果""" + if task_id in self.backup_tasks: + self.backup_tasks[task_id]["status"] = status + self.backup_tasks[task_id]["result"] = result + self.backup_tasks[task_id]["error"] = error + if task_id in self.backup_progress: + self.backup_progress[task_id]["status"] = status + + def _update_progress( + self, + task_id: str, + *, + status: str | None = None, + stage: str | None = None, + current: int | None = None, + total: int | None = None, + message: str | None = None, + ) -> None: + """更新任务进度""" + if task_id not in self.backup_progress: + return + p = self.backup_progress[task_id] + if status is not None: + p["status"] = status + if stage is not None: + p["stage"] = stage + if current is not None: + p["current"] = current + if total is not None: + p["total"] = total + if message is not None: + p["message"] = message + + def _make_progress_callback(self, task_id: str): + """创建进度回调函数""" + + async def _callback(stage: str, current: int, total: int, message: str = ""): + self._update_progress( + task_id, + status="processing", + stage=stage, + current=current, + total=total, + message=message, + ) + + return _callback + + async def list_backups(self): + """获取备份列表 + + Query 参数: + - page: 页码 (默认 1) + - page_size: 每页数量 (默认 20) + """ + try: + page = request.args.get("page", 1, type=int) + page_size = request.args.get("page_size", 20, type=int) + + # 确保备份目录存在 + Path(self.backup_dir).mkdir(parents=True, exist_ok=True) + + # 获取所有备份文件 + backup_files = [] + for filename in os.listdir(self.backup_dir): + if filename.endswith(".zip") and filename.startswith("astrbot_backup_"): + file_path = os.path.join(self.backup_dir, filename) + stat = os.stat(file_path) + backup_files.append( + { + "filename": filename, + "size": stat.st_size, + "created_at": stat.st_mtime, + } + ) + + # 按创建时间倒序排序 + backup_files.sort(key=lambda x: x["created_at"], reverse=True) + + # 分页 + start = (page - 1) * page_size + end = start + page_size + items = backup_files[start:end] + + return ( + Response() + .ok( + { + "items": items, + "total": len(backup_files), + "page": page, + "page_size": page_size, + } + ) + .__dict__ + ) + except Exception as e: + logger.error(f"获取备份列表失败: {e}") + logger.error(traceback.format_exc()) + return Response().error(f"获取备份列表失败: {e!s}").__dict__ + + async def export_backup(self): + """创建备份 + + 返回: + - task_id: 任务ID,用于查询导出进度 + """ + try: + # 生成任务ID + task_id = str(uuid.uuid4()) + + # 初始化任务状态 + self._init_task(task_id, "export", "pending") + + # 启动后台导出任务 + asyncio.create_task(self._background_export_task(task_id)) + + return ( + Response() + .ok( + { + "task_id": task_id, + "message": "export task created, processing in background", + } + ) + .__dict__ + ) + except Exception as e: + logger.error(f"创建备份失败: {e}") + logger.error(traceback.format_exc()) + return Response().error(f"创建备份失败: {e!s}").__dict__ + + async def _background_export_task(self, task_id: str): + """后台导出任务""" + try: + self._update_progress(task_id, status="processing", message="正在初始化...") + + # 获取知识库管理器 + kb_manager = getattr(self.core_lifecycle, "kb_manager", None) + + exporter = AstrBotExporter( + main_db=self.db, + kb_manager=kb_manager, + config_path="data/cmd_config.json", + attachments_dir="data/attachments", + data_root="data", + ) + + # 创建进度回调 + progress_callback = self._make_progress_callback(task_id) + + # 执行导出 + zip_path = await exporter.export_all( + output_dir=self.backup_dir, + progress_callback=progress_callback, + ) + + # 设置成功结果 + self._set_task_result( + task_id, + "completed", + result={ + "filename": os.path.basename(zip_path), + "path": zip_path, + "size": os.path.getsize(zip_path), + }, + ) + except Exception as e: + logger.error(f"后台导出任务 {task_id} 失败: {e}") + logger.error(traceback.format_exc()) + self._set_task_result(task_id, "failed", error=str(e)) + + async def upload_backup(self): + """上传备份文件 + + 将备份文件上传到服务器,返回保存的文件名。 + 上传后应调用 check_backup 进行预检查。 + + Form Data: + - file: 备份文件 (.zip) + + 返回: + - filename: 保存的文件名 + """ + try: + files = await request.files + if "file" not in files: + return Response().error("缺少备份文件").__dict__ + + file = files["file"] + if not file.filename or not file.filename.endswith(".zip"): + return Response().error("请上传 ZIP 格式的备份文件").__dict__ + + # 清洗文件名并生成唯一名称,防止路径遍历和覆盖 + safe_filename = secure_filename(file.filename) + unique_filename = generate_unique_filename(safe_filename) + + # 保存上传的文件 + Path(self.backup_dir).mkdir(parents=True, exist_ok=True) + zip_path = os.path.join(self.backup_dir, unique_filename) + await file.save(zip_path) + + logger.info( + f"上传的备份文件已保存: {unique_filename} (原始名称: {file.filename})" + ) + + return ( + Response() + .ok( + { + "filename": unique_filename, + "original_filename": file.filename, + "size": os.path.getsize(zip_path), + } + ) + .__dict__ + ) + except Exception as e: + logger.error(f"上传备份文件失败: {e}") + logger.error(traceback.format_exc()) + return Response().error(f"上传备份文件失败: {e!s}").__dict__ + + async def check_backup(self): + """预检查备份文件 + + 检查备份文件的版本兼容性,返回确认信息。 + 用户确认后调用 import_backup 执行导入。 + + JSON Body: + - filename: 已上传的备份文件名 + + 返回: + - ImportPreCheckResult: 预检查结果 + """ + try: + data = await request.json + filename = data.get("filename") + if not filename: + return Response().error("缺少 filename 参数").__dict__ + + # 安全检查 - 防止路径遍历 + if ".." in filename or "/" in filename or "\\" in filename: + return Response().error("无效的文件名").__dict__ + + zip_path = os.path.join(self.backup_dir, filename) + if not os.path.exists(zip_path): + return Response().error(f"备份文件不存在: {filename}").__dict__ + + # 获取知识库管理器(用于构造 importer) + kb_manager = getattr(self.core_lifecycle, "kb_manager", None) + + importer = AstrBotImporter( + main_db=self.db, + kb_manager=kb_manager, + config_path="data/cmd_config.json", + attachments_dir="data/attachments", + data_root="data", + ) + + # 执行预检查 + check_result = importer.pre_check(zip_path) + + return Response().ok(check_result.to_dict()).__dict__ + except Exception as e: + logger.error(f"预检查备份文件失败: {e}") + logger.error(traceback.format_exc()) + return Response().error(f"预检查备份文件失败: {e!s}").__dict__ + + async def import_backup(self): + """执行备份导入 + + 在用户确认后执行实际的导入操作。 + 需要先调用 upload_backup 上传文件,再调用 check_backup 预检查。 + + JSON Body: + - filename: 已上传的备份文件名(必填) + - confirmed: 用户已确认(必填,必须为 true) + + 返回: + - task_id: 任务ID,用于查询导入进度 + """ + try: + data = await request.json + filename = data.get("filename") + confirmed = data.get("confirmed", False) + + if not filename: + return Response().error("缺少 filename 参数").__dict__ + + if not confirmed: + return ( + Response() + .error("请先确认导入。导入将会清空并覆盖现有数据,此操作不可撤销。") + .__dict__ + ) + + # 安全检查 - 防止路径遍历 + if ".." in filename or "/" in filename or "\\" in filename: + return Response().error("无效的文件名").__dict__ + + zip_path = os.path.join(self.backup_dir, filename) + if not os.path.exists(zip_path): + return Response().error(f"备份文件不存在: {filename}").__dict__ + + # 生成任务ID + task_id = str(uuid.uuid4()) + + # 初始化任务状态 + self._init_task(task_id, "import", "pending") + + # 启动后台导入任务 + asyncio.create_task(self._background_import_task(task_id, zip_path)) + + return ( + Response() + .ok( + { + "task_id": task_id, + "message": "import task created, processing in background", + } + ) + .__dict__ + ) + except Exception as e: + logger.error(f"导入备份失败: {e}") + logger.error(traceback.format_exc()) + return Response().error(f"导入备份失败: {e!s}").__dict__ + + async def _background_import_task(self, task_id: str, zip_path: str): + """后台导入任务""" + try: + self._update_progress(task_id, status="processing", message="正在初始化...") + + # 获取知识库管理器 + kb_manager = getattr(self.core_lifecycle, "kb_manager", None) + + importer = AstrBotImporter( + main_db=self.db, + kb_manager=kb_manager, + config_path="data/cmd_config.json", + attachments_dir="data/attachments", + data_root="data", + ) + + # 创建进度回调 + progress_callback = self._make_progress_callback(task_id) + + # 执行导入 + result = await importer.import_all( + zip_path=zip_path, + mode="replace", + progress_callback=progress_callback, + ) + + # 设置结果 + if result.success: + self._set_task_result( + task_id, + "completed", + result=result.to_dict(), + ) + else: + self._set_task_result( + task_id, + "failed", + error="; ".join(result.errors), + ) + except Exception as e: + logger.error(f"后台导入任务 {task_id} 失败: {e}") + logger.error(traceback.format_exc()) + self._set_task_result(task_id, "failed", error=str(e)) + + async def get_progress(self): + """获取任务进度 + + Query 参数: + - task_id: 任务 ID (必填) + """ + try: + task_id = request.args.get("task_id") + if not task_id: + return Response().error("缺少参数 task_id").__dict__ + + if task_id not in self.backup_tasks: + return Response().error("找不到该任务").__dict__ + + task_info = self.backup_tasks[task_id] + status = task_info["status"] + + response_data = { + "task_id": task_id, + "type": task_info["type"], + "status": status, + } + + # 如果任务正在处理,返回进度信息 + if status == "processing" and task_id in self.backup_progress: + response_data["progress"] = self.backup_progress[task_id] + + # 如果任务完成,返回结果 + if status == "completed": + response_data["result"] = task_info["result"] + + # 如果任务失败,返回错误信息 + if status == "failed": + response_data["error"] = task_info["error"] + + return Response().ok(response_data).__dict__ + except Exception as e: + logger.error(f"获取任务进度失败: {e}") + logger.error(traceback.format_exc()) + return Response().error(f"获取任务进度失败: {e!s}").__dict__ + + async def download_backup(self): + """下载备份文件 + + Query 参数: + - filename: 备份文件名 (必填) + """ + try: + filename = request.args.get("filename") + if not filename: + return Response().error("缺少参数 filename").__dict__ + + # 安全检查 - 防止路径遍历 + if ".." in filename or "/" in filename or "\\" in filename: + return Response().error("无效的文件名").__dict__ + + file_path = os.path.join(self.backup_dir, filename) + if not os.path.exists(file_path): + return Response().error("备份文件不存在").__dict__ + + return await send_file( + file_path, + as_attachment=True, + attachment_filename=filename, + ) + except Exception as e: + logger.error(f"下载备份失败: {e}") + logger.error(traceback.format_exc()) + return Response().error(f"下载备份失败: {e!s}").__dict__ + + async def delete_backup(self): + """删除备份文件 + + Body: + - filename: 备份文件名 (必填) + """ + try: + data = await request.json + filename = data.get("filename") + if not filename: + return Response().error("缺少参数 filename").__dict__ + + # 安全检查 - 防止路径遍历 + if ".." in filename or "/" in filename or "\\" in filename: + return Response().error("无效的文件名").__dict__ + + file_path = os.path.join(self.backup_dir, filename) + if not os.path.exists(file_path): + return Response().error("备份文件不存在").__dict__ + + os.remove(file_path) + return Response().ok(message="删除备份成功").__dict__ + except Exception as e: + logger.error(f"删除备份失败: {e}") + logger.error(traceback.format_exc()) + return Response().error(f"删除备份失败: {e!s}").__dict__ diff --git a/astrbot/dashboard/server.py b/astrbot/dashboard/server.py index 6d6530c90..ad258b824 100644 --- a/astrbot/dashboard/server.py +++ b/astrbot/dashboard/server.py @@ -19,6 +19,7 @@ from astrbot.core.utils.io import get_local_ip_addresses from .routes import * +from .routes.backup import BackupRoute from .routes.platform import PlatformRoute from .routes.route import Response, RouteContext from .routes.session_management import SessionManagementRoute @@ -85,6 +86,7 @@ def __init__( self.t2i_route = T2iRoute(self.context, core_lifecycle) self.kb_route = KnowledgeBaseRoute(self.context, core_lifecycle) self.platform_route = PlatformRoute(self.context, core_lifecycle) + self.backup_route = BackupRoute(self.context, db, core_lifecycle) self.app.add_url_rule( "/api/plug/", diff --git a/dashboard/src/components/shared/BackupDialog.vue b/dashboard/src/components/shared/BackupDialog.vue new file mode 100644 index 000000000..629e4e559 --- /dev/null +++ b/dashboard/src/components/shared/BackupDialog.vue @@ -0,0 +1,673 @@ + + + + + \ No newline at end of file diff --git a/dashboard/src/i18n/locales/en-US/features/settings.json b/dashboard/src/i18n/locales/en-US/features/settings.json index 0a494ca3e..3715bb35a 100644 --- a/dashboard/src/i18n/locales/en-US/features/settings.json +++ b/dashboard/src/i18n/locales/en-US/features/settings.json @@ -18,6 +18,11 @@ "title": "Data Migration to v4.0.0", "subtitle": "If you encounter data compatibility issues, you can manually start the database migration assistant", "button": "Start Migration Assistant" + }, + "backup": { + "title": "Backup & Restore", + "subtitle": "Export or import all AstrBot data for easy migration to a new server", + "button": "Backup Manager" } }, "sidebar": { @@ -29,5 +34,66 @@ "mainItems": "Main Modules", "moreItems": "More Features" } + }, + "backup": { + "dialog": { + "title": "Backup Manager" + }, + "tabs": { + "export": "Export Backup", + "import": "Import Backup", + "list": "Backup List" + }, + "export": { + "title": "Create Backup", + "description": "Export all data as a ZIP backup file, including database, knowledge base, config and attachments.", + "includes": "Backup includes: Main database, Knowledge bases (metadata + vector index + documents), Config files, Attachment files", + "button": "Start Export", + "processing": "Exporting...", + "wait": "Please wait, packaging data...", + "completed": "Export Completed!", + "download": "Download Backup", + "another": "Create New Backup", + "failed": "Export Failed", + "retry": "Retry" + }, + "import": { + "title": "Import Backup", + "warning": "⚠️ Import will clear and overwrite existing data! Please make sure you have backed up your current data.", + "selectFile": "Select backup file (.zip)", + "uploadAndCheck": "Upload & Check", + "uploading": "Uploading...", + "uploadWait": "Please wait, uploading backup file...", + "invalidBackup": "Invalid backup file", + "backupContents": "Backup Contents", + "tables": "tables", + "knowledgeBases": "Knowledge Bases", + "configFiles": "Config Files", + "confirmImport": "Confirm Import", + "button": "Start Import", + "processing": "Importing...", + "wait": "Please wait, restoring data...", + "completed": "Import Completed!", + "restartRequired": "Data has been successfully imported. It is recommended to restart AstrBot immediately for all changes to take effect.", + "restartNow": "Restart Now", + "failed": "Import Failed", + "retry": "Retry", + "version": { + "backupVersion": "Backup Version", + "currentVersion": "Current Version", + "backupTime": "Backup Time", + "matchTitle": "✅ Version Match", + "matchMessage": "Import will clear and overwrite all existing data, including:\n• Main database (conversations, settings, etc.)\n• Knowledge bases\n• Plugins and plugin data\n• Configuration files\n\nThis action cannot be undone! Do you want to continue?", + "minorDiffTitle": "⚠️ Version Difference Warning", + "minorDiffMessage": "Minor version differences are usually compatible, but there may be some data structure changes.\nImport will clear and overwrite all existing data!\n\nDo you want to continue?", + "majorDiffTitle": "⛔ Cannot Import", + "majorDiffMessage": "Major version numbers are different. Cross-major-version import may cause data corruption.\nPlease use the same major version of AstrBot for import." + } + }, + "list": { + "empty": "No backup files", + "refresh": "Refresh List", + "confirmDelete": "Are you sure you want to delete this backup file? This action cannot be undone." + } } -} \ No newline at end of file +} \ No newline at end of file diff --git a/dashboard/src/i18n/locales/zh-CN/features/settings.json b/dashboard/src/i18n/locales/zh-CN/features/settings.json index bb6700f60..3778125aa 100644 --- a/dashboard/src/i18n/locales/zh-CN/features/settings.json +++ b/dashboard/src/i18n/locales/zh-CN/features/settings.json @@ -18,6 +18,11 @@ "title": "数据迁移到 v4.0.0 格式", "subtitle": "如果您遇到数据兼容性问题,可以手动启动数据库迁移助手", "button": "启动迁移助手" + }, + "backup": { + "title": "数据备份与恢复", + "subtitle": "导出或导入 AstrBot 的所有数据,方便迁移到新服务器", + "button": "备份管理" } }, "sidebar": { @@ -29,5 +34,66 @@ "mainItems": "主要模块", "moreItems": "更多功能" } + }, + "backup": { + "dialog": { + "title": "备份管理" + }, + "tabs": { + "export": "导出备份", + "import": "导入备份", + "list": "备份列表" + }, + "export": { + "title": "创建备份", + "description": "将所有数据导出为 ZIP 备份文件,包括数据库、知识库、配置和附件。", + "includes": "备份包含:主数据库、知识库(元数据+向量索引+文档)、配置文件、附件文件", + "button": "开始导出", + "processing": "正在导出...", + "wait": "请稍候,正在打包数据...", + "completed": "导出完成!", + "download": "下载备份", + "another": "创建新备份", + "failed": "导出失败", + "retry": "重试" + }, + "import": { + "title": "导入备份", + "warning": "⚠️ 导入将会清空并覆盖现有数据!请确保已备份当前数据。", + "selectFile": "选择备份文件 (.zip)", + "uploadAndCheck": "上传并检查", + "uploading": "正在上传...", + "uploadWait": "请稍候,正在上传备份文件...", + "invalidBackup": "无效的备份文件", + "backupContents": "备份内容", + "tables": "个数据表", + "knowledgeBases": "知识库", + "configFiles": "配置文件", + "confirmImport": "确认导入", + "button": "开始导入", + "processing": "正在导入...", + "wait": "请稍候,正在恢复数据...", + "completed": "导入完成!", + "restartRequired": "数据已成功导入。建议立即重启 AstrBot 以使所有更改生效。", + "restartNow": "立即重启", + "failed": "导入失败", + "retry": "重试", + "version": { + "backupVersion": "备份版本", + "currentVersion": "当前版本", + "backupTime": "备份时间", + "matchTitle": "✅ 版本匹配", + "matchMessage": "导入将会清空并覆盖现有的所有数据,包括:\n• 主数据库(对话记录、配置等)\n• 知识库数据\n• 插件及插件数据\n• 配置文件\n\n此操作不可撤销!是否确认继续?", + "minorDiffTitle": "⚠️ 版本差异警告", + "minorDiffMessage": "小版本差异通常是兼容的,但可能存在少量数据结构变化。\n导入将会清空并覆盖现有的所有数据!\n\n是否确认继续导入?", + "majorDiffTitle": "⛔ 无法导入", + "majorDiffMessage": "主版本号不同,跨主版本导入可能导致数据损坏。\n请使用相同主版本的 AstrBot 进行导入。" + } + }, + "list": { + "empty": "暂无备份文件", + "refresh": "刷新列表", + "confirmDelete": "确定要删除这个备份文件吗?此操作不可撤销。" + } } -} \ No newline at end of file +} \ No newline at end of file diff --git a/dashboard/src/views/Settings.vue b/dashboard/src/views/Settings.vue index 338d0394d..1c56119ab 100644 --- a/dashboard/src/views/Settings.vue +++ b/dashboard/src/views/Settings.vue @@ -17,6 +17,13 @@ {{ tm('system.title') }} + + + mdi-backup-restore + {{ tm('system.backup.button') }} + + + {{ tm('system.restart.button') }} @@ -30,6 +37,7 @@ + @@ -40,12 +48,14 @@ import WaitingForRestart from '@/components/shared/WaitingForRestart.vue'; import ProxySelector from '@/components/shared/ProxySelector.vue'; import MigrationDialog from '@/components/shared/MigrationDialog.vue'; import SidebarCustomizer from '@/components/shared/SidebarCustomizer.vue'; +import BackupDialog from '@/components/shared/BackupDialog.vue'; import { useModuleI18n } from '@/i18n/composables'; const { tm } = useModuleI18n('features/settings'); const wfr = ref(null); const migrationDialog = ref(null); +const backupDialog = ref(null); const restartAstrBot = () => { axios.post('/api/stat/restart-core').then(() => { @@ -65,4 +75,10 @@ const startMigration = async () => { } } } + +const openBackupDialog = () => { + if (backupDialog.value) { + backupDialog.value.open(); + } +} \ No newline at end of file diff --git a/tests/test_backup.py b/tests/test_backup.py new file mode 100644 index 000000000..2341c6aed --- /dev/null +++ b/tests/test_backup.py @@ -0,0 +1,749 @@ +"""备份功能单元测试""" + +import json +import os +import re +import zipfile +from datetime import datetime +from unittest.mock import AsyncMock, MagicMock + +import pytest + +from astrbot.core.backup import ( + BACKUP_MANIFEST_VERSION, + KB_METADATA_MODELS, + MAIN_DB_MODELS, + ImportPreCheckResult, +) +from astrbot.core.backup.exporter import AstrBotExporter +from astrbot.core.backup.importer import ( + AstrBotImporter, + ImportResult, + compare_versions, + parse_version, +) +from astrbot.core.config.default import VERSION +from astrbot.core.db.po import ( + ConversationV2, +) +from astrbot.dashboard.routes.backup import ( + generate_unique_filename, + secure_filename, +) + + +@pytest.fixture +def temp_backup_dir(tmp_path): + """创建临时备份目录""" + backup_dir = tmp_path / "backups" + backup_dir.mkdir() + return backup_dir + + +@pytest.fixture +def temp_data_dir(tmp_path): + """创建临时数据目录""" + data_dir = tmp_path / "data" + data_dir.mkdir() + + # 创建配置文件 + config_path = data_dir / "cmd_config.json" + config_path.write_text(json.dumps({"test": "config"})) + + # 创建附件目录 + attachments_dir = data_dir / "attachments" + attachments_dir.mkdir() + + return data_dir + + +@pytest.fixture +def mock_main_db(): + """创建模拟的主数据库""" + db = MagicMock() + + # 模拟异步上下文管理器 + session = AsyncMock() + db.get_db = MagicMock( + return_value=AsyncMock(__aenter__=AsyncMock(return_value=session)) + ) + + return db + + +@pytest.fixture +def mock_kb_manager(): + """创建模拟的知识库管理器""" + kb_manager = MagicMock() + kb_manager.kb_insts = {} + + # 模拟 kb_db + kb_db = MagicMock() + session = AsyncMock() + kb_db.get_db = MagicMock( + return_value=AsyncMock(__aenter__=AsyncMock(return_value=session)) + ) + kb_manager.kb_db = kb_db + + return kb_manager + + +class TestImportResult: + """ImportResult 类测试""" + + def test_init(self): + """测试初始化""" + result = ImportResult() + assert result.success is True + assert result.imported_tables == {} + assert result.imported_files == {} + assert result.warnings == [] + assert result.errors == [] + + def test_add_warning(self): + """测试添加警告""" + result = ImportResult() + result.add_warning("test warning") + assert "test warning" in result.warnings + assert result.success is True # 警告不影响成功状态 + + def test_add_error(self): + """测试添加错误""" + result = ImportResult() + result.add_error("test error") + assert "test error" in result.errors + assert result.success is False # 错误会导致失败 + + def test_to_dict(self): + """测试转换为字典""" + result = ImportResult() + result.imported_tables = {"test_table": 10} + result.add_warning("warning") + + d = result.to_dict() + assert d["success"] is True + assert d["imported_tables"] == {"test_table": 10} + assert "warning" in d["warnings"] + + +class TestAstrBotExporter: + """AstrBotExporter 类测试""" + + def test_init(self, mock_main_db, mock_kb_manager, temp_data_dir): + """测试初始化""" + exporter = AstrBotExporter( + main_db=mock_main_db, + kb_manager=mock_kb_manager, + config_path=str(temp_data_dir / "cmd_config.json"), + attachments_dir=str(temp_data_dir / "attachments"), + ) + assert exporter.main_db is mock_main_db + assert exporter.kb_manager is mock_kb_manager + + def test_model_to_dict_with_model_dump(self): + """测试 _model_to_dict 使用 model_dump 方法""" + exporter = AstrBotExporter(main_db=MagicMock()) + + # 创建一个有 model_dump 方法的模拟对象 + mock_record = MagicMock() + mock_record.model_dump.return_value = {"id": 1, "name": "test"} + + result = exporter._model_to_dict(mock_record) + assert result == {"id": 1, "name": "test"} + + def test_model_to_dict_with_datetime(self): + """测试 _model_to_dict 处理 datetime 字段""" + exporter = AstrBotExporter(main_db=MagicMock()) + + now = datetime.now() + mock_record = MagicMock() + mock_record.model_dump.return_value = {"id": 1, "created_at": now} + + result = exporter._model_to_dict(mock_record) + assert result["created_at"] == now.isoformat() + + def test_add_checksum(self): + """测试添加校验和""" + exporter = AstrBotExporter(main_db=MagicMock()) + + exporter._add_checksum("test.json", '{"test": "data"}') + + assert "test.json" in exporter._checksums + assert exporter._checksums["test.json"].startswith("sha256:") + + def test_generate_manifest(self, mock_main_db, mock_kb_manager): + """测试生成清单""" + exporter = AstrBotExporter( + main_db=mock_main_db, + kb_manager=mock_kb_manager, + ) + + main_data = { + "platform_stats": [{"id": 1}], + "conversations": [], + "attachments": [], + } + kb_meta_data = { + "knowledge_bases": [], + "kb_documents": [], + } + dir_stats = { + "plugins": {"files": 10, "size": 1024}, + "plugin_data": {"files": 5, "size": 512}, + } + + manifest = exporter._generate_manifest(main_data, kb_meta_data, dir_stats) + + assert manifest["version"] == BACKUP_MANIFEST_VERSION + assert manifest["astrbot_version"] == VERSION + assert "exported_at" in manifest + assert "tables" in manifest + assert "statistics" in manifest + assert "directories" in manifest + assert manifest["statistics"]["main_db"]["platform_stats"] == 1 + assert manifest["statistics"]["directories"] == dir_stats + + @pytest.mark.asyncio + async def test_export_all_creates_zip( + self, mock_main_db, temp_backup_dir, temp_data_dir + ): + """测试导出创建 ZIP 文件""" + # 设置模拟数据库返回空数据 + session = AsyncMock() + result = MagicMock() + result.scalars.return_value.all.return_value = [] + session.execute = AsyncMock(return_value=result) + + mock_main_db.get_db.return_value = AsyncMock( + __aenter__=AsyncMock(return_value=session), + __aexit__=AsyncMock(return_value=None), + ) + + exporter = AstrBotExporter( + main_db=mock_main_db, + kb_manager=None, + config_path=str(temp_data_dir / "cmd_config.json"), + attachments_dir=str(temp_data_dir / "attachments"), + ) + + zip_path = await exporter.export_all(output_dir=str(temp_backup_dir)) + + assert os.path.exists(zip_path) + assert zip_path.endswith(".zip") + assert "astrbot_backup_" in zip_path + + # 验证 ZIP 文件内容 + with zipfile.ZipFile(zip_path, "r") as zf: + namelist = zf.namelist() + assert "manifest.json" in namelist + assert "databases/main_db.json" in namelist + assert "config/cmd_config.json" in namelist + + +class TestAstrBotImporter: + """AstrBotImporter 类测试""" + + def test_init(self, mock_main_db, mock_kb_manager, temp_data_dir): + """测试初始化""" + importer = AstrBotImporter( + main_db=mock_main_db, + kb_manager=mock_kb_manager, + config_path=str(temp_data_dir / "cmd_config.json"), + attachments_dir=str(temp_data_dir / "attachments"), + ) + assert importer.main_db is mock_main_db + assert importer.kb_manager is mock_kb_manager + + def test_validate_version_match(self): + """测试版本匹配验证""" + importer = AstrBotImporter(main_db=MagicMock()) + + manifest = {"astrbot_version": VERSION} + # 不应该抛出异常 + importer._validate_version(manifest) + + def test_validate_version_major_diff_rejected(self): + """测试主版本不同被拒绝""" + importer = AstrBotImporter(main_db=MagicMock()) + + # 使用一个明显不同的主版本 + manifest = {"astrbot_version": "0.0.1"} + with pytest.raises(ValueError, match="主版本不兼容"): + importer._validate_version(manifest) + + def test_validate_version_minor_diff_allowed(self): + """测试小版本不同被允许(仅警告)""" + importer = AstrBotImporter(main_db=MagicMock()) + + # 解析当前版本 + current_parts = parse_version(VERSION) + if len(current_parts) >= 2: + # 构造一个同主版本但小版本不同的版本 + minor_diff_version = f"{current_parts[0]}.{current_parts[1]}.999" + manifest = {"astrbot_version": minor_diff_version} + # 不应该抛出异常 + importer._validate_version(manifest) + + def test_validate_version_missing(self): + """测试缺少版本信息""" + importer = AstrBotImporter(main_db=MagicMock()) + + manifest = {} + with pytest.raises(ValueError, match="缺少版本信息"): + importer._validate_version(manifest) + + def test_convert_datetime_fields(self): + """测试 datetime 字段转换""" + importer = AstrBotImporter(main_db=MagicMock()) + + # 使用 ConversationV2 作为测试模型(它有 created_at 和 updated_at 字段) + row = { + "conversation_id": "test-123", + "platform_id": "test", + "user_id": "user1", + "created_at": "2024-01-01T12:00:00", + "updated_at": "2024-01-01T12:00:00", + } + + result = importer._convert_datetime_fields(row, ConversationV2) + + # created_at 应该被转换为 datetime 对象 + assert isinstance(result["created_at"], datetime) + assert isinstance(result["updated_at"], datetime) + + @pytest.mark.asyncio + async def test_import_file_not_exists(self, mock_main_db, tmp_path): + """测试导入不存在的文件""" + importer = AstrBotImporter(main_db=mock_main_db) + + result = await importer.import_all(str(tmp_path / "nonexistent.zip")) + + assert result.success is False + assert any("不存在" in err for err in result.errors) + + @pytest.mark.asyncio + async def test_import_invalid_zip(self, mock_main_db, tmp_path): + """测试导入无效的 ZIP 文件""" + # 创建一个无效的文件 + invalid_zip = tmp_path / "invalid.zip" + invalid_zip.write_text("not a zip file") + + importer = AstrBotImporter(main_db=mock_main_db) + result = await importer.import_all(str(invalid_zip)) + + assert result.success is False + assert any("无效" in err or "ZIP" in err for err in result.errors) + + @pytest.mark.asyncio + async def test_import_missing_manifest(self, mock_main_db, tmp_path): + """测试导入缺少 manifest 的 ZIP 文件""" + # 创建一个没有 manifest 的 ZIP 文件 + zip_path = tmp_path / "no_manifest.zip" + with zipfile.ZipFile(zip_path, "w") as zf: + zf.writestr("test.txt", "test content") + + importer = AstrBotImporter(main_db=mock_main_db) + result = await importer.import_all(str(zip_path)) + + assert result.success is False + assert any("manifest" in err.lower() for err in result.errors) + + @pytest.mark.asyncio + async def test_import_major_version_mismatch(self, mock_main_db, tmp_path): + """测试导入主版本不匹配的备份""" + # 创建一个主版本不匹配的备份 + zip_path = tmp_path / "old_version.zip" + manifest = { + "version": "1.0", + "astrbot_version": "0.0.1", # 主版本不同 + "tables": {"main_db": []}, + } + + with zipfile.ZipFile(zip_path, "w") as zf: + zf.writestr("manifest.json", json.dumps(manifest)) + + importer = AstrBotImporter(main_db=mock_main_db) + result = await importer.import_all(str(zip_path)) + + assert result.success is False + assert any("主版本不兼容" in err for err in result.errors) + + +class TestSecureFilename: + """安全文件名函数测试""" + + def test_secure_filename_normal(self): + """测试正常文件名""" + assert secure_filename("backup.zip") == "backup.zip" + assert secure_filename("my_backup_2024.zip") == "my_backup_2024.zip" + + def test_secure_filename_path_traversal(self): + """测试路径遍历攻击""" + assert ".." not in secure_filename("../../../etc/passwd") + assert "/" not in secure_filename("/etc/passwd") + assert "\\" not in secure_filename("..\\..\\windows\\system32") + + def test_secure_filename_with_path(self): + """测试带路径的文件名""" + result = secure_filename("/path/to/backup.zip") + assert result == "backup.zip" + + result = secure_filename("C:\\Users\\test\\backup.zip") + assert result == "backup.zip" + + def test_secure_filename_special_chars(self): + """测试特殊字符""" + result = secure_filename('backup<>:"|?*.zip') + # 特殊字符应被替换为下划线 + assert "<" not in result + assert ">" not in result + assert ":" not in result + assert '"' not in result + assert "|" not in result + assert "?" not in result + assert "*" not in result + + def test_secure_filename_hidden_file(self): + """测试隐藏文件(前导点)""" + result = secure_filename(".hidden_backup.zip") + assert not result.startswith(".") + + def test_secure_filename_empty(self): + """测试空文件名""" + assert secure_filename("") == "backup" + assert secure_filename("...") == "backup" + + def test_generate_unique_filename(self): + """测试生成唯一文件名""" + result = generate_unique_filename("backup.zip") + # 应包含 uploaded_ 前缀和时间戳 + assert result.startswith("uploaded_") + assert result.endswith("_backup.zip") + # 应包含时间戳格式 YYYYMMDD_HHMMSS + assert re.search(r"uploaded_\d{8}_\d{6}_backup\.zip", result) + + +class TestVersionComparison: + """版本比较函数测试""" + + def test_parse_version_simple(self): + """测试解析简单版本号""" + assert parse_version("1.0") == (1, 0) + assert parse_version("2.1") == (2, 1) + + def test_parse_version_multi_digit(self): + """测试解析多位数版本号""" + assert parse_version("1.10") == (1, 10) + assert parse_version("1.10.2") == (1, 10, 2) + assert parse_version("10.20.30") == (10, 20, 30) + + def test_parse_version_invalid(self): + """测试解析无效版本号""" + assert parse_version("invalid") == (0,) + assert parse_version("") == (0,) + assert parse_version("1.x.2") == (0,) + + def test_compare_versions_equal(self): + """测试版本相等""" + assert compare_versions("1.0", "1.0") == 0 + assert compare_versions("1.0.0", "1.0") == 0 + assert compare_versions("2.10", "2.10") == 0 + + def test_compare_versions_less_than(self): + """测试版本小于""" + assert compare_versions("1.0", "1.1") == -1 + assert compare_versions("1.9", "1.10") == -1 # 关键测试:多位数版本比较 + assert compare_versions("1.2", "1.10") == -1 + assert compare_versions("1.0", "2.0") == -1 + + def test_compare_versions_greater_than(self): + """测试版本大于""" + assert compare_versions("1.1", "1.0") == 1 + assert compare_versions("1.10", "1.9") == 1 # 关键测试:多位数版本比较 + assert compare_versions("1.10", "1.2") == 1 + assert compare_versions("2.0", "1.0") == 1 + + def test_compare_versions_different_lengths(self): + """测试不同长度版本比较""" + assert compare_versions("1.0", "1.0.0") == 0 + assert compare_versions("1.0", "1.0.1") == -1 + assert compare_versions("1.0.1", "1.0") == 1 + + +class TestImportPreCheckResult: + """ImportPreCheckResult 类测试""" + + def test_init_default_values(self): + """测试默认值初始化""" + result = ImportPreCheckResult() + assert result.valid is False + assert result.can_import is False + assert result.version_status == "" + assert result.backup_version == "" + assert result.current_version == VERSION + assert result.confirm_message == "" + assert result.warnings == [] + assert result.error == "" + assert result.backup_summary == {} + + def test_to_dict(self): + """测试转换为字典""" + result = ImportPreCheckResult( + valid=True, + can_import=True, + version_status="match", + backup_version="4.9.0", + confirm_message="确认导入?", + warnings=["警告1"], + backup_summary={"tables": ["table1"]}, + ) + + d = result.to_dict() + assert d["valid"] is True + assert d["can_import"] is True + assert d["version_status"] == "match" + assert d["backup_version"] == "4.9.0" + assert d["confirm_message"] == "确认导入?" + assert "警告1" in d["warnings"] + assert d["backup_summary"]["tables"] == ["table1"] + + +class TestPreCheck: + """预检查功能测试""" + + def test_pre_check_file_not_exists(self, mock_main_db): + """测试预检查不存在的文件""" + importer = AstrBotImporter(main_db=mock_main_db) + result = importer.pre_check("/nonexistent/file.zip") + + assert result.valid is False + assert "不存在" in result.error + + def test_pre_check_invalid_zip(self, mock_main_db, tmp_path): + """测试预检查无效的 ZIP 文件""" + invalid_zip = tmp_path / "invalid.zip" + invalid_zip.write_text("not a zip file") + + importer = AstrBotImporter(main_db=mock_main_db) + result = importer.pre_check(str(invalid_zip)) + + assert result.valid is False + assert "ZIP" in result.error or "无效" in result.error + + def test_pre_check_missing_manifest(self, mock_main_db, tmp_path): + """测试预检查缺少 manifest 的 ZIP 文件""" + zip_path = tmp_path / "no_manifest.zip" + with zipfile.ZipFile(zip_path, "w") as zf: + zf.writestr("test.txt", "test content") + + importer = AstrBotImporter(main_db=mock_main_db) + result = importer.pre_check(str(zip_path)) + + assert result.valid is False + assert "manifest" in result.error.lower() + + def test_pre_check_version_match(self, mock_main_db, tmp_path): + """测试预检查版本匹配""" + zip_path = tmp_path / "backup.zip" + manifest = { + "version": "1.1", + "astrbot_version": VERSION, + "created_at": "2024-01-01T12:00:00", + "tables": {"platform_stats": 1}, + "has_knowledge_bases": True, + "has_config": True, + "directories": ["plugins"], + } + + with zipfile.ZipFile(zip_path, "w") as zf: + zf.writestr("manifest.json", json.dumps(manifest)) + + importer = AstrBotImporter(main_db=mock_main_db) + result = importer.pre_check(str(zip_path)) + + assert result.valid is True + assert result.can_import is True + assert result.version_status == "match" + assert result.backup_version == VERSION + # confirm_message 现在由前端生成,后端不再生成 + assert result.backup_summary["has_knowledge_bases"] is True + + def test_pre_check_minor_version_diff(self, mock_main_db, tmp_path): + """测试预检查小版本差异""" + # 构造一个同主版本但小版本不同的版本 + current_parts = parse_version(VERSION) + # VERSION 应该至少有两个部分(如 4.9) + assert len(current_parts) >= 2, f"VERSION {VERSION} 应该有至少两个部分" + minor_diff_version = f"{current_parts[0]}.{current_parts[1]}.999" + + zip_path = tmp_path / "backup.zip" + manifest = { + "version": "1.1", + "astrbot_version": minor_diff_version, + "created_at": "2024-01-01T12:00:00", + "tables": {}, + } + + with zipfile.ZipFile(zip_path, "w") as zf: + zf.writestr("manifest.json", json.dumps(manifest)) + + importer = AstrBotImporter(main_db=mock_main_db) + result = importer.pre_check(str(zip_path)) + + assert result.valid is True + assert result.can_import is True + assert result.version_status == "minor_diff" + # 版本消息由前端 i18n 生成,后端 warnings 列表不再包含版本相关消息 + # warnings 列表保留用于其他非版本相关的警告 + + def test_pre_check_major_version_diff(self, mock_main_db, tmp_path): + """测试预检查主版本差异""" + zip_path = tmp_path / "backup.zip" + manifest = { + "version": "1.1", + "astrbot_version": "0.0.1", # 主版本不同 + "created_at": "2024-01-01T12:00:00", + "tables": {}, + } + + with zipfile.ZipFile(zip_path, "w") as zf: + zf.writestr("manifest.json", json.dumps(manifest)) + + importer = AstrBotImporter(main_db=mock_main_db) + result = importer.pre_check(str(zip_path)) + + assert result.valid is True # 文件有效 + assert result.can_import is False # 但不能导入 + assert result.version_status == "major_diff" + # 版本消息由前端 i18n 生成,后端 warnings 列表不再包含版本相关消息 + + +class TestVersionCompatibility: + """版本兼容性检查测试""" + + def test_check_version_compatibility_match(self, mock_main_db): + """测试版本完全匹配""" + importer = AstrBotImporter(main_db=mock_main_db) + result = importer._check_version_compatibility(VERSION) + + assert result["status"] == "match" + assert result["can_import"] is True + + def test_check_version_compatibility_minor_diff(self, mock_main_db): + """测试小版本差异""" + current_parts = parse_version(VERSION) + # VERSION 应该至少有两个部分(如 4.9) + assert len(current_parts) >= 2, f"VERSION {VERSION} 应该有至少两个部分" + minor_diff_version = f"{current_parts[0]}.{current_parts[1]}.999" + + importer = AstrBotImporter(main_db=mock_main_db) + result = importer._check_version_compatibility(minor_diff_version) + + assert result["status"] == "minor_diff" + assert result["can_import"] is True + + def test_check_version_compatibility_major_diff(self, mock_main_db): + """测试主版本差异""" + importer = AstrBotImporter(main_db=mock_main_db) + result = importer._check_version_compatibility("0.0.1") + + assert result["status"] == "major_diff" + assert result["can_import"] is False + + def test_check_version_compatibility_empty_version(self, mock_main_db): + """测试空版本号""" + importer = AstrBotImporter(main_db=mock_main_db) + result = importer._check_version_compatibility("") + + assert result["status"] == "major_diff" + assert result["can_import"] is False + + +class TestModelMappings: + """测试模型映射配置""" + + def test_main_db_models_not_empty(self): + """测试主数据库模型映射非空""" + assert len(MAIN_DB_MODELS) > 0 + + def test_main_db_models_contain_expected_tables(self): + """测试主数据库模型映射包含预期的表""" + expected_tables = [ + "platform_stats", + "conversations", + "personas", + "preferences", + "attachments", + ] + for table in expected_tables: + assert table in MAIN_DB_MODELS, f"Missing table: {table}" + + def test_kb_metadata_models_not_empty(self): + """测试知识库元数据模型映射非空""" + assert len(KB_METADATA_MODELS) > 0 + + def test_kb_metadata_models_contain_expected_tables(self): + """测试知识库元数据模型映射包含预期的表""" + expected_tables = [ + "knowledge_bases", + "kb_documents", + "kb_media", + ] + for table in expected_tables: + assert table in KB_METADATA_MODELS, f"Missing table: {table}" + + +class TestBackupIntegration: + """备份集成测试""" + + @pytest.mark.asyncio + async def test_export_import_roundtrip(self, tmp_path): + """测试导出-导入往返""" + backup_dir = tmp_path / "backups" + backup_dir.mkdir() + + data_dir = tmp_path / "data" + data_dir.mkdir() + + config_path = data_dir / "cmd_config.json" + config_path.write_text(json.dumps({"setting": "value"})) + + attachments_dir = data_dir / "attachments" + attachments_dir.mkdir() + + # 创建模拟数据库 + mock_db = MagicMock() + session = AsyncMock() + result = MagicMock() + result.scalars.return_value.all.return_value = [] + session.execute = AsyncMock(return_value=result) + + mock_db.get_db.return_value = AsyncMock( + __aenter__=AsyncMock(return_value=session), + __aexit__=AsyncMock(return_value=None), + ) + + # 导出 + exporter = AstrBotExporter( + main_db=mock_db, + kb_manager=None, + config_path=str(config_path), + attachments_dir=str(attachments_dir), + ) + + zip_path = await exporter.export_all(output_dir=str(backup_dir)) + assert os.path.exists(zip_path) + + # 验证 ZIP 内容 + with zipfile.ZipFile(zip_path, "r") as zf: + # 读取 manifest + manifest = json.loads(zf.read("manifest.json")) + assert manifest["astrbot_version"] == VERSION + + # 读取配置 + config = json.loads(zf.read("config/cmd_config.json")) + assert config["setting"] == "value" + + # 读取主数据库 + main_db = json.loads(zf.read("databases/main_db.json")) + assert "platform_stats" in main_db