diff --git a/vdb_benchmark/.gitignore b/vdb_benchmark/.gitignore new file mode 100644 index 0000000..95b3f05 --- /dev/null +++ b/vdb_benchmark/.gitignore @@ -0,0 +1,180 @@ +# Byte-compiled / optimized / DLL files +__pycache__/ +tests/tests/__pycache__/ +*.py[cod] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +share/python-wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.nox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +*.py,cover +.hypothesis/ +.pytest_cache/ +cover/ +tests/.benchmarks/ +tests/.coverage +tests/tests/coverage_html/ +tests/tests/test_results.* +tests/tests/test_report.* + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 +db.sqlite3-journal + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +.pybuilder/ +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# IPython +profile_default/ +ipython_config.py + +# pyenv +# For a library or package, you might want to ignore these files since the code is +# intended to run in multiple environments; otherwise, check them in: +# .python-version + +# pipenv +# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. +# However, in case of collaboration, if having platform-specific dependencies or dependencies +# having no cross-platform support, pipenv may install dependencies that don't work, or not +# install all needed dependencies. +#Pipfile.lock + +# UV +# Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control. +# This is especially recommended for binary packages to ensure reproducibility, and is more +# commonly ignored for libraries. +#uv.lock + +# poetry +# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. +# This is especially recommended for binary packages to ensure reproducibility, and is more +# commonly ignored for libraries. +# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control +#poetry.lock + +# pdm +# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. +#pdm.lock +# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it +# in version control. +# https://pdm.fming.dev/latest/usage/project/#working-with-version-control +.pdm.toml +.pdm-python +.pdm-build/ + +# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm +__pypackages__/ + +# Celery stuff +celerybeat-schedule +celerybeat.pid + +# SageMath parsed files +*.sage.py + +# Environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ +.dmypy.json +dmypy.json + +# Pyre type checker +.pyre/ + +# pytype static type analyzer +.pytype/ + +# Cython debug symbols +cython_debug/ + +# PyCharm +# JetBrains specific template is maintained in a separate JetBrains.gitignore that can +# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore +# and can be added to the global gitignore or merged into this file. For a more nuclear +# option (not recommended) you can uncomment the following to ignore the entire idea folder. +#.idea/ + +# Ruff stuff: +.ruff_cache/ + +# PyPI configuration file +.pypirc diff --git a/vdb_benchmark/LICENSE b/vdb_benchmark/LICENSE new file mode 100644 index 0000000..261eeb9 --- /dev/null +++ b/vdb_benchmark/LICENSE @@ -0,0 +1,201 @@ + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright [yyyy] [name of copyright owner] + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. diff --git a/vdb_benchmark/README.md b/vdb_benchmark/README.md new file mode 100644 index 0000000..e8ea20e --- /dev/null +++ b/vdb_benchmark/README.md @@ -0,0 +1,125 @@ +# Vector Database Benchmark Tool +This tool allows you to benchmark and compare the performance of vector databases with current support for Milvus and others planned. + +## Installation + +### Using Docker (recommended) +1. Clone the repository: +``` bash +git clone -b TF_VDBBench https://github.com/mlcommons/storage.git +cd storage/vdb_benchmark +``` +2. Build and run the Docker container: +```bash +docker compose up -d # with docker-compose-v2. v1 uses docker-compose up +``` + +### Manual Installation +1. Clone the repository: +```bash +git clone -b TF_VDBBench https://github.com/mlcommons/storage.git +cd storage/vdb_benchmark +``` + +2. Install the package: +```bash +pip3 install ./ +``` + +## Deploying a Standalone Milvus Instance +The docker-compose.yml file will configure a 3-container instance of Milvus database. + - Milvus Database + - Minio Object Storage + - etcd + +The docker-compose.yml file uses ```/mnt/vdb``` as the root directory for the required docker volumes. You can modify the compose file for your environment or ensure that your target storage is mounted at this location. + +For testing more than one storage solution, there are two methods: +1. Create a set of containers for each storage solution with modified docker-compose.yml files pointing to different root directories. Each set of containers will also need a different port to listen on. You may need to limit how many instances you can run depending on the available memory in your system +2. Bring down the containers, copy the /mnt/vdb data to another location, change the mount point to point to the new location. Bring the containers back up. This is simpler as the database connection isn't changing but you need to manually reconfigure the storage to change the system under test. + +### Deployment +```bash +cd storage/vdb_benchmark +docker compose up -d # with docker-compose-v2. v1 uses docker-compose up +``` + +```-d``` option is required to detach from the containers after starting them. Without this option you will be attached to the log output of the set of containers and ```ctrl+c``` will stop the containers. + +*If you have connection problems with a proxy I recommend this link: https://medium.com/@SrvZ/docker-proxy-and-my-struggles-a4fd6de21861* + +## Running the Benchmark +The benchmark process consists of three main steps: +1. Loading vectors into the database +2. Monitoring and compacting the database +3. Running the benchmark queries + +### Step 1: Load Vectors into the Database +Use the load_vdb.py script to generate and load 10 million vectors into your vector database: (this process can take up to 8 hours) +```bash +python vdbbench/load_vdb.py --config vdbbench/configs/10m_diskann.yaml +``` + + +For testing, I recommend using a smaller data by passing the num_vectors option: +```bash +python vdbbench/load_vdb.py --config vdbbench/configs/10m_diskann.yaml --collection-name mlps_500k_10shards_1536dim_uniform_diskann --num-vectors 500000 +``` + +Key parameters: +* --collection-name: Name of the collection to create +* --dimension: Vector dimension +* --num-vectors: Number of vectors to generate +* --chunk-size: Number of vectors to generate in each chunk (for memory management) +* --distribution: Distribution for vector generation (uniform, normal) +* --batch-size: Batch size for insertion + +Example configuration file (vdbbench/configs/10m_diskann.yaml): +```yaml +database: + host: 127.0.0.1 + port: 19530 + database: milvus + max_receive_message_length: 514_983_574 + max_send_message_length: 514_983_574 + +dataset: + collection_name: mlps_10m_10shards_1536dim_uniform_diskann + num_vectors: 10_000_000 + dimension: 1536 + distribution: uniform + batch_size: 1000 + num_shards: 10 + vector_dtype: FLOAT_VECTOR + +index: + index_type: DISKANN + metric_type: COSINE + #index_params + max_degree: 64 + search_list_size: 200 + +workflow: + compact: True +``` + +### Step 2: Monitor and Compact the Database +The compact_and_watch.py script monitors the database and performs compaction. You should only need this if the load process exits out while waiting. The load script will do compaction and will wait for it to complete. +```bash +python vdbbench/compact_and_watch.py --config vdbbench/configs/10m_diskann.yaml --interval 5 +``` +This step is automatically performed at the end of the loading process if you set compact: true in your configuration. + +### Step 3: Run the Benchmark +Finally, run the benchmark using the simple_bench.py script: +```bash +python vdbbench/simple_bench.py --host 127.0.0.1 --collection --processes --batch-size --runtime +``` + +For comparison with HNSW indexing, use ```vdbbench/configs/10m_hnsw.yaml``` and update collection_name accordingly. + +## Supported Databases +Milvus with DiskANN & HNSW indexing (currently implemented) + +# Contributing +Contributions are welcome! Please feel free to submit a Pull Request. diff --git a/vdb_benchmark/docker-compose.yml b/vdb_benchmark/docker-compose.yml new file mode 100644 index 0000000..4c69af2 --- /dev/null +++ b/vdb_benchmark/docker-compose.yml @@ -0,0 +1,68 @@ +version: '3.5' + +services: + etcd: + container_name: milvus-etcd + image: quay.io/coreos/etcd:v3.5.18 + environment: + - ETCD_AUTO_COMPACTION_MODE=revision + - ETCD_AUTO_COMPACTION_RETENTION=1000 + - ETCD_QUOTA_BACKEND_BYTES=4294967296 + - ETCD_SNAPSHOT_COUNT=50000 + volumes: + - /mnt/vdb/etcd:/etcd + command: etcd -advertise-client-urls=http://etcd:2379 -listen-client-urls http://0.0.0.0:2379 --data-dir /etcd + ports: + - "2379:2379" + healthcheck: + test: ["CMD", "etcdctl", "endpoint", "health"] + interval: 30s + timeout: 20s + retries: 3 + + minio: + container_name: milvus-minio + image: minio/minio:RELEASE.2023-03-20T20-16-18Z + environment: + MINIO_ACCESS_KEY: minioadmin + MINIO_SECRET_KEY: minioadmin + ports: + - "9001:9001" + - "9000:9000" + volumes: + - /mnt/vdb/minio:/minio_data + command: minio server /minio_data --console-address ":9001" + healthcheck: + test: ["CMD", "curl", "-f", "http://localhost:9000/minio/health/live"] + interval: 30s + timeout: 20s + retries: 3 + + standalone: + container_name: milvus-standalone + image: milvusdb/milvus:v2.5.10 + command: ["milvus", "run", "standalone"] + security_opt: + - seccomp:unconfined + environment: + MINIO_REGION: us-east-1 + ETCD_ENDPOINTS: etcd:2379 + MINIO_ADDRESS: minio:9000 + volumes: + - /mnt/vdb/milvus:/var/lib/milvus + healthcheck: + test: ["CMD", "curl", "-f", "http://localhost:9091/healthz"] + interval: 30s + start_period: 90s + timeout: 20s + retries: 3 + ports: + - "19530:19530" + - "9091:9091" + depends_on: + - "etcd" + - "minio" + +networks: + default: + name: milvus diff --git a/vdb_benchmark/list_collections.py b/vdb_benchmark/list_collections.py new file mode 100644 index 0000000..a83b2f8 --- /dev/null +++ b/vdb_benchmark/list_collections.py @@ -0,0 +1,153 @@ +#!/usr/bin/env python3 +""" +Milvus Collection Lister + +This script connects to a local Milvus database and lists all collections +along with the number of vectors in each collection. +""" + +import argparse +import sys +from typing import Dict, List, Tuple + +try: + from pymilvus import connections, utility + from pymilvus.exceptions import MilvusException +except ImportError: + print("Error: pymilvus package not found. Please install it with 'pip install pymilvus'") + sys.exit(1) + + +def parse_args() -> argparse.Namespace: + """Parse command line arguments""" + parser = argparse.ArgumentParser(description="List Milvus collections and their vector counts") + parser.add_argument("--host", type=str, default="127.0.0.1", + help="Milvus server host (default: 127.0.0.1)") + parser.add_argument("--port", type=str, default="19530", + help="Milvus server port (default: 19530)") + parser.add_argument("--verbose", "-v", action="store_true", + help="Show detailed collection information") + return parser.parse_args() + + +def connect_to_milvus(host: str, port: str) -> bool: + """Establish connection to Milvus server""" + try: + connections.connect( + alias="default", + host=host, + port=port, + max_receive_message_length=514983574, + max_send_message_length=514983574 + ) + return True + except Exception as e: + print(f"Failed to connect to Milvus: {e}") + return False + + +def get_collections_info() -> List[Dict]: + """Get information about all collections""" + try: + collection_names = utility.list_collections() + collections_info = [] + + for name in collection_names: + from pymilvus import Collection + collection = Collection(name) + + # Get collection statistics - using num_entities instead of get_stats() + row_count = collection.num_entities + + # Get collection schema + schema = collection.schema + description = schema.description if schema.description else "No description" + + # Get vector field dimension + vector_field = None + vector_dim = None + for field in schema.fields: + if field.dtype == 100: # DataType.FLOAT_VECTOR + vector_field = field.name + vector_dim = field.params.get("dim") + break + + # Get index information + index_info = [] + try: + for field_name in collection.schema.fields: + if collection.has_index(field_name.name): + index = collection.index(field_name.name) + index_info.append({ + "field": field_name.name, + "index_type": index.params.get("index_type"), + "metric_type": index.params.get("metric_type"), + "params": index.params.get("params", {}) + }) + except Exception as e: + index_info = [{"error": str(e)}] + + collections_info.append({ + "name": name, + "row_count": row_count, + "description": description, + "vector_field": vector_field, + "vector_dim": vector_dim, + "index_info": index_info + }) + + return collections_info + except MilvusException as e: + print(f"Error retrieving collection information: {e}") + return [] + + +def main() -> int: + """Main function""" + args = parse_args() + + # Connect to Milvus + if not connect_to_milvus(args.host, args.port): + return 1 + + print(f"Connected to Milvus server at {args.host}:{args.port}") + + # Get collections information + collections_info = get_collections_info() + + if not collections_info: + print("No collections found.") + return 0 + + # Display collections information + print(f"\nFound {len(collections_info)} collections:") + print("-" * 80) + + for info in collections_info: + print(f"Collection: {info['name']}") + print(f" Vectors: {info['row_count']:,}") + print(f" Vector Field: {info['vector_field']} (dim: {info['vector_dim']})") + + if args.verbose: + print(f" Description: {info['description']}") + + if info['index_info']: + print(" Indexes:") + for idx in info['index_info']: + if "error" in idx: + print(f" Error retrieving index info: {idx['error']}") + else: + print(f" Field: {idx['field']}") + print(f" Type: {idx['index_type']}") + print(f" Metric: {idx['metric_type']}") + print(f" Params: {idx['params']}") + else: + print(" Indexes: None") + + print("-" * 80) + + return 0 + + +if __name__ == "__main__": + sys.exit(main()) \ No newline at end of file diff --git a/vdb_benchmark/pyproject.toml b/vdb_benchmark/pyproject.toml new file mode 100644 index 0000000..f4d56d8 --- /dev/null +++ b/vdb_benchmark/pyproject.toml @@ -0,0 +1,36 @@ +[build-system] +requires = ["setuptools>=42", "wheel"] +build-backend = "setuptools.build_meta" + +[project] +name = "vdbbench" +version = "0.1.0" +description = "Vector Database Benchmarking Tool" +readme = "README.md" +authors = [ + {name = "Vector DB Storage WG TF"} +] +license = {text = "MIT"} +requires-python = ">=3.8" +dependencies = [ + "numpy", + "pandas", + "pymilvus", + "pyyaml", + "tabulate" +] + +[project.urls] +"Homepage" = "https://github.com/mlcommons/storage/tree/TF_VDBBench/vdb_benchmark" +"Bug Tracker" = "https://github.com/mlcommons/storage/issues" + +[project.scripts] +compact-and-watch = "vdbbench.compact_and_watch:main" +load-vdb = "vdbbench.load_vdb:main" +vdbbench = "vdbbench.simple_bench:main" + +[tool.setuptools] +packages = {find = {}} + +[tool.setuptools.package-data] +vdbbench = ["*.py"] diff --git a/vdb_benchmark/tests/Makefile b/vdb_benchmark/tests/Makefile new file mode 100755 index 0000000..742886c --- /dev/null +++ b/vdb_benchmark/tests/Makefile @@ -0,0 +1,165 @@ +# Makefile for VDB-Bench Test Suite + +.PHONY: help install test test-all test-config test-connection test-loading \ + test-benchmark test-index test-monitoring test-performance \ + test-integration coverage coverage-html clean lint format \ + test-verbose test-failed test-parallel + +# Default target +help: + @echo "VDB-Bench Test Suite Makefile" + @echo "==============================" + @echo "" + @echo "Available targets:" + @echo " make install - Install test dependencies" + @echo " make test - Run all tests" + @echo " make test-verbose - Run tests with verbose output" + @echo " make test-parallel - Run tests in parallel" + @echo " make test-failed - Re-run only failed tests" + @echo "" + @echo "Test categories:" + @echo " make test-config - Run configuration tests" + @echo " make test-connection - Run connection tests" + @echo " make test-loading - Run loading tests" + @echo " make test-benchmark - Run benchmark tests" + @echo " make test-index - Run index management tests" + @echo " make test-monitoring - Run monitoring tests" + @echo "" + @echo "Special test suites:" + @echo " make test-performance - Run performance tests" + @echo " make test-integration - Run integration tests" + @echo "" + @echo "Coverage and reports:" + @echo " make coverage - Run tests with coverage" + @echo " make coverage-html - Generate HTML coverage report" + @echo "" + @echo "Code quality:" + @echo " make lint - Run code linting" + @echo " make format - Format code with black" + @echo "" + @echo "Maintenance:" + @echo " make clean - Clean test artifacts" + +# Installation +install: + pip install -r tests/requirements-test.txt + pip install -e . + +# Basic test execution +test: + python tests/run_tests.py + +test-all: test + +test-verbose: + python tests/run_tests.py --verbose + +test-parallel: + pytest tests/ -n auto --dist loadscope + +test-failed: + pytest tests/ --lf + +# Test categories +test-config: + python tests/run_tests.py --category config + +test-connection: + python tests/run_tests.py --category connection + +test-loading: + python tests/run_tests.py --category loading + +test-benchmark: + python tests/run_tests.py --category benchmark + +test-index: + python tests/run_tests.py --category index + +test-monitoring: + python tests/run_tests.py --category monitoring + +# Special test suites +test-performance: + python tests/run_tests.py --performance + +test-integration: + python tests/run_tests.py --integration + +# Coverage +coverage: + pytest tests/ --cov=vdbbench --cov-report=term --cov-report=html + +coverage-html: coverage + @echo "Opening coverage report in browser..." + @python -m webbrowser tests/htmlcov/index.html + +# Code quality +lint: + @echo "Running flake8..." + flake8 tests/ --max-line-length=100 --ignore=E203,W503 + @echo "Running pylint..." + pylint tests/ --max-line-length=100 --disable=C0111,R0903,R0913 + @echo "Running mypy..." + mypy tests/ --ignore-missing-imports + +format: + black tests/ --line-length=100 + isort tests/ --profile black --line-length=100 + +# Clean up +clean: + @echo "Cleaning test artifacts..." + rm -rf tests/__pycache__ + rm -rf tests/utils/__pycache__ + rm -rf tests/.pytest_cache + rm -rf tests/htmlcov + rm -rf tests/coverage_html + rm -f tests/.coverage + rm -f tests/test_results.xml + rm -f tests/test_results.json + rm -f tests/test_report.html + rm -f tests/*.pyc + rm -rf tests/**/*.pyc + find tests/ -type d -name "__pycache__" -exec rm -rf {} + 2>/dev/null || true + @echo "Clean complete!" + +# Watch mode (requires pytest-watch) +watch: + ptw tests/ -- --verbose + +# Run specific test file +test-file: + @read -p "Enter test file name (without .py): " file; \ + pytest tests/$$file.py -v + +# Run tests matching pattern +test-match: + @read -p "Enter test pattern: " pattern; \ + pytest tests/ -k "$$pattern" -v + +# Generate test report +report: + pytest tests/ --html=tests/test_report.html --self-contained-html + @echo "Test report generated at tests/test_report.html" + +# Check test coverage for specific module +coverage-module: + @read -p "Enter module name: " module; \ + pytest tests/ --cov=vdbbench.$$module --cov-report=term + +# Quick test (fast subset of tests) +test-quick: + pytest tests/ -m "not slow" --maxfail=1 -x + +# Full test suite with all checks +test-full: clean lint test-parallel coverage report + @echo "Full test suite complete!" + +# Continuous Integration target +ci: install lint test-parallel coverage + @echo "CI test suite complete!" + +# Development target (format, lint, and test) +dev: format lint test-verbose + @echo "Development test cycle complete!" diff --git a/vdb_benchmark/tests/README.md b/vdb_benchmark/tests/README.md new file mode 100755 index 0000000..f40c101 --- /dev/null +++ b/vdb_benchmark/tests/README.md @@ -0,0 +1,404 @@ +# VDB-Bench Test Suite + +Comprehensive unit test suite for the vdb-bench vector database benchmarking tool. + +## Overview + +This test suite provides extensive coverage for all components of vdb-bench, including: + +- Configuration management +- Database connections +- Vector generation and loading +- Index management +- Benchmarking operations +- Compaction and monitoring +- Performance metrics + +## Directory Structure + +``` +tests/ +├── __init__.py # Test suite package initialization +├── conftest.py # Pytest configuration and shared fixtures +├── run_tests.py # Main test runner script +├── requirements-test.txt # Testing dependencies +│ +├── test_config.py # Configuration management tests +├── test_database_connection.py # Database connection tests +├── test_load_vdb.py # Vector loading tests +├── test_vector_generation.py # Vector generation tests +├── test_index_management.py # Index management tests +├── test_simple_bench.py # Benchmarking functionality tests +├── test_compact_and_watch.py # Compaction and monitoring tests +│ +├── utils/ # Test utilities +│ ├── __init__.py +│ ├── test_helpers.py # Helper functions and utilities +│ └── mock_data.py # Mock data generators +│ +└── fixtures/ # Test fixtures + └── test_config.yaml # Sample configuration file +``` + +## Installation + +1. Install test dependencies: + +```bash +pip install -r tests/requirements-test.txt +``` + +2. Install vdb-bench in development mode: + +```bash +pip install -e . +``` + +## Running Tests + +### Run All Tests + +```bash +# Using pytest directly +pytest tests/ + +# Using the test runner +python tests/run_tests.py + +# With coverage +python tests/run_tests.py --verbose +``` + +### Run Specific Test Categories + +```bash +# Configuration tests +python tests/run_tests.py --category config + +# Connection tests +python tests/run_tests.py --category connection + +# Loading tests +python tests/run_tests.py --category loading + +# Benchmark tests +python tests/run_tests.py --category benchmark + +# Index management tests +python tests/run_tests.py --category index + +# Monitoring tests +python tests/run_tests.py --category monitoring +``` + +### Run Specific Test Modules + +```bash +# Run specific test files +python tests/run_tests.py --modules test_config test_load_vdb + +# Or using pytest +pytest tests/test_config.py tests/test_load_vdb.py +``` + +### Run Performance Tests + +```bash +# Run only performance-related tests +python tests/run_tests.py --performance + +# Or using pytest markers +pytest tests/ -k "performance or benchmark" +``` + +### Run with Verbose Output + +```bash +python tests/run_tests.py --verbose + +# Or with pytest +pytest tests/ -v +``` + +## Test Coverage + +### Generate Coverage Report + +```bash +# Run tests with coverage +pytest tests/ --cov=vdbbench --cov-report=html + +# Or using the test runner +python tests/run_tests.py # Coverage is enabled by default +``` + +### View Coverage Report + +After running tests with coverage, open the HTML report: + +```bash +# Open coverage report in browser +open tests/coverage_html/index.html +``` + +## Test Configuration + +### Environment Variables + +Set these environment variables to configure test behavior: + +```bash +# Database connection +export VDB_BENCH_TEST_HOST=localhost +export VDB_BENCH_TEST_PORT=19530 + +# Test data size +export VDB_BENCH_TEST_VECTORS=1000 +export VDB_BENCH_TEST_DIMENSION=128 + +# Performance test settings +export VDB_BENCH_TEST_TIMEOUT=60 +``` + +### Custom Test Configuration + +Create a custom test configuration file: + +```yaml +# tests/custom_config.yaml +test_settings: + use_mock_database: true + vector_count: 5000 + dimension: 256 + test_timeout: 30 +``` + +## Writing New Tests + +### Test Structure + +Follow this template for new test files: + +```python +""" +Unit tests for [component name] +""" +import pytest +from unittest.mock import Mock, patch +import numpy as np + +class TestComponentName: + """Test [component] functionality.""" + + def test_basic_operation(self): + """Test basic [operation].""" + # Test implementation + assert result == expected + + @pytest.mark.parametrize("input,expected", [ + (1, 2), + (2, 4), + (3, 6), + ]) + def test_parametrized(self, input, expected): + """Test with multiple inputs.""" + result = function_under_test(input) + assert result == expected + + @pytest.mark.skipif(condition, reason="Reason for skipping") + def test_conditional(self): + """Test that runs conditionally.""" + pass +``` + +### Using Fixtures + +Common fixtures are available in `conftest.py`: + +```python +def test_with_fixtures(mock_collection, sample_vectors, temp_config_file): + """Test using provided fixtures.""" + # mock_collection: Mock Milvus collection + # sample_vectors: Pre-generated test vectors + # temp_config_file: Temporary config file path + + result = process_vectors(mock_collection, sample_vectors) + assert result is not None +``` + +### Adding Mock Data + +Use mock data generators from `utils/mock_data.py`: + +```python +from tests.utils.mock_data import MockDataGenerator + +def test_with_mock_data(): + """Test using mock data generators.""" + generator = MockDataGenerator(seed=42) + + # Generate SIFT-like vectors + vectors = generator.generate_sift_like_vectors(1000, 128) + + # Generate deep learning embeddings + embeddings = generator.generate_deep_learning_embeddings( + 500, 768, model_type="bert" + ) +``` + +## Test Reports + +### HTML Report + +Tests automatically generate an HTML report: + +```bash +# View test report +open tests/test_report.html +``` + +### JUnit XML Report + +JUnit XML format for CI/CD integration: + +```bash +# Located at +tests/test_results.xml +``` + +### JSON Results + +Detailed test results in JSON format: + +```bash +# Located at +tests/test_results.json +``` + +## Continuous Integration + +### GitHub Actions Example + +```yaml +name: Tests + +on: [push, pull_request] + +jobs: + test: + runs-on: ubuntu-latest + + steps: + - uses: actions/checkout@v2 + + - name: Set up Python + uses: actions/setup-python@v2 + with: + python-version: '3.9' + + - name: Install dependencies + run: | + pip install -r tests/requirements-test.txt + pip install -e . + + - name: Run tests + run: python tests/run_tests.py --verbose + + - name: Upload coverage + uses: codecov/codecov-action@v2 +``` + +## Debugging Tests + +### Run Tests in Debug Mode + +```bash +# Run with pytest debugging +pytest tests/ --pdb + +# Run specific test with debugging +pytest tests/test_config.py::TestConfigurationLoader::test_load_valid_config --pdb +``` + +### Increase Verbosity + +```bash +# Maximum verbosity +pytest tests/ -vvv + +# Show print statements +pytest tests/ -s +``` + +### Run Failed Tests Only + +```bash +# Re-run only failed tests from last run +pytest tests/ --lf + +# Run failed tests first, then others +pytest tests/ --ff +``` + +## Performance Testing + +### Run Benchmark Tests + +```bash +# Run with benchmark plugin +pytest tests/ --benchmark-only + +# Save benchmark results +pytest tests/ --benchmark-save=results + +# Compare benchmark results +pytest tests/ --benchmark-compare=results +``` + +### Memory Profiling + +```bash +# Profile memory usage +python -m memory_profiler tests/test_load_vdb.py +``` + +## Best Practices + +1. **Isolation**: Each test should be independent +2. **Mocking**: Mock external dependencies (database, file I/O) +3. **Fixtures**: Use fixtures for common setup +4. **Parametrization**: Test multiple inputs with parametrize +5. **Assertions**: Use clear, specific assertions +6. **Documentation**: Document complex test logic +7. **Performance**: Keep tests fast (< 1 second each) +8. **Coverage**: Aim for >80% code coverage + +## Troubleshooting + +### Common Issues + +1. **Import Errors**: Ensure vdb-bench is installed in development mode +2. **Mock Failures**: Check that pymilvus mocks are properly configured +3. **Timeout Issues**: Increase timeout for slow tests +4. **Resource Issues**: Some tests may require more memory/CPU + +### Getting Help + +For issues or questions: +1. Check test logs in `tests/test_results.json` +2. Review HTML report at `tests/test_report.html` +3. Enable verbose mode for detailed output +4. Check fixture definitions in `conftest.py` + +## Contributing + +When contributing new features, please: +1. Add corresponding unit tests +2. Ensure all tests pass +3. Maintain or improve code coverage +4. Follow the existing test structure +5. Update this README if needed + +## License + +Same as vdb-bench main project. diff --git a/vdb_benchmark/tests/fixtures/test_config.yaml b/vdb_benchmark/tests/fixtures/test_config.yaml new file mode 100755 index 0000000..360f34f --- /dev/null +++ b/vdb_benchmark/tests/fixtures/test_config.yaml @@ -0,0 +1,54 @@ +# Test configuration for vdb-bench unit tests +database: + host: 127.0.0.1 + port: 19530 + database: test_milvus + timeout: 30 + max_receive_message_length: 514983574 + max_send_message_length: 514983574 + +dataset: + collection_name: test_collection_sample + num_vectors: 10000 + dimension: 128 + distribution: uniform + batch_size: 500 + chunk_size: 1000 + num_shards: 2 + vector_dtype: FLOAT_VECTOR + +index: + index_type: HNSW + metric_type: L2 + params: + M: 16 + efConstruction: 200 + ef: 64 + +benchmark: + num_queries: 1000 + top_k: 10 + batch_size: 100 + num_processes: 4 + runtime: 60 + warmup_queries: 100 + +monitoring: + enabled: true + interval: 5 + metrics: + - qps + - latency + - recall + - memory_usage + +workflow: + compact: true + compact_threshold: 0.2 + flush_interval: 10000 + auto_index: true + +logging: + level: INFO + file: test_benchmark.log + format: "%(asctime)s - %(name)s - %(levelname)s - %(message)s" diff --git a/vdb_benchmark/tests/requirements.txt b/vdb_benchmark/tests/requirements.txt new file mode 100755 index 0000000..32f8b91 --- /dev/null +++ b/vdb_benchmark/tests/requirements.txt @@ -0,0 +1,66 @@ +# Testing Dependencies for vdb-bench + +# Core testing frameworks +pytest>=7.4.0 +pytest-cov>=4.1.0 +pytest-html>=3.2.0 +pytest-xdist>=3.3.1 # For parallel test execution +pytest-timeout>=2.1.0 +pytest-mock>=3.11.1 + +# Coverage tools +coverage>=7.2.7 +coverage-badge>=1.1.0 + +# Mocking and fixtures +mock>=5.1.0 +faker>=19.2.0 +factory-boy>=3.3.0 + +# Data generation and manipulation +numpy>=1.24.3 +pandas>=2.0.3 +scipy>=1.11.1 + +# File handling +pyyaml>=6.0 +h5py>=3.9.0 + +# System monitoring (for testing monitoring features) +psutil>=5.9.5 + +# HTTP mocking (if needed for API tests) +responses>=0.23.1 +requests-mock>=1.11.0 + +# Async testing support +pytest-asyncio>=0.21.1 +aiofiles>=23.1.0 + +# Performance testing +pytest-benchmark>=4.0.0 +memory-profiler>=0.61.0 + +# Code quality +black>=23.7.0 +flake8>=6.0.0 +mypy>=1.4.1 +pylint>=2.17.4 + +# Documentation +sphinx>=7.0.1 +sphinx-rtd-theme>=1.2.2 + +# Milvus client (for integration tests) +pymilvus>=2.3.0 + +# Additional utilities +python-dotenv>=1.0.0 +click>=8.1.6 +colorama>=0.4.6 +tabulate>=0.9.0 +tqdm>=4.65.0 + +# Optional: for generating test reports +junitparser>=3.1.0 +allure-pytest>=2.13.2 diff --git a/vdb_benchmark/tests/tests/__init__.py b/vdb_benchmark/tests/tests/__init__.py new file mode 100755 index 0000000..241de82 --- /dev/null +++ b/vdb_benchmark/tests/tests/__init__.py @@ -0,0 +1,17 @@ +""" +VDB-Bench Test Suite + +Comprehensive unit tests for the vdb-bench vector database benchmarking tool. +""" + +__version__ = "1.0.0" + +# Test categories +TEST_CATEGORIES = [ + "configuration", + "database_connection", + "vector_loading", + "benchmarking", + "compaction", + "monitoring" +] diff --git a/vdb_benchmark/tests/tests/conftest.py b/vdb_benchmark/tests/tests/conftest.py new file mode 100755 index 0000000..48a0354 --- /dev/null +++ b/vdb_benchmark/tests/tests/conftest.py @@ -0,0 +1,180 @@ +""" +Pytest configuration and fixtures for vdb-bench tests +""" +import pytest +import yaml +import tempfile +import shutil +from pathlib import Path +from unittest.mock import Mock, MagicMock, patch +import numpy as np +from typing import Dict, Any, Generator +import os + +# Mock pymilvus if not installed +try: + from pymilvus import connections, Collection, utility +except ImportError: + connections = MagicMock() + Collection = MagicMock() + utility = MagicMock() + + +@pytest.fixture(scope="session") +def test_data_dir() -> Path: + """Create a temporary directory for test data that persists for the session.""" + temp_dir = Path(tempfile.mkdtemp(prefix="vdb_bench_test_")) + yield temp_dir + shutil.rmtree(temp_dir) + + +@pytest.fixture(scope="function") +def temp_config_file(test_data_dir) -> Generator[Path, None, None]: + """Create a temporary configuration file for testing.""" + config_path = test_data_dir / "test_config.yaml" + config_data = { + "database": { + "host": "127.0.0.1", + "port": 19530, + "database": "milvus_test", + "max_receive_message_length": 514983574, + "max_send_message_length": 514983574 + }, + "dataset": { + "collection_name": "test_collection", + "num_vectors": 1000, + "dimension": 128, + "distribution": "uniform", + "batch_size": 100, + "num_shards": 2, + "vector_dtype": "FLOAT_VECTOR" + }, + "index": { + "index_type": "DISKANN", + "metric_type": "COSINE", + "max_degree": 64, + "search_list_size": 200 + }, + "workflow": { + "compact": True + } + } + + with open(config_path, 'w') as f: + yaml.dump(config_data, f) + + yield config_path + + if config_path.exists(): + config_path.unlink() + + +@pytest.fixture +def mock_milvus_connection(): + """Mock Milvus connection for testing.""" + with patch('pymilvus.connections.connect') as mock_connect: + mock_connect.return_value = Mock() + yield mock_connect + + +@pytest.fixture +def mock_collection(): + """Mock Milvus collection for testing.""" + mock_coll = Mock(spec=Collection) + mock_coll.name = "test_collection" + mock_coll.schema = Mock() + mock_coll.num_entities = 1000 + mock_coll.insert = Mock(return_value=Mock(primary_keys=[1, 2, 3])) + mock_coll.create_index = Mock() + mock_coll.load = Mock() + mock_coll.release = Mock() + mock_coll.flush = Mock() + mock_coll.compact = Mock() + return mock_coll + + +@pytest.fixture +def sample_vectors() -> np.ndarray: + """Generate sample vectors for testing.""" + np.random.seed(42) + return np.random.randn(100, 128).astype(np.float32) + + +@pytest.fixture +def sample_config() -> Dict[str, Any]: + """Provide a sample configuration dictionary.""" + return { + "database": { + "host": "localhost", + "port": 19530, + "database": "default" + }, + "dataset": { + "collection_name": "test_vectors", + "num_vectors": 10000, + "dimension": 1536, + "distribution": "uniform", + "batch_size": 1000 + }, + "index": { + "index_type": "DISKANN", + "metric_type": "COSINE" + } + } + + +@pytest.fixture +def mock_time(): + """Mock time module for testing time-based operations.""" + with patch('time.time') as mock_time_func: + mock_time_func.side_effect = [0, 1, 2, 3, 4, 5] # Incremental time + yield mock_time_func + + +@pytest.fixture +def mock_multiprocessing(): + """Mock multiprocessing for testing parallel operations.""" + with patch('multiprocessing.Pool') as mock_pool: + mock_pool_instance = Mock() + mock_pool_instance.map = Mock(side_effect=lambda func, args: [func(arg) for arg in args]) + mock_pool_instance.close = Mock() + mock_pool_instance.join = Mock() + mock_pool.return_value.__enter__ = Mock(return_value=mock_pool_instance) + mock_pool.return_value.__exit__ = Mock(return_value=None) + yield mock_pool + + +@pytest.fixture +def benchmark_results(): + """Sample benchmark results for testing.""" + return { + "qps": 1250.5, + "latency_p50": 0.8, + "latency_p95": 1.2, + "latency_p99": 1.5, + "total_queries": 10000, + "runtime": 8.0, + "errors": 0 + } + + +@pytest.fixture(autouse=True) +def reset_milvus_connections(): + """Reset Milvus connections before each test.""" + connections.disconnect("default") + yield + connections.disconnect("default") + + +@pytest.fixture +def env_vars(): + """Set up environment variables for testing.""" + original_env = os.environ.copy() + + os.environ['VDB_BENCH_HOST'] = 'test_host' + os.environ['VDB_BENCH_PORT'] = '19530' + + yield os.environ + + os.environ.clear() + os.environ.update(original_env) diff --git a/vdb_benchmark/tests/tests/run_tests.py b/vdb_benchmark/tests/tests/run_tests.py new file mode 100755 index 0000000..a09766b --- /dev/null +++ b/vdb_benchmark/tests/tests/run_tests.py @@ -0,0 +1,346 @@ +#!/usr/bin/env python3 +""" +Comprehensive test runner for vdb-bench test suite +""" +import sys +import os +import argparse +import pytest +import coverage +from pathlib import Path +from typing import List, Optional +import json +import time +from datetime import datetime + + +class TestRunner: + """Main test runner for vdb-bench test suite.""" + + def __init__(self, test_dir: Path = None): + """Initialize test runner.""" + self.test_dir = test_dir or Path(__file__).parent + self.results = { + "start_time": None, + "end_time": None, + "duration": 0, + "total_tests": 0, + "passed": 0, + "failed": 0, + "skipped": 0, + "errors": 0, + "coverage": None + } + + def run_all_tests(self, verbose: bool = False, + coverage_enabled: bool = True) -> int: + """Run all tests with optional coverage.""" + print("=" * 60) + print("VDB-Bench Test Suite Runner") + print("=" * 60) + + self.results["start_time"] = datetime.now().isoformat() + start = time.time() + + # Setup coverage if enabled + cov = None + if coverage_enabled: + cov = coverage.Coverage() + cov.start() + print("Coverage tracking enabled") + + # Prepare pytest arguments + pytest_args = [ + str(self.test_dir), + "-v" if verbose else "-q", + "--tb=short", + "--color=yes", + f"--junitxml={self.test_dir}/test_results.xml", + f"--html={self.test_dir}/test_report.html", + "--self-contained-html" + ] + + # Run pytest + print(f"\nRunning tests from: {self.test_dir}") + print("-" * 60) + + exit_code = pytest.main(pytest_args) + + # Stop coverage and generate report + if cov: + cov.stop() + cov.save() + + # Generate coverage report + print("\n" + "=" * 60) + print("Coverage Report") + print("-" * 60) + + cov.report() + + # Save HTML coverage report + html_dir = self.test_dir / "coverage_html" + cov.html_report(directory=str(html_dir)) + print(f"\nHTML coverage report saved to: {html_dir}") + + # Get coverage percentage + self.results["coverage"] = cov.report(show_missing=False) + + # Update results + self.results["end_time"] = datetime.now().isoformat() + self.results["duration"] = time.time() - start + + # Parse test results + self._parse_test_results(exit_code) + + # Save results to JSON + self._save_results() + + # Print summary + self._print_summary() + + return exit_code + + def run_specific_tests(self, test_modules: List[str], + verbose: bool = False) -> int: + """Run specific test modules.""" + print("=" * 60) + print(f"Running specific tests: {', '.join(test_modules)}") + print("=" * 60) + + pytest_args = [] + for module in test_modules: + test_path = self.test_dir / f"{module}.py" + if test_path.exists(): + pytest_args.append(str(test_path)) + else: + print(f"Warning: Test module not found: {test_path}") + + if not pytest_args: + print("No valid test modules found!") + return 1 + + if verbose: + pytest_args.append("-v") + else: + pytest_args.append("-q") + + pytest_args.extend(["--tb=short", "--color=yes"]) + + return pytest.main(pytest_args) + + def run_by_category(self, category: str, verbose: bool = False) -> int: + """Run tests by category.""" + category_map = { + "config": ["test_config"], + "connection": ["test_database_connection"], + "loading": ["test_load_vdb", "test_vector_generation"], + "benchmark": ["test_simple_bench"], + "index": ["test_index_management"], + "monitoring": ["test_compact_and_watch"], + "all": None # Run all tests + } + + if category not in category_map: + print(f"Unknown category: {category}") + print(f"Available categories: {', '.join(category_map.keys())}") + return 1 + + if category == "all": + return self.run_all_tests(verbose=verbose) + + test_modules = category_map[category] + return self.run_specific_tests(test_modules, verbose=verbose) + + def run_performance_tests(self, verbose: bool = False) -> int: + """Run performance-related tests.""" + print("=" * 60) + print("Running Performance Tests") + print("=" * 60) + + pytest_args = [ + str(self.test_dir), + "-v" if verbose else "-q", + "-k", "performance or benchmark or throughput", + "--tb=short", + "--color=yes" + ] + + return pytest.main(pytest_args) + + def run_integration_tests(self, verbose: bool = False) -> int: + """Run integration tests.""" + print("=" * 60) + print("Running Integration Tests") + print("=" * 60) + + pytest_args = [ + str(self.test_dir), + "-v" if verbose else "-q", + "-m", "integration", + "--tb=short", + "--color=yes" + ] + + return pytest.main(pytest_args) + + def _parse_test_results(self, exit_code: int) -> None: + """Parse test results from pytest exit code.""" + # Basic result parsing based on exit code + if exit_code == 0: + self.results["status"] = "SUCCESS" + elif exit_code == 1: + self.results["status"] = "TESTS_FAILED" + elif exit_code == 2: + self.results["status"] = "INTERRUPTED" + elif exit_code == 3: + self.results["status"] = "INTERNAL_ERROR" + elif exit_code == 4: + self.results["status"] = "USAGE_ERROR" + elif exit_code == 5: + self.results["status"] = "NO_TESTS" + else: + self.results["status"] = "UNKNOWN_ERROR" + + # Try to parse XML results if available + xml_path = self.test_dir / "test_results.xml" + if xml_path.exists(): + try: + import xml.etree.ElementTree as ET + tree = ET.parse(xml_path) + root = tree.getroot() + + testsuite = root.find("testsuite") or root + self.results["total_tests"] = int(testsuite.get("tests", 0)) + self.results["failed"] = int(testsuite.get("failures", 0)) + self.results["errors"] = int(testsuite.get("errors", 0)) + self.results["skipped"] = int(testsuite.get("skipped", 0)) + self.results["passed"] = ( + self.results["total_tests"] - + self.results["failed"] - + self.results["errors"] - + self.results["skipped"] + ) + except Exception as e: + print(f"Warning: Could not parse XML results: {e}") + + def _save_results(self) -> None: + """Save test results to JSON file.""" + results_path = self.test_dir / "test_results.json" + + with open(results_path, 'w') as f: + json.dump(self.results, f, indent=2) + + print(f"\nTest results saved to: {results_path}") + + def _print_summary(self) -> None: + """Print test execution summary.""" + print("\n" + "=" * 60) + print("Test Execution Summary") + print("=" * 60) + + print(f"Status: {self.results.get('status', 'UNKNOWN')}") + print(f"Duration: {self.results['duration']:.2f} seconds") + print(f"Total Tests: {self.results['total_tests']}") + print(f"Passed: {self.results['passed']}") + print(f"Failed: {self.results['failed']}") + print(f"Errors: {self.results['errors']}") + print(f"Skipped: {self.results['skipped']}") + + if self.results.get("coverage"): + print(f"Code Coverage: {self.results['coverage']:.1f}%") + + print("=" * 60) + + # Print pass rate + if self.results['total_tests'] > 0: + pass_rate = (self.results['passed'] / self.results['total_tests']) * 100 + print(f"Pass Rate: {pass_rate:.1f}%") + + if pass_rate == 100: + print("✅ All tests passed!") + elif pass_rate >= 90: + print("⚠️ Most tests passed, but some failures detected.") + else: + print("❌ Significant test failures detected.") + + print("=" * 60) + + +def main(): + """Main entry point for test runner.""" + parser = argparse.ArgumentParser( + description="VDB-Bench Test Suite Runner", + formatter_class=argparse.RawDescriptionHelpFormatter + ) + + parser.add_argument( + "--category", "-c", + choices=["all", "config", "connection", "loading", + "benchmark", "index", "monitoring"], + default="all", + help="Test category to run" + ) + + parser.add_argument( + "--modules", "-m", + nargs="+", + help="Specific test modules to run" + ) + + parser.add_argument( + "--performance", "-p", + action="store_true", + help="Run performance tests only" + ) + + parser.add_argument( + "--integration", "-i", + action="store_true", + help="Run integration tests only" + ) + + parser.add_argument( + "--verbose", "-v", + action="store_true", + help="Verbose output" + ) + + parser.add_argument( + "--no-coverage", + action="store_true", + help="Disable coverage tracking" + ) + + parser.add_argument( + "--test-dir", + type=Path, + default=Path(__file__).parent, + help="Test directory path" + ) + + args = parser.parse_args() + + # Create test runner + runner = TestRunner(test_dir=args.test_dir) + + # Determine which tests to run + if args.modules: + exit_code = runner.run_specific_tests(args.modules, verbose=args.verbose) + elif args.performance: + exit_code = runner.run_performance_tests(verbose=args.verbose) + elif args.integration: + exit_code = runner.run_integration_tests(verbose=args.verbose) + elif args.category != "all": + exit_code = runner.run_by_category(args.category, verbose=args.verbose) + else: + exit_code = runner.run_all_tests( + verbose=args.verbose, + coverage_enabled=not args.no_coverage + ) + + sys.exit(exit_code) + + +if __name__ == "__main__": + main() diff --git a/vdb_benchmark/tests/tests/test_compact_and_watch.py b/vdb_benchmark/tests/tests/test_compact_and_watch.py new file mode 100755 index 0000000..fbc886f --- /dev/null +++ b/vdb_benchmark/tests/tests/test_compact_and_watch.py @@ -0,0 +1,701 @@ +""" +Unit tests for compaction and monitoring functionality in vdb-bench +""" +import pytest +import time +from unittest.mock import Mock, MagicMock, patch, call +import threading +from typing import Dict, Any, List +import json +from datetime import datetime, timedelta + + +class TestCompactionOperations: + """Test database compaction operations.""" + + def test_manual_compaction_trigger(self, mock_collection): + """Test manually triggering compaction.""" + mock_collection.compact.return_value = 1234 # Compaction ID + + def trigger_compaction(collection): + """Trigger manual compaction.""" + try: + compaction_id = collection.compact() + return { + "success": True, + "compaction_id": compaction_id, + "timestamp": time.time() + } + except Exception as e: + return { + "success": False, + "error": str(e) + } + + result = trigger_compaction(mock_collection) + + assert result["success"] is True + assert result["compaction_id"] == 1234 + assert "timestamp" in result + mock_collection.compact.assert_called_once() + + def test_compaction_state_monitoring(self, mock_collection): + """Test monitoring compaction state.""" + # Mock compaction state progression + states = ["Executing", "Executing", "Completed"] + state_iter = iter(states) + + def get_compaction_state(compaction_id): + try: + return next(state_iter) + except StopIteration: + return "Completed" + + mock_collection.get_compaction_state = Mock(side_effect=get_compaction_state) + + def monitor_compaction(collection, compaction_id, timeout=60): + """Monitor compaction until completion.""" + start_time = time.time() + states = [] + + while time.time() - start_time < timeout: + state = collection.get_compaction_state(compaction_id) + states.append({ + "state": state, + "timestamp": time.time() - start_time + }) + + if state == "Completed": + return { + "success": True, + "duration": time.time() - start_time, + "states": states + } + elif state == "Failed": + return { + "success": False, + "error": "Compaction failed", + "states": states + } + + time.sleep(0.1) # Check interval + + return { + "success": False, + "error": "Compaction timeout", + "states": states + } + + with patch('time.sleep'): # Speed up test + result = monitor_compaction(mock_collection, 1234) + + assert result["success"] is True + assert len(result["states"]) == 3 + assert result["states"][-1]["state"] == "Completed" + + def test_automatic_compaction_scheduling(self): + """Test automatic compaction scheduling based on conditions.""" + class CompactionScheduler: + def __init__(self, collection): + self.collection = collection + self.last_compaction = None + self.compaction_history = [] + + def should_compact(self, num_segments, deleted_ratio, time_since_last): + """Determine if compaction should be triggered.""" + # Compact if: + # - More than 10 segments + # - Deleted ratio > 20% + # - More than 1 hour since last compaction + + if num_segments > 10: + return True, "Too many segments" + + if deleted_ratio > 0.2: + return True, "High deletion ratio" + + if self.last_compaction and time_since_last > 3600: + return True, "Time-based compaction" + + return False, None + + def check_and_compact(self): + """Check conditions and trigger compaction if needed.""" + # Get collection stats (mocked here) + stats = { + "num_segments": 12, + "deleted_ratio": 0.15, + "last_compaction": self.last_compaction + } + + time_since_last = ( + time.time() - self.last_compaction + if self.last_compaction else float('inf') + ) + + should_compact, reason = self.should_compact( + stats["num_segments"], + stats["deleted_ratio"], + time_since_last + ) + + if should_compact: + compaction_id = self.collection.compact() + self.last_compaction = time.time() + self.compaction_history.append({ + "id": compaction_id, + "reason": reason, + "timestamp": self.last_compaction + }) + return True, reason + + return False, None + + mock_collection = Mock() + mock_collection.compact.return_value = 5678 + + scheduler = CompactionScheduler(mock_collection) + + # Should trigger compaction (too many segments) + compacted, reason = scheduler.check_and_compact() + + assert compacted is True + assert reason == "Too many segments" + assert len(scheduler.compaction_history) == 1 + mock_collection.compact.assert_called_once() + + def test_compaction_with_resource_monitoring(self): + """Test compaction with system resource monitoring.""" + import psutil + + class ResourceAwareCompaction: + def __init__(self, collection): + self.collection = collection + self.resource_thresholds = { + "cpu_percent": 80, + "memory_percent": 85, + "disk_io_rate": 100 # MB/s + } + + def check_resources(self): + """Check if system resources allow compaction.""" + cpu_percent = psutil.cpu_percent(interval=1) + memory_percent = psutil.virtual_memory().percent + + # Mock disk I/O rate + disk_io_rate = 50 # MB/s + + return { + "cpu_ok": cpu_percent < self.resource_thresholds["cpu_percent"], + "memory_ok": memory_percent < self.resource_thresholds["memory_percent"], + "disk_ok": disk_io_rate < self.resource_thresholds["disk_io_rate"], + "cpu_percent": cpu_percent, + "memory_percent": memory_percent, + "disk_io_rate": disk_io_rate + } + + def compact_with_resource_check(self): + """Perform compaction only if resources are available.""" + resource_status = self.check_resources() + + if all([resource_status["cpu_ok"], + resource_status["memory_ok"], + resource_status["disk_ok"]]): + + compaction_id = self.collection.compact() + return { + "success": True, + "compaction_id": compaction_id, + "resource_status": resource_status + } + else: + return { + "success": False, + "reason": "Resource constraints", + "resource_status": resource_status + } + + with patch('psutil.cpu_percent', return_value=50): + with patch('psutil.virtual_memory') as mock_memory: + mock_memory.return_value = Mock(percent=60) + + mock_collection = Mock() + mock_collection.compact.return_value = 9999 + + compactor = ResourceAwareCompaction(mock_collection) + result = compactor.compact_with_resource_check() + + assert result["success"] is True + assert result["compaction_id"] == 9999 + assert result["resource_status"]["cpu_ok"] is True + + +class TestMonitoring: + """Test monitoring functionality.""" + + def test_collection_stats_monitoring(self, mock_collection): + """Test monitoring collection statistics.""" + mock_collection.num_entities = 1000000 + + # Mock getting collection stats + def get_stats(): + return { + "num_entities": mock_collection.num_entities, + "num_segments": 10, + "index_building_progress": 95 + } + + mock_collection.get_stats = get_stats + + class StatsMonitor: + def __init__(self, collection): + self.collection = collection + self.stats_history = [] + + def collect_stats(self): + """Collect current statistics.""" + stats = self.collection.get_stats() + stats["timestamp"] = time.time() + self.stats_history.append(stats) + return stats + + def get_trends(self, window_size=10): + """Calculate trends from recent stats.""" + if len(self.stats_history) < 2: + return None + + recent = self.stats_history[-window_size:] + + # Calculate entity growth rate + if len(recent) >= 2: + time_diff = recent[-1]["timestamp"] - recent[0]["timestamp"] + entity_diff = recent[-1]["num_entities"] - recent[0]["num_entities"] + + growth_rate = entity_diff / time_diff if time_diff > 0 else 0 + + return { + "entity_growth_rate": growth_rate, + "avg_segments": sum(s["num_segments"] for s in recent) / len(recent), + "current_entities": recent[-1]["num_entities"] + } + + return None + + monitor = StatsMonitor(mock_collection) + + # Collect stats over time + for i in range(5): + mock_collection.num_entities += 10000 + stats = monitor.collect_stats() + time.sleep(0.01) # Small delay + + trends = monitor.get_trends() + + assert trends is not None + assert trends["current_entities"] == 1050000 # 1000000 + (5 * 10000) + assert len(monitor.stats_history) == 5 + + def test_periodic_monitoring(self): + """Test periodic monitoring with configurable intervals.""" + class PeriodicMonitor: + def __init__(self, collection, interval=5): + self.collection = collection + self.interval = interval + self.running = False + self.thread = None + self.data = [] + + def monitor_function(self): + """Function to run periodically.""" + stats = { + "timestamp": time.time(), + "num_entities": self.collection.num_entities, + "status": "healthy" + } + self.data.append(stats) + return stats + + def start(self): + """Start periodic monitoring.""" + self.running = True + + def run(): + while self.running: + self.monitor_function() + time.sleep(self.interval) + + self.thread = threading.Thread(target=run) + self.thread.daemon = True + self.thread.start() + + def stop(self): + """Stop periodic monitoring.""" + self.running = False + if self.thread: + self.thread.join(timeout=1) + + def get_latest(self, n=5): + """Get latest n monitoring results.""" + return self.data[-n:] if self.data else [] + + mock_collection = Mock() + mock_collection.num_entities = 1000000 + + monitor = PeriodicMonitor(mock_collection, interval=0.01) # Fast interval for testing + + monitor.start() + time.sleep(0.05) # Let it collect some data + monitor.stop() + + latest = monitor.get_latest() + + assert len(latest) > 0 + assert all("timestamp" in item for item in latest) + + def test_alert_system(self): + """Test alert system for monitoring thresholds.""" + class AlertSystem: + def __init__(self): + self.alerts = [] + self.thresholds = { + "high_latency": 100, # ms + "low_qps": 50, + "high_error_rate": 0.05, + "segment_count": 20 + } + self.alert_callbacks = [] + + def check_metric(self, metric_name, value): + """Check if metric exceeds threshold.""" + if metric_name == "latency" and value > self.thresholds["high_latency"]: + self.trigger_alert("HIGH_LATENCY", f"Latency {value}ms exceeds threshold") + + elif metric_name == "qps" and value < self.thresholds["low_qps"]: + self.trigger_alert("LOW_QPS", f"QPS {value} below threshold") + + elif metric_name == "error_rate" and value > self.thresholds["high_error_rate"]: + self.trigger_alert("HIGH_ERROR_RATE", f"Error rate {value:.2%} exceeds threshold") + + elif metric_name == "segments" and value > self.thresholds["segment_count"]: + self.trigger_alert("TOO_MANY_SEGMENTS", f"Segment count {value} exceeds threshold") + + def trigger_alert(self, alert_type, message): + """Trigger an alert.""" + alert = { + "type": alert_type, + "message": message, + "timestamp": time.time(), + "resolved": False + } + + self.alerts.append(alert) + + # Call registered callbacks + for callback in self.alert_callbacks: + callback(alert) + + return alert + + def resolve_alert(self, alert_type): + """Mark alerts of given type as resolved.""" + for alert in self.alerts: + if alert["type"] == alert_type and not alert["resolved"]: + alert["resolved"] = True + alert["resolved_time"] = time.time() + + def register_callback(self, callback): + """Register callback for alerts.""" + self.alert_callbacks.append(callback) + + def get_active_alerts(self): + """Get list of active (unresolved) alerts.""" + return [a for a in self.alerts if not a["resolved"]] + + alert_system = AlertSystem() + + # Register a callback + received_alerts = [] + alert_system.register_callback(lambda alert: received_alerts.append(alert)) + + # Test various metrics + alert_system.check_metric("latency", 150) # Should trigger + alert_system.check_metric("qps", 100) # Should not trigger + alert_system.check_metric("error_rate", 0.1) # Should trigger + alert_system.check_metric("segments", 25) # Should trigger + + active = alert_system.get_active_alerts() + + assert len(active) == 3 + assert len(received_alerts) == 3 + assert any(a["type"] == "HIGH_LATENCY" for a in active) + + # Resolve an alert + alert_system.resolve_alert("HIGH_LATENCY") + active = alert_system.get_active_alerts() + + assert len(active) == 2 + + def test_monitoring_data_aggregation(self): + """Test aggregating monitoring data over time windows.""" + class DataAggregator: + def __init__(self): + self.raw_data = [] + + def add_data_point(self, timestamp, metrics): + """Add a data point.""" + self.raw_data.append({ + "timestamp": timestamp, + **metrics + }) + + def aggregate_window(self, start_time, end_time, aggregation="avg"): + """Aggregate data within a time window.""" + window_data = [ + d for d in self.raw_data + if start_time <= d["timestamp"] <= end_time + ] + + if not window_data: + return None + + if aggregation == "avg": + return self._average_aggregation(window_data) + elif aggregation == "max": + return self._max_aggregation(window_data) + elif aggregation == "min": + return self._min_aggregation(window_data) + else: + return window_data + + def _average_aggregation(self, data): + """Calculate average of metrics.""" + result = {"count": len(data)} + + # Get all metric keys (excluding timestamp) + metric_keys = [k for k in data[0].keys() if k != "timestamp"] + + for key in metric_keys: + values = [d[key] for d in data if key in d] + result[f"{key}_avg"] = sum(values) / len(values) if values else 0 + + return result + + def _max_aggregation(self, data): + """Get maximum values of metrics.""" + result = {"count": len(data)} + + metric_keys = [k for k in data[0].keys() if k != "timestamp"] + + for key in metric_keys: + values = [d[key] for d in data if key in d] + result[f"{key}_max"] = max(values) if values else 0 + + return result + + def _min_aggregation(self, data): + """Get minimum values of metrics.""" + result = {"count": len(data)} + + metric_keys = [k for k in data[0].keys() if k != "timestamp"] + + for key in metric_keys: + values = [d[key] for d in data if key in d] + result[f"{key}_min"] = min(values) if values else 0 + + return result + + def create_time_series(self, metric_name, interval=60): + """Create time series data for a specific metric.""" + if not self.raw_data: + return [] + + min_time = min(d["timestamp"] for d in self.raw_data) + max_time = max(d["timestamp"] for d in self.raw_data) + + time_series = [] + current_time = min_time + + while current_time <= max_time: + window_end = current_time + interval + window_data = [ + d for d in self.raw_data + if current_time <= d["timestamp"] < window_end + and metric_name in d + ] + + if window_data: + avg_value = sum(d[metric_name] for d in window_data) / len(window_data) + time_series.append({ + "timestamp": current_time, + "value": avg_value + }) + + current_time = window_end + + return time_series + + aggregator = DataAggregator() + + # Add sample data points + base_time = time.time() + for i in range(100): + aggregator.add_data_point( + base_time + i, + { + "qps": 100 + i % 20, + "latency": 10 + i % 5, + "error_count": i % 3 + } + ) + + # Test aggregation + avg_metrics = aggregator.aggregate_window(base_time, base_time + 50, "avg") + assert avg_metrics is not None + assert "qps_avg" in avg_metrics + assert avg_metrics["count"] == 51 + + # Test time series creation + time_series = aggregator.create_time_series("qps", interval=10) + assert len(time_series) > 0 + assert all("timestamp" in point and "value" in point for point in time_series) + + +class TestWatchOperations: + """Test watch operations for monitoring database state.""" + + def test_index_building_watch(self, mock_collection): + """Test watching index building progress.""" + progress_values = [0, 25, 50, 75, 100] + progress_iter = iter(progress_values) + + def get_index_progress(): + try: + return next(progress_iter) + except StopIteration: + return 100 + + mock_collection.index.get_build_progress = Mock(side_effect=get_index_progress) + + class IndexWatcher: + def __init__(self, collection): + self.collection = collection + self.progress_history = [] + + def watch_build(self, check_interval=1): + """Watch index building until completion.""" + while True: + progress = self.collection.index.get_build_progress() + self.progress_history.append({ + "progress": progress, + "timestamp": time.time() + }) + + if progress >= 100: + return { + "completed": True, + "final_progress": progress, + "history": self.progress_history + } + + time.sleep(check_interval) + + mock_collection.index = Mock() + mock_collection.index.get_build_progress = Mock(side_effect=get_index_progress) + + watcher = IndexWatcher(mock_collection) + + with patch('time.sleep'): # Speed up test + result = watcher.watch_build() + + assert result["completed"] is True + assert result["final_progress"] == 100 + assert len(result["history"]) == 5 + + def test_segment_merge_watch(self): + """Test watching segment merge operations.""" + class SegmentMergeWatcher: + def __init__(self): + self.merge_operations = [] + self.active_merges = {} + + def start_merge(self, segments): + """Start watching a segment merge.""" + merge_id = f"merge_{len(self.merge_operations)}" + + merge_op = { + "id": merge_id, + "segments": segments, + "start_time": time.time(), + "status": "running", + "progress": 0 + } + + self.merge_operations.append(merge_op) + self.active_merges[merge_id] = merge_op + + return merge_id + + def update_progress(self, merge_id, progress): + """Update merge progress.""" + if merge_id in self.active_merges: + self.active_merges[merge_id]["progress"] = progress + + if progress >= 100: + self.complete_merge(merge_id) + + def complete_merge(self, merge_id): + """Mark merge as completed.""" + if merge_id in self.active_merges: + merge_op = self.active_merges[merge_id] + merge_op["status"] = "completed" + merge_op["end_time"] = time.time() + merge_op["duration"] = merge_op["end_time"] - merge_op["start_time"] + + del self.active_merges[merge_id] + + return merge_op + + return None + + def get_active_merges(self): + """Get list of active merge operations.""" + return list(self.active_merges.values()) + + def get_merge_stats(self): + """Get statistics about merge operations.""" + completed = [m for m in self.merge_operations if m["status"] == "completed"] + + if not completed: + return None + + durations = [m["duration"] for m in completed] + + return { + "total_merges": len(self.merge_operations), + "completed_merges": len(completed), + "active_merges": len(self.active_merges), + "avg_duration": sum(durations) / len(durations) if durations else 0, + "min_duration": min(durations) if durations else 0, + "max_duration": max(durations) if durations else 0 + } + + watcher = SegmentMergeWatcher() + + # Start multiple merges + merge1 = watcher.start_merge(["seg1", "seg2"]) + merge2 = watcher.start_merge(["seg3", "seg4"]) + + assert len(watcher.get_active_merges()) == 2 + + # Update progress + watcher.update_progress(merge1, 50) + watcher.update_progress(merge2, 100) # Complete this one + + assert len(watcher.get_active_merges()) == 1 + + # Complete remaining merge + watcher.update_progress(merge1, 100) + + stats = watcher.get_merge_stats() + assert stats["completed_merges"] == 2 + assert stats["active_merges"] == 0 diff --git a/vdb_benchmark/tests/tests/test_config.py b/vdb_benchmark/tests/tests/test_config.py new file mode 100755 index 0000000..725976a --- /dev/null +++ b/vdb_benchmark/tests/tests/test_config.py @@ -0,0 +1,359 @@ +""" +Unit tests for configuration management in vdb-bench +""" +import pytest +import yaml +from pathlib import Path +from typing import Dict, Any +import os +from unittest.mock import patch, mock_open, MagicMock + + +class TestConfigurationLoader: + """Test configuration loading and validation.""" + + def test_load_valid_config(self, temp_config_file): + """Test loading a valid configuration file.""" + # Mock the config loading function + with open(temp_config_file, 'r') as f: + config = yaml.safe_load(f) + + assert config is not None + assert 'database' in config + assert 'dataset' in config + assert 'index' in config + assert config['database']['host'] == '127.0.0.1' + assert config['dataset']['num_vectors'] == 1000 + + def test_load_missing_config_file(self): + """Test handling of missing configuration file.""" + non_existent_file = Path("/tmp/non_existent_config.yaml") + + with pytest.raises(FileNotFoundError): + with open(non_existent_file, 'r') as f: + yaml.safe_load(f) + + def test_load_invalid_yaml(self, test_data_dir): + """Test handling of invalid YAML syntax.""" + invalid_yaml_path = test_data_dir / "invalid.yaml" + + with open(invalid_yaml_path, 'w') as f: + f.write("invalid: yaml: content: [") + + with pytest.raises(yaml.YAMLError): + with open(invalid_yaml_path, 'r') as f: + yaml.safe_load(f) + + def test_config_validation_missing_required_fields(self): + """Test validation when required configuration fields are missing.""" + incomplete_config = { + "database": { + "host": "localhost" + # Missing port and other required fields + } + } + + # Mock validation function + def validate_config(config): + required_fields = ['port', 'database'] + for field in required_fields: + if field not in config.get('database', {}): + raise ValueError(f"Missing required field: database.{field}") + + with pytest.raises(ValueError, match="Missing required field"): + validate_config(incomplete_config) + + def test_config_validation_invalid_values(self): + """Test validation of configuration values.""" + invalid_config = { + "database": { + "host": "localhost", + "port": -1, # Invalid port + "database": "milvus" + }, + "dataset": { + "num_vectors": -100, # Invalid negative value + "dimension": 0, # Invalid dimension + "batch_size": 0 # Invalid batch size + } + } + + def validate_config_values(config): + if config['database']['port'] < 1 or config['database']['port'] > 65535: + raise ValueError("Invalid port number") + if config['dataset']['num_vectors'] <= 0: + raise ValueError("Number of vectors must be positive") + if config['dataset']['dimension'] <= 0: + raise ValueError("Vector dimension must be positive") + if config['dataset']['batch_size'] <= 0: + raise ValueError("Batch size must be positive") + + with pytest.raises(ValueError): + validate_config_values(invalid_config) + + def test_config_merge_with_defaults(self): + """Test merging user configuration with defaults.""" + default_config = { + "database": { + "host": "localhost", + "port": 19530, + "timeout": 30 + }, + "dataset": { + "batch_size": 1000, + "distribution": "uniform" + } + } + + user_config = { + "database": { + "host": "remote-host", + "port": 8080 + }, + "dataset": { + "batch_size": 500 + } + } + + def merge_configs(default, user): + """Deep merge user config into default config.""" + merged = default.copy() + for key, value in user.items(): + if key in merged and isinstance(merged[key], dict) and isinstance(value, dict): + merged[key] = merge_configs(merged[key], value) + else: + merged[key] = value + return merged + + merged = merge_configs(default_config, user_config) + + assert merged['database']['host'] == 'remote-host' + assert merged['database']['port'] == 8080 + assert merged['database']['timeout'] == 30 # From default + assert merged['dataset']['batch_size'] == 500 + assert merged['dataset']['distribution'] == 'uniform' # From default + + def test_config_environment_variable_override(self, sample_config): + """Test overriding configuration with environment variables.""" + import copy + + os.environ['VDB_BENCH_DATABASE_HOST'] = 'env-host' + os.environ['VDB_BENCH_DATABASE_PORT'] = '9999' + os.environ['VDB_BENCH_DATASET_NUM_VECTORS'] = '5000' + + def apply_env_overrides(config): + """Apply environment variable overrides to configuration.""" + # Make a deep copy to avoid modifying original + result = copy.deepcopy(config) + env_prefix = 'VDB_BENCH_' + + for key, value in os.environ.items(): + if key.startswith(env_prefix): + # Parse the environment variable name + parts = key[len(env_prefix):].lower().split('_') + + # Special handling for num_vectors (DATASET_NUM_VECTORS) + if len(parts) >= 2 and parts[0] == 'dataset' and parts[1] == 'num' and len(parts) == 3 and parts[2] == 'vectors': + if 'dataset' not in result: + result['dataset'] = {} + result['dataset']['num_vectors'] = int(value) + else: + # Navigate to the config section for other keys + current = result + for part in parts[:-1]: + if part not in current: + current[part] = {} + current = current[part] + + # Set the value (with type conversion) + final_key = parts[-1] + if value.isdigit(): + current[final_key] = int(value) + else: + current[final_key] = value + + return result + + config = apply_env_overrides(sample_config) + + assert config['database']['host'] == 'env-host' + assert config['database']['port'] == 9999 + assert config['dataset']['num_vectors'] == 5000 + + # Clean up environment variables + del os.environ['VDB_BENCH_DATABASE_HOST'] + del os.environ['VDB_BENCH_DATABASE_PORT'] + del os.environ['VDB_BENCH_DATASET_NUM_VECTORS'] + + def test_config_save(self, test_data_dir): + """Test saving configuration to file.""" + config = { + "database": {"host": "localhost", "port": 19530}, + "dataset": {"collection_name": "test", "dimension": 128} + } + + save_path = test_data_dir / "saved_config.yaml" + + with open(save_path, 'w') as f: + yaml.dump(config, f) + + # Verify saved file + with open(save_path, 'r') as f: + loaded_config = yaml.safe_load(f) + + assert loaded_config == config + + def test_config_schema_validation(self): + """Test configuration schema validation.""" + schema = { + "database": { + "type": "dict", + "required": ["host", "port"], + "properties": { + "host": {"type": "string"}, + "port": {"type": "integer", "min": 1, "max": 65535} + } + }, + "dataset": { + "type": "dict", + "required": ["dimension"], + "properties": { + "dimension": {"type": "integer", "min": 1} + } + } + } + + def validate_against_schema(config, schema): + """Basic schema validation.""" + for key, rules in schema.items(): + if rules.get("type") == "dict": + if key not in config: + if "required" in rules: + raise ValueError(f"Missing required section: {key}") + continue + + if "required" in rules: + for req_field in rules["required"]: + if req_field not in config[key]: + raise ValueError(f"Missing required field: {key}.{req_field}") + + if "properties" in rules: + for prop, prop_rules in rules["properties"].items(): + if prop in config[key]: + value = config[key][prop] + if "type" in prop_rules: + if prop_rules["type"] == "integer" and not isinstance(value, int): + raise TypeError(f"{key}.{prop} must be an integer") + if prop_rules["type"] == "string" and not isinstance(value, str): + raise TypeError(f"{key}.{prop} must be a string") + + if "min" in prop_rules and value < prop_rules["min"]: + raise ValueError(f"{key}.{prop} must be >= {prop_rules['min']}") + if "max" in prop_rules and value > prop_rules["max"]: + raise ValueError(f"{key}.{prop} must be <= {prop_rules['max']}") + + # Valid config + valid_config = { + "database": {"host": "localhost", "port": 19530}, + "dataset": {"dimension": 128} + } + + validate_against_schema(valid_config, schema) # Should not raise + + # Invalid config (missing required field) + invalid_config = { + "database": {"host": "localhost"}, # Missing port + "dataset": {"dimension": 128} + } + + with pytest.raises(ValueError, match="Missing required field"): + validate_against_schema(invalid_config, schema) + + +class TestIndexConfiguration: + """Test index-specific configuration handling.""" + + def test_diskann_config_validation(self): + """Test DiskANN index configuration validation.""" + valid_diskann_config = { + "index_type": "DISKANN", + "metric_type": "COSINE", + "max_degree": 64, + "search_list_size": 200, + "pq_code_budget_gb": 0.1, + "build_algo": "IVF_PQ" + } + + def validate_diskann_config(config): + assert config["index_type"] == "DISKANN" + assert config["metric_type"] in ["L2", "IP", "COSINE"] + assert 1 <= config["max_degree"] <= 128 + assert 100 <= config["search_list_size"] <= 1000 + if "pq_code_budget_gb" in config: + assert config["pq_code_budget_gb"] > 0 + + validate_diskann_config(valid_diskann_config) + + # Invalid max_degree + invalid_config = valid_diskann_config.copy() + invalid_config["max_degree"] = 200 + + with pytest.raises(AssertionError): + validate_diskann_config(invalid_config) + + def test_hnsw_config_validation(self): + """Test HNSW index configuration validation.""" + valid_hnsw_config = { + "index_type": "HNSW", + "metric_type": "L2", + "M": 16, + "efConstruction": 200 + } + + def validate_hnsw_config(config): + assert config["index_type"] == "HNSW" + assert config["metric_type"] in ["L2", "IP", "COSINE"] + assert 4 <= config["M"] <= 64 + assert 8 <= config["efConstruction"] <= 512 + + validate_hnsw_config(valid_hnsw_config) + + # Invalid M value + invalid_config = valid_hnsw_config.copy() + invalid_config["M"] = 100 + + with pytest.raises(AssertionError): + validate_hnsw_config(invalid_config) + + def test_auto_index_config_selection(self): + """Test automatic index configuration based on dataset size.""" + def select_index_config(num_vectors, dimension): + if num_vectors < 100000: + return { + "index_type": "IVF_FLAT", + "nlist": 128 + } + elif num_vectors < 1000000: + return { + "index_type": "HNSW", + "M": 16, + "efConstruction": 200 + } + else: + return { + "index_type": "DISKANN", + "max_degree": 64, + "search_list_size": 200 + } + + # Small dataset + config = select_index_config(50000, 128) + assert config["index_type"] == "IVF_FLAT" + + # Medium dataset + config = select_index_config(500000, 256) + assert config["index_type"] == "HNSW" + + # Large dataset + config = select_index_config(10000000, 1536) + assert config["index_type"] == "DISKANN" diff --git a/vdb_benchmark/tests/tests/test_database_connection.py b/vdb_benchmark/tests/tests/test_database_connection.py new file mode 100755 index 0000000..538c588 --- /dev/null +++ b/vdb_benchmark/tests/tests/test_database_connection.py @@ -0,0 +1,538 @@ +""" +Unit tests for Milvus database connection management +""" +import pytest +from unittest.mock import Mock, MagicMock, patch, call +import time +from typing import Dict, Any + + +class TestDatabaseConnection: + """Test database connection management.""" + + @patch('pymilvus.connections.connect') + def test_successful_connection(self, mock_connect): + """Test successful connection to Milvus.""" + mock_connect.return_value = True + + def connect_to_milvus(host="localhost", port=19530, **kwargs): + from pymilvus import connections + return connections.connect( + alias="default", + host=host, + port=port, + **kwargs + ) + + result = connect_to_milvus("localhost", 19530) + assert result is True + mock_connect.assert_called_once_with( + alias="default", + host="localhost", + port=19530 + ) + + @patch('pymilvus.connections.connect') + def test_connection_with_timeout(self, mock_connect): + """Test connection with custom timeout.""" + mock_connect.return_value = True + + def connect_with_timeout(host, port, timeout=30): + from pymilvus import connections + return connections.connect( + alias="default", + host=host, + port=port, + timeout=timeout + ) + + connect_with_timeout("localhost", 19530, timeout=60) + mock_connect.assert_called_with( + alias="default", + host="localhost", + port=19530, + timeout=60 + ) + + @patch('pymilvus.connections.connect') + def test_connection_failure(self, mock_connect): + """Test handling of connection failures.""" + mock_connect.side_effect = Exception("Connection refused") + + def connect_to_milvus(host, port): + from pymilvus import connections + try: + return connections.connect(alias="default", host=host, port=port) + except Exception as e: + return f"Failed to connect: {e}" + + result = connect_to_milvus("localhost", 19530) + assert "Failed to connect" in result + assert "Connection refused" in result + + @patch('pymilvus.connections.connect') + def test_connection_retry_logic(self, mock_connect): + """Test connection retry mechanism.""" + # Fail twice, then succeed + mock_connect.side_effect = [ + Exception("Connection failed"), + Exception("Connection failed"), + True + ] + + def connect_with_retry(host, port, max_retries=3, retry_delay=1): + from pymilvus import connections + + for attempt in range(max_retries): + try: + return connections.connect( + alias="default", + host=host, + port=port + ) + except Exception as e: + if attempt == max_retries - 1: + raise + time.sleep(retry_delay) + + return False + + with patch('time.sleep'): # Mock sleep to speed up test + result = connect_with_retry("localhost", 19530) + assert result is True + assert mock_connect.call_count == 3 + + @patch('pymilvus.connections.list_connections') + def test_list_connections(self, mock_list): + """Test listing active connections.""" + mock_list.return_value = [ + ("default", {"host": "localhost", "port": 19530}), + ("secondary", {"host": "remote", "port": 8080}) + ] + + def get_active_connections(): + from pymilvus import connections + return connections.list_connections() + + connections_list = get_active_connections() + assert len(connections_list) == 2 + assert connections_list[0][0] == "default" + assert connections_list[1][1]["host"] == "remote" + + @patch('pymilvus.connections.disconnect') + def test_disconnect(self, mock_disconnect): + """Test disconnecting from Milvus.""" + mock_disconnect.return_value = None + + def disconnect_from_milvus(alias="default"): + from pymilvus import connections + connections.disconnect(alias) + return True + + result = disconnect_from_milvus() + assert result is True + mock_disconnect.assert_called_once_with("default") + + @patch('pymilvus.connections.connect') + def test_connection_pool(self, mock_connect): + """Test connection pooling behavior.""" + mock_connect.return_value = True + + class ConnectionPool: + def __init__(self, max_connections=5): + self.max_connections = max_connections + self.connections = [] + self.available = [] + + def get_connection(self): + if self.available: + return self.available.pop() + elif len(self.connections) < self.max_connections: + from pymilvus import connections + conn = connections.connect( + alias=f"conn_{len(self.connections)}", + host="localhost", + port=19530 + ) + self.connections.append(conn) + return conn + else: + raise Exception("Connection pool exhausted") + + def return_connection(self, conn): + self.available.append(conn) + + def close_all(self): + for conn in self.connections: + # In real code, would disconnect each connection + pass + self.connections.clear() + self.available.clear() + + pool = ConnectionPool(max_connections=3) + + # Get connections + conn1 = pool.get_connection() + conn2 = pool.get_connection() + conn3 = pool.get_connection() + + # Pool should be exhausted + with pytest.raises(Exception, match="Connection pool exhausted"): + pool.get_connection() + + # Return a connection + pool.return_connection(conn1) + + # Should be able to get a connection now + conn4 = pool.get_connection() + assert conn4 == conn1 # Should reuse the returned connection + + @patch('pymilvus.connections.connect') + def test_connection_with_authentication(self, mock_connect): + """Test connection with authentication credentials.""" + mock_connect.return_value = True + + def connect_with_auth(host, port, user, password): + from pymilvus import connections + return connections.connect( + alias="default", + host=host, + port=port, + user=user, + password=password + ) + + connect_with_auth("localhost", 19530, "admin", "password123") + + mock_connect.assert_called_with( + alias="default", + host="localhost", + port=19530, + user="admin", + password="password123" + ) + + @patch('pymilvus.connections.connect') + def test_connection_health_check(self, mock_connect): + """Test connection health check mechanism.""" + mock_connect.return_value = True + + class MilvusConnection: + def __init__(self, host, port): + self.host = host + self.port = port + self.connected = False + self.last_health_check = 0 + + def connect(self): + from pymilvus import connections + try: + connections.connect( + alias="health_check", + host=self.host, + port=self.port + ) + self.connected = True + return True + except: + self.connected = False + return False + + def health_check(self): + """Perform a health check on the connection.""" + current_time = time.time() + + # Only check every 30 seconds + if current_time - self.last_health_check < 30: + return self.connected + + self.last_health_check = current_time + + # Try a simple operation to verify connection + try: + # In real code, would perform a lightweight operation + # like checking server status + return self.connected + except: + self.connected = False + return False + + def ensure_connected(self): + """Ensure connection is active, reconnect if needed.""" + if not self.health_check(): + return self.connect() + return True + + conn = MilvusConnection("localhost", 19530) + assert conn.connect() is True + assert conn.health_check() is True + assert conn.ensure_connected() is True + + +class TestCollectionManagement: + """Test Milvus collection management operations.""" + + @patch('pymilvus.Collection') + def test_create_collection(self, mock_collection_class): + """Test creating a new collection.""" + mock_collection = Mock() + mock_collection_class.return_value = mock_collection + + def create_collection(name, dimension, metric_type="L2"): + from pymilvus import Collection, FieldSchema, CollectionSchema, DataType + + # Define schema + fields = [ + FieldSchema(name="id", dtype=DataType.INT64, is_primary=True, auto_id=True), + FieldSchema(name="embedding", dtype=DataType.FLOAT_VECTOR, dim=dimension) + ] + schema = CollectionSchema(fields, description=f"Collection {name}") + + # Create collection + collection = Collection(name=name, schema=schema) + return collection + + coll = create_collection("test_collection", 128) + assert coll is not None + mock_collection_class.assert_called_once() + + @patch('pymilvus.utility.has_collection') + def test_check_collection_exists(self, mock_has_collection): + """Test checking if a collection exists.""" + mock_has_collection.return_value = True + + def collection_exists(collection_name): + from pymilvus import utility + return utility.has_collection(collection_name) + + exists = collection_exists("test_collection") + assert exists is True + mock_has_collection.assert_called_once_with("test_collection") + + @patch('pymilvus.Collection') + def test_drop_collection(self, mock_collection_class): + """Test dropping a collection.""" + mock_collection = Mock() + mock_collection.drop = Mock() + mock_collection_class.return_value = mock_collection + + def drop_collection(collection_name): + from pymilvus import Collection + collection = Collection(collection_name) + collection.drop() + return True + + result = drop_collection("test_collection") + assert result is True + mock_collection.drop.assert_called_once() + + @patch('pymilvus.utility.list_collections') + def test_list_collections(self, mock_list_collections): + """Test listing all collections.""" + mock_list_collections.return_value = [ + "collection1", + "collection2", + "collection3" + ] + + def get_all_collections(): + from pymilvus import utility + return utility.list_collections() + + collections = get_all_collections() + assert len(collections) == 3 + assert "collection1" in collections + + def test_collection_with_partitions(self, mock_collection): + """Test creating and managing collection partitions.""" + mock_collection.create_partition = Mock() + mock_collection.has_partition = Mock(return_value=False) + mock_collection.partitions = [] + + def create_partitions(collection, partition_names): + for name in partition_names: + if not collection.has_partition(name): + collection.create_partition(name) + collection.partitions.append(name) + return collection.partitions + + partitions = create_partitions(mock_collection, ["partition1", "partition2"]) + assert len(partitions) == 2 + assert mock_collection.create_partition.call_count == 2 + + def test_collection_properties(self, mock_collection): + """Test getting collection properties.""" + mock_collection.num_entities = 10000 + mock_collection.description = "Test collection" + mock_collection.name = "test_coll" + mock_collection.schema = Mock() + + def get_collection_info(collection): + return { + "name": collection.name, + "description": collection.description, + "num_entities": collection.num_entities, + "schema": collection.schema + } + + info = get_collection_info(mock_collection) + assert info["name"] == "test_coll" + assert info["num_entities"] == 10000 + assert info["description"] == "Test collection" + + +class TestConnectionResilience: + """Test connection resilience and error recovery.""" + + @patch('pymilvus.connections.connect') + def test_automatic_reconnection(self, mock_connect): + """Test automatic reconnection after connection loss.""" + # Simulate connection loss and recovery + mock_connect.side_effect = [ + True, # Initial connection + Exception("Connection lost"), # Connection drops + Exception("Still disconnected"), # First retry fails + True # Reconnection succeeds + ] + + class ResilientConnection: + def __init__(self): + self.connected = False + self.retry_count = 0 + self.max_retries = 3 + self.connection_attempts = 0 + + def execute_with_retry(self, operation): + """Execute operation with automatic retry on connection failure.""" + for attempt in range(self.max_retries): + try: + if not self.connected or attempt > 0: + self._connect() + + result = operation() + self.retry_count = 0 # Reset retry count on success + return result + + except Exception as e: + self.retry_count += 1 + self.connected = False + + if self.retry_count >= self.max_retries: + raise Exception(f"Max retries exceeded: {e}") + + time.sleep(2 ** attempt) # Exponential backoff + + def _connect(self): + from pymilvus import connections + self.connection_attempts += 1 + if self.connection_attempts <= 2: + # First two connection attempts fail + self.connected = False + if self.connection_attempts == 1: + raise Exception("Connection lost") + else: + raise Exception("Still disconnected") + else: + # Third attempt succeeds + connections.connect(alias="resilient", host="localhost", port=19530) + self.connected = True + + conn = ResilientConnection() + + # Mock operation that will fail initially + operation_calls = 0 + def test_operation(): + nonlocal operation_calls + operation_calls += 1 + if operation_calls < 3 and not conn.connected: + raise Exception("Operation failed") + return "Success" + + with patch('time.sleep'): # Mock sleep for faster testing + result = conn.execute_with_retry(test_operation) + + # Operation should eventually succeed + assert result == "Success" + + @patch('pymilvus.connections.connect') + def test_connection_timeout_handling(self, mock_connect): + """Test handling of connection timeouts.""" + import socket + mock_connect.side_effect = socket.timeout("Connection timed out") + + def connect_with_timeout_handling(host, port, timeout=10): + from pymilvus import connections + + try: + return connections.connect( + alias="timeout_test", + host=host, + port=port, + timeout=timeout + ) + except socket.timeout as e: + return f"Connection timeout: {e}" + except Exception as e: + return f"Connection error: {e}" + + result = connect_with_timeout_handling("localhost", 19530, timeout=5) + assert "Connection timeout" in result + + def test_connection_state_management(self): + """Test managing connection state across operations.""" + class ConnectionManager: + def __init__(self): + self.connections = {} + self.active_alias = None + + def add_connection(self, alias, host, port): + """Add a connection configuration.""" + self.connections[alias] = { + "host": host, + "port": port, + "connected": False + } + + def switch_connection(self, alias): + """Switch to a different connection.""" + if alias not in self.connections: + raise ValueError(f"Unknown connection alias: {alias}") + + # Disconnect from current if connected + if self.active_alias and self.connections[self.active_alias]["connected"]: + self.connections[self.active_alias]["connected"] = False + + self.active_alias = alias + self.connections[alias]["connected"] = True + return True + + def get_active_connection(self): + """Get the currently active connection.""" + if not self.active_alias: + return None + return self.connections.get(self.active_alias) + + def close_all(self): + """Close all connections.""" + for alias in self.connections: + self.connections[alias]["connected"] = False + self.active_alias = None + + manager = ConnectionManager() + manager.add_connection("primary", "localhost", 19530) + manager.add_connection("secondary", "remote", 8080) + + # Switch to primary + assert manager.switch_connection("primary") is True + active = manager.get_active_connection() + assert active["host"] == "localhost" + assert active["connected"] is True + + # Switch to secondary + manager.switch_connection("secondary") + assert manager.connections["primary"]["connected"] is False + assert manager.connections["secondary"]["connected"] is True + + # Close all + manager.close_all() + assert all(not conn["connected"] for conn in manager.connections.values()) diff --git a/vdb_benchmark/tests/tests/test_index_management.py b/vdb_benchmark/tests/tests/test_index_management.py new file mode 100755 index 0000000..7cf87f7 --- /dev/null +++ b/vdb_benchmark/tests/tests/test_index_management.py @@ -0,0 +1,825 @@ +""" +Unit tests for index management functionality in vdb-bench +""" +import pytest +import numpy as np +from unittest.mock import Mock, MagicMock, patch, call +import time +import json +from typing import Dict, Any, List +from concurrent.futures import ThreadPoolExecutor + + +class TestIndexCreation: + """Test index creation operations.""" + + def test_create_diskann_index(self, mock_collection): + """Test creating DiskANN index.""" + mock_collection.create_index.return_value = True + + def create_diskann_index(collection, field_name="embedding", params=None): + """Create DiskANN index on collection.""" + if params is None: + params = { + "metric_type": "L2", + "index_type": "DISKANN", + "params": { + "max_degree": 64, + "search_list_size": 200, + "pq_code_budget_gb": 0.1, + "build_algo": "IVF_PQ" + } + } + + try: + result = collection.create_index( + field_name=field_name, + index_params=params + ) + return { + "success": True, + "index_type": params["index_type"], + "field": field_name, + "params": params + } + except Exception as e: + return { + "success": False, + "error": str(e) + } + + result = create_diskann_index(mock_collection) + + assert result["success"] is True + assert result["index_type"] == "DISKANN" + mock_collection.create_index.assert_called_once() + + def test_create_hnsw_index(self, mock_collection): + """Test creating HNSW index.""" + mock_collection.create_index.return_value = True + + def create_hnsw_index(collection, field_name="embedding", params=None): + """Create HNSW index on collection.""" + if params is None: + params = { + "metric_type": "L2", + "index_type": "HNSW", + "params": { + "M": 16, + "efConstruction": 200 + } + } + + try: + result = collection.create_index( + field_name=field_name, + index_params=params + ) + return { + "success": True, + "index_type": params["index_type"], + "field": field_name, + "params": params + } + except Exception as e: + return { + "success": False, + "error": str(e) + } + + result = create_hnsw_index(mock_collection) + + assert result["success"] is True + assert result["index_type"] == "HNSW" + assert result["params"]["params"]["M"] == 16 + + def test_create_ivf_index(self, mock_collection): + """Test creating IVF index variants.""" + class IVFIndexBuilder: + def __init__(self, collection): + self.collection = collection + + def create_ivf_flat(self, field_name, nlist=128): + """Create IVF_FLAT index.""" + params = { + "metric_type": "L2", + "index_type": "IVF_FLAT", + "params": {"nlist": nlist} + } + return self._create_index(field_name, params) + + def create_ivf_sq8(self, field_name, nlist=128): + """Create IVF_SQ8 index.""" + params = { + "metric_type": "L2", + "index_type": "IVF_SQ8", + "params": {"nlist": nlist} + } + return self._create_index(field_name, params) + + def create_ivf_pq(self, field_name, nlist=128, m=8, nbits=8): + """Create IVF_PQ index.""" + params = { + "metric_type": "L2", + "index_type": "IVF_PQ", + "params": { + "nlist": nlist, + "m": m, + "nbits": nbits + } + } + return self._create_index(field_name, params) + + def _create_index(self, field_name, params): + """Internal method to create index.""" + try: + self.collection.create_index( + field_name=field_name, + index_params=params + ) + return {"success": True, "params": params} + except Exception as e: + return {"success": False, "error": str(e)} + + mock_collection.create_index.return_value = True + builder = IVFIndexBuilder(mock_collection) + + # Test IVF_FLAT + result = builder.create_ivf_flat("embedding", nlist=256) + assert result["success"] is True + assert result["params"]["index_type"] == "IVF_FLAT" + + # Test IVF_SQ8 + result = builder.create_ivf_sq8("embedding", nlist=512) + assert result["success"] is True + assert result["params"]["index_type"] == "IVF_SQ8" + + # Test IVF_PQ + result = builder.create_ivf_pq("embedding", nlist=256, m=16) + assert result["success"] is True + assert result["params"]["index_type"] == "IVF_PQ" + assert result["params"]["params"]["m"] == 16 + + def test_index_creation_with_retry(self, mock_collection): + """Test index creation with retry logic.""" + # Simulate failures then success + mock_collection.create_index.side_effect = [ + Exception("Index creation failed"), + Exception("Still failing"), + True + ] + + def create_index_with_retry(collection, params, max_retries=3, backoff=2): + """Create index with exponential backoff retry.""" + for attempt in range(max_retries): + try: + collection.create_index( + field_name="embedding", + index_params=params + ) + return { + "success": True, + "attempts": attempt + 1 + } + except Exception as e: + if attempt == max_retries - 1: + return { + "success": False, + "attempts": attempt + 1, + "error": str(e) + } + time.sleep(backoff ** attempt) + + return {"success": False, "attempts": max_retries} + + params = { + "metric_type": "L2", + "index_type": "DISKANN", + "params": {"max_degree": 64} + } + + with patch('time.sleep'): # Speed up test + result = create_index_with_retry(mock_collection, params) + + assert result["success"] is True + assert result["attempts"] == 3 + assert mock_collection.create_index.call_count == 3 + + +class TestIndexManagement: + """Test index management operations.""" + + def test_index_status_check(self, mock_collection): + """Test checking index status.""" + # Create a proper mock index object + mock_index = Mock() + mock_index.params = {"index_type": "DISKANN"} + mock_index.progress = 100 + mock_index.state = "Finished" + + # Set the index attribute on collection + mock_collection.index = mock_index + + class IndexManager: + def __init__(self, collection): + self.collection = collection + + def get_index_status(self): + """Get current index status.""" + try: + index = self.collection.index + return { + "exists": True, + "type": index.params.get("index_type"), + "progress": index.progress, + "state": index.state, + "params": index.params + } + except: + return { + "exists": False, + "type": None, + "progress": 0, + "state": "Not Created" + } + + def is_index_ready(self): + """Check if index is ready for use.""" + status = self.get_index_status() + return ( + status["exists"] and + status["state"] == "Finished" and + status["progress"] == 100 + ) + + def wait_for_index(self, timeout=300, check_interval=5): + """Wait for index to be ready.""" + start_time = time.time() + + while time.time() - start_time < timeout: + if self.is_index_ready(): + return True + time.sleep(check_interval) + + return False + + manager = IndexManager(mock_collection) + + status = manager.get_index_status() + assert status["exists"] is True + assert status["type"] == "DISKANN" + assert status["progress"] == 100 + + assert manager.is_index_ready() is True + + def test_drop_index(self, mock_collection): + """Test dropping an index.""" + mock_collection.drop_index.return_value = None + + def drop_index(collection, field_name="embedding"): + """Drop index from collection.""" + try: + collection.drop_index(field_name=field_name) + return { + "success": True, + "field": field_name, + "message": f"Index dropped for field {field_name}" + } + except Exception as e: + return { + "success": False, + "error": str(e) + } + + result = drop_index(mock_collection) + + assert result["success"] is True + assert result["field"] == "embedding" + mock_collection.drop_index.assert_called_once_with(field_name="embedding") + + def test_rebuild_index(self, mock_collection): + """Test rebuilding an index.""" + mock_collection.drop_index.return_value = None + mock_collection.create_index.return_value = True + + class IndexRebuilder: + def __init__(self, collection): + self.collection = collection + + def rebuild_index(self, field_name, new_params): + """Rebuild index with new parameters.""" + steps = [] + + try: + # Step 1: Drop existing index + self.collection.drop_index(field_name=field_name) + steps.append("Index dropped") + + # Step 2: Wait for drop to complete + time.sleep(1) + steps.append("Waited for drop completion") + + # Step 3: Create new index + self.collection.create_index( + field_name=field_name, + index_params=new_params + ) + steps.append("New index created") + + return { + "success": True, + "steps": steps, + "new_params": new_params + } + + except Exception as e: + return { + "success": False, + "steps": steps, + "error": str(e) + } + + rebuilder = IndexRebuilder(mock_collection) + + new_params = { + "metric_type": "COSINE", + "index_type": "HNSW", + "params": {"M": 32, "efConstruction": 400} + } + + with patch('time.sleep'): # Speed up test + result = rebuilder.rebuild_index("embedding", new_params) + + assert result["success"] is True + assert len(result["steps"]) == 3 + assert mock_collection.drop_index.called + assert mock_collection.create_index.called + + def test_index_comparison(self): + """Test comparing different index configurations.""" + class IndexComparator: + def __init__(self): + self.results = {} + + def add_result(self, index_type, metrics): + """Add benchmark result for an index type.""" + self.results[index_type] = metrics + + def compare(self): + """Compare all index results.""" + if len(self.results) < 2: + return None + + comparison = { + "indexes": [], + "best_qps": None, + "best_recall": None, + "best_build_time": None + } + + best_qps = 0 + best_recall = 0 + best_build_time = float('inf') + + for index_type, metrics in self.results.items(): + comparison["indexes"].append({ + "type": index_type, + "qps": metrics.get("qps", 0), + "recall": metrics.get("recall", 0), + "build_time": metrics.get("build_time", 0), + "memory_usage": metrics.get("memory_usage", 0) + }) + + if metrics.get("qps", 0) > best_qps: + best_qps = metrics["qps"] + comparison["best_qps"] = index_type + + if metrics.get("recall", 0) > best_recall: + best_recall = metrics["recall"] + comparison["best_recall"] = index_type + + if metrics.get("build_time", float('inf')) < best_build_time: + best_build_time = metrics["build_time"] + comparison["best_build_time"] = index_type + + return comparison + + def get_recommendation(self, requirements): + """Get index recommendation based on requirements.""" + if not self.results: + return None + + scores = {} + + for index_type, metrics in self.results.items(): + score = 0 + + # Weight different factors based on requirements + if requirements.get("prioritize_speed"): + score += metrics.get("qps", 0) * 2 + + if requirements.get("prioritize_accuracy"): + score += metrics.get("recall", 0) * 1000 + + if requirements.get("memory_constrained"): + # Penalize high memory usage + score -= metrics.get("memory_usage", 0) * 0.1 + + if requirements.get("fast_build"): + # Penalize slow build time + score -= metrics.get("build_time", 0) * 10 + + scores[index_type] = score + + best_index = max(scores, key=scores.get) + + return { + "recommended": best_index, + "score": scores[best_index], + "all_scores": scores + } + + comparator = IndexComparator() + + # Add sample results + comparator.add_result("DISKANN", { + "qps": 1500, + "recall": 0.95, + "build_time": 300, + "memory_usage": 2048 + }) + + comparator.add_result("HNSW", { + "qps": 1200, + "recall": 0.98, + "build_time": 150, + "memory_usage": 4096 + }) + + comparator.add_result("IVF_PQ", { + "qps": 2000, + "recall": 0.90, + "build_time": 100, + "memory_usage": 1024 + }) + + comparison = comparator.compare() + + assert comparison["best_qps"] == "IVF_PQ" + assert comparison["best_recall"] == "HNSW" + assert comparison["best_build_time"] == "IVF_PQ" + + # Test recommendation + requirements = { + "prioritize_accuracy": True, + "memory_constrained": False + } + + recommendation = comparator.get_recommendation(requirements) + assert recommendation["recommended"] == "HNSW" + + +class TestIndexOptimization: + """Test index optimization strategies.""" + + def test_parameter_tuning(self, mock_collection): + """Test automatic parameter tuning for indexes.""" + class ParameterTuner: + def __init__(self, collection): + self.collection = collection + self.test_results = [] + + def tune_diskann(self, test_vectors, ground_truth): + """Tune DiskANN parameters.""" + param_grid = [ + {"max_degree": 32, "search_list_size": 100}, + {"max_degree": 64, "search_list_size": 200}, + {"max_degree": 96, "search_list_size": 300} + ] + + best_params = None + best_score = 0 + + for params in param_grid: + score = self._test_params( + "DISKANN", + params, + test_vectors, + ground_truth + ) + + if score > best_score: + best_score = score + best_params = params + + self.test_results.append({ + "params": params, + "score": score + }) + + return best_params, best_score + + def tune_hnsw(self, test_vectors, ground_truth): + """Tune HNSW parameters.""" + param_grid = [ + {"M": 8, "efConstruction": 100}, + {"M": 16, "efConstruction": 200}, + {"M": 32, "efConstruction": 400} + ] + + best_params = None + best_score = 0 + + for params in param_grid: + score = self._test_params( + "HNSW", + params, + test_vectors, + ground_truth + ) + + if score > best_score: + best_score = score + best_params = params + + self.test_results.append({ + "params": params, + "score": score + }) + + return best_params, best_score + + def _test_params(self, index_type, params, test_vectors, ground_truth): + """Test specific parameters and return score.""" + # Simulated testing (in reality would rebuild index and test) + # Score based on parameter values (simplified) + + if index_type == "DISKANN": + score = params["max_degree"] * 0.5 + params["search_list_size"] * 0.2 + elif index_type == "HNSW": + score = params["M"] * 2 + params["efConstruction"] * 0.1 + else: + score = 0 + + # Add some randomness + score += np.random.random() * 10 + + return score + + tuner = ParameterTuner(mock_collection) + + # Create test data + test_vectors = np.random.randn(100, 128).astype(np.float32) + ground_truth = np.random.randint(0, 1000, (100, 10)) + + # Tune DiskANN + best_diskann, diskann_score = tuner.tune_diskann(test_vectors, ground_truth) + assert best_diskann is not None + assert diskann_score > 0 + + # Tune HNSW + best_hnsw, hnsw_score = tuner.tune_hnsw(test_vectors, ground_truth) + assert best_hnsw is not None + assert hnsw_score > 0 + + # Check that results were recorded + assert len(tuner.test_results) == 6 # 3 for each index type + + def test_adaptive_index_selection(self): + """Test adaptive index selection based on workload.""" + class AdaptiveIndexSelector: + def __init__(self): + self.workload_history = [] + self.current_index = None + + def analyze_workload(self, queries): + """Analyze query workload characteristics.""" + characteristics = { + "query_count": len(queries), + "dimension": queries.shape[1] if len(queries) > 0 else 0, + "distribution": self._analyze_distribution(queries), + "sparsity": self._calculate_sparsity(queries), + "clustering": self._analyze_clustering(queries) + } + + self.workload_history.append({ + "timestamp": time.time(), + "characteristics": characteristics + }) + + return characteristics + + def select_index(self, characteristics, dataset_size): + """Select best index for workload characteristics.""" + # Simple rule-based selection + + if dataset_size < 100000: + # Small dataset - use simple index + return "IVF_FLAT" + + elif dataset_size < 1000000: + # Medium dataset + if characteristics["clustering"] > 0.7: + # Highly clustered - IVF works well + return "IVF_PQ" + else: + # More uniform - HNSW + return "HNSW" + + else: + # Large dataset + if characteristics["sparsity"] > 0.5: + # Sparse vectors - specialized index + return "SPARSE_IVF" + elif characteristics["dimension"] > 1000: + # High dimension - DiskANN with PQ + return "DISKANN" + else: + # Default to HNSW for good all-around performance + return "HNSW" + + def _analyze_distribution(self, queries): + """Analyze query distribution.""" + if len(queries) == 0: + return "unknown" + + # Simple variance check + variance = np.var(queries) + if variance < 0.5: + return "concentrated" + elif variance < 2.0: + return "normal" + else: + return "scattered" + + def _calculate_sparsity(self, queries): + """Calculate sparsity of queries.""" + if len(queries) == 0: + return 0 + + zero_count = np.sum(queries == 0) + total_elements = queries.size + + return zero_count / total_elements if total_elements > 0 else 0 + + def _analyze_clustering(self, queries): + """Analyze clustering tendency.""" + # Simplified clustering score + if len(queries) < 10: + return 0 + + # Calculate pairwise distances for small sample + sample = queries[:min(100, len(queries))] + distances = [] + + for i in range(len(sample)): + for j in range(i + 1, len(sample)): + dist = np.linalg.norm(sample[i] - sample[j]) + distances.append(dist) + + if not distances: + return 0 + + # High variance in distances indicates clustering + distance_var = np.var(distances) + return min(distance_var / 10, 1.0) # Normalize to [0, 1] + + selector = AdaptiveIndexSelector() + + # Test with different workloads + + # Sparse workload + sparse_queries = np.random.randn(100, 2000).astype(np.float32) + sparse_queries[sparse_queries < 1] = 0 # Make sparse + + characteristics = selector.analyze_workload(sparse_queries) + selected_index = selector.select_index(characteristics, 5000000) + + assert characteristics["sparsity"] > 0.3 + + # Dense clustered workload + clustered_queries = [] + for _ in range(5): + center = np.random.randn(128) * 10 + cluster = center + np.random.randn(20, 128) * 0.1 + clustered_queries.append(cluster) + clustered_queries = np.vstack(clustered_queries).astype(np.float32) + + characteristics = selector.analyze_workload(clustered_queries) + selected_index = selector.select_index(characteristics, 500000) + + assert selected_index in ["IVF_PQ", "HNSW"] + + def test_index_warm_up(self, mock_collection): + """Test index warm-up procedures.""" + class IndexWarmUp: + def __init__(self, collection): + self.collection = collection + self.warm_up_stats = [] + + def warm_up(self, num_queries=100, batch_size=10): + """Warm up index with sample queries.""" + total_time = 0 + queries_executed = 0 + + for batch in range(0, num_queries, batch_size): + # Generate random queries + batch_queries = np.random.randn( + min(batch_size, num_queries - batch), + 128 + ).astype(np.float32) + + start = time.time() + + # Execute warm-up queries + self.collection.search( + data=batch_queries.tolist(), + anns_field="embedding", + param={"metric_type": "L2"}, + limit=10 + ) + + elapsed = time.time() - start + total_time += elapsed + queries_executed += len(batch_queries) + + self.warm_up_stats.append({ + "batch": batch // batch_size, + "queries": len(batch_queries), + "time": elapsed, + "qps": len(batch_queries) / elapsed if elapsed > 0 else 0 + }) + + return { + "total_queries": queries_executed, + "total_time": total_time, + "avg_qps": queries_executed / total_time if total_time > 0 else 0, + "stats": self.warm_up_stats + } + + def adaptive_warm_up(self, target_qps=100, max_queries=1000): + """Adaptive warm-up that stops when performance stabilizes.""" + stable_threshold = 0.1 # 10% variation + window_size = 5 + recent_qps = [] + + batch_size = 10 + total_queries = 0 + + while total_queries < max_queries: + queries = np.random.randn(batch_size, 128).astype(np.float32) + + start = time.time() + self.collection.search( + data=queries.tolist(), + anns_field="embedding", + param={"metric_type": "L2"}, + limit=10 + ) + elapsed = time.time() - start + + qps = batch_size / elapsed if elapsed > 0 else 0 + recent_qps.append(qps) + total_queries += batch_size + + # Check if performance is stable + if len(recent_qps) >= window_size: + recent = recent_qps[-window_size:] + avg = sum(recent) / len(recent) + variance = sum((q - avg) ** 2 for q in recent) / len(recent) + cv = (variance ** 0.5) / avg if avg > 0 else 1 + + if cv < stable_threshold and avg >= target_qps: + return { + "warmed_up": True, + "queries_used": total_queries, + "final_qps": avg, + "stabilized": True + } + + return { + "warmed_up": True, + "queries_used": total_queries, + "final_qps": recent_qps[-1] if recent_qps else 0, + "stabilized": False + } + + mock_collection.search.return_value = [[Mock(id=i, distance=0.1*i) for i in range(10)]] + + warmer = IndexWarmUp(mock_collection) + + # Test basic warm-up + with patch('time.time', side_effect=[0, 0.1, 0.2, 0.3, 0.4, 0.5] * 20): + result = warmer.warm_up(num_queries=50, batch_size=10) + + assert result["total_queries"] == 50 + assert len(warmer.warm_up_stats) == 5 + + # Test adaptive warm-up + warmer2 = IndexWarmUp(mock_collection) + + with patch('time.time', side_effect=[i * 0.01 for i in range(200)]): + result = warmer2.adaptive_warm_up(target_qps=100, max_queries=100) + + assert result["warmed_up"] is True + assert result["queries_used"] <= 100 diff --git a/vdb_benchmark/tests/tests/test_load_vdb.py b/vdb_benchmark/tests/tests/test_load_vdb.py new file mode 100755 index 0000000..772f2f9 --- /dev/null +++ b/vdb_benchmark/tests/tests/test_load_vdb.py @@ -0,0 +1,530 @@ +""" +Unit tests for vector loading functionality in vdb-bench +""" +import pytest +import numpy as np +from unittest.mock import Mock, MagicMock, patch, call +import time +from typing import List, Generator +import json + + +class TestVectorGeneration: + """Test vector generation utilities.""" + + def test_uniform_vector_generation(self): + """Test generating vectors with uniform distribution.""" + def generate_uniform_vectors(num_vectors, dimension, seed=None): + if seed is not None: + np.random.seed(seed) + return np.random.uniform(-1, 1, size=(num_vectors, dimension)).astype(np.float32) + + vectors = generate_uniform_vectors(100, 128, seed=42) + + assert vectors.shape == (100, 128) + assert vectors.dtype == np.float32 + assert vectors.min() >= -1 + assert vectors.max() <= 1 + + # Test reproducibility with seed + vectors2 = generate_uniform_vectors(100, 128, seed=42) + np.testing.assert_array_equal(vectors, vectors2) + + def test_normal_vector_generation(self): + """Test generating vectors with normal distribution.""" + def generate_normal_vectors(num_vectors, dimension, mean=0, std=1, seed=None): + if seed is not None: + np.random.seed(seed) + return np.random.normal(mean, std, size=(num_vectors, dimension)).astype(np.float32) + + vectors = generate_normal_vectors(1000, 256, seed=42) + + assert vectors.shape == (1000, 256) + assert vectors.dtype == np.float32 + + # Check distribution properties (should be close to normal) + assert -0.1 < vectors.mean() < 0.1 # Mean should be close to 0 + assert 0.9 < vectors.std() < 1.1 # Std should be close to 1 + + def test_normalized_vector_generation(self): + """Test generating L2-normalized vectors.""" + def generate_normalized_vectors(num_vectors, dimension, seed=None): + if seed is not None: + np.random.seed(seed) + + vectors = np.random.randn(num_vectors, dimension).astype(np.float32) + # L2 normalize each vector + norms = np.linalg.norm(vectors, axis=1, keepdims=True) + return vectors / norms + + vectors = generate_normalized_vectors(50, 64, seed=42) + + assert vectors.shape == (50, 64) + + # Check that all vectors are normalized + norms = np.linalg.norm(vectors, axis=1) + np.testing.assert_array_almost_equal(norms, np.ones(50), decimal=5) + + def test_chunked_vector_generation(self): + """Test generating vectors in chunks for memory efficiency.""" + def generate_vectors_chunked(total_vectors, dimension, chunk_size=1000): + """Generate vectors in chunks to manage memory.""" + num_chunks = (total_vectors + chunk_size - 1) // chunk_size + + for i in range(num_chunks): + start_idx = i * chunk_size + end_idx = min(start_idx + chunk_size, total_vectors) + chunk_vectors = end_idx - start_idx + + yield np.random.randn(chunk_vectors, dimension).astype(np.float32) + + # Generate 10000 vectors in chunks of 1000 + all_vectors = [] + for chunk in generate_vectors_chunked(10000, 128, chunk_size=1000): + all_vectors.append(chunk) + + assert len(all_vectors) == 10 + assert all_vectors[0].shape == (1000, 128) + + # Concatenate and verify total + concatenated = np.vstack(all_vectors) + assert concatenated.shape == (10000, 128) + + def test_vector_generation_with_ids(self): + """Test generating vectors with associated IDs.""" + def generate_vectors_with_ids(num_vectors, dimension, start_id=0): + vectors = np.random.randn(num_vectors, dimension).astype(np.float32) + ids = np.arange(start_id, start_id + num_vectors, dtype=np.int64) + return ids, vectors + + ids, vectors = generate_vectors_with_ids(100, 256, start_id=1000) + + assert len(ids) == 100 + assert ids[0] == 1000 + assert ids[-1] == 1099 + assert vectors.shape == (100, 256) + + def test_vector_generation_progress_tracking(self): + """Test tracking progress during vector generation.""" + def generate_with_progress(num_vectors, dimension, chunk_size=100): + total_generated = 0 + progress_updates = [] + + for chunk_num in range(0, num_vectors, chunk_size): + chunk_end = min(chunk_num + chunk_size, num_vectors) + chunk_size_actual = chunk_end - chunk_num + + vectors = np.random.randn(chunk_size_actual, dimension).astype(np.float32) + + total_generated += chunk_size_actual + progress = (total_generated / num_vectors) * 100 + progress_updates.append(progress) + + yield vectors, progress + + progress_list = [] + vector_list = [] + + for vectors, progress in generate_with_progress(1000, 128, chunk_size=200): + vector_list.append(vectors) + progress_list.append(progress) + + assert len(progress_list) == 5 + assert progress_list[-1] == 100.0 + assert all(p > 0 for p in progress_list) + + +class TestVectorLoading: + """Test vector loading into database.""" + + def test_batch_insertion(self, mock_collection): + """Test inserting vectors in batches.""" + inserted_data = [] + mock_collection.insert.side_effect = lambda data: inserted_data.append(data) + + def insert_vectors_batch(collection, vectors, batch_size=1000): + """Insert vectors in batches.""" + num_vectors = len(vectors) + total_inserted = 0 + + for i in range(0, num_vectors, batch_size): + batch = vectors[i:i + batch_size] + collection.insert([batch]) + total_inserted += len(batch) + + return total_inserted + + vectors = np.random.randn(5000, 128).astype(np.float32) + total = insert_vectors_batch(mock_collection, vectors, batch_size=1000) + + assert total == 5000 + assert mock_collection.insert.call_count == 5 + + def test_insertion_with_error_handling(self, mock_collection): + """Test vector insertion with error handling.""" + # Simulate occasional insertion failures + call_count = 0 + def insert_side_effect(data): + nonlocal call_count + call_count += 1 + if call_count == 1: + raise Exception("Insert failed") + return Mock(primary_keys=list(range(len(data[0])))) + + mock_collection.insert.side_effect = insert_side_effect + + def insert_with_retry(collection, vectors, max_retries=3): + """Insert vectors with retry on failure.""" + for attempt in range(max_retries): + try: + result = collection.insert([vectors]) + return result + except Exception as e: + if attempt == max_retries - 1: + raise + time.sleep(1) + return None + + vectors = np.random.randn(100, 128).astype(np.float32) + + with patch('time.sleep'): + result = insert_with_retry(mock_collection, vectors) + + assert result is not None + assert mock_collection.insert.call_count == 2 # Failed once, succeeded on retry + + def test_parallel_insertion(self, mock_collection): + """Test parallel vector insertion using multiple threads/processes.""" + from concurrent.futures import ThreadPoolExecutor + + def insert_chunk(args): + collection, chunk_id, vectors = args + collection.insert([vectors]) + return chunk_id, len(vectors) + + def parallel_insert(collection, vectors, num_workers=4, chunk_size=1000): + """Insert vectors in parallel.""" + chunks = [] + for i in range(0, len(vectors), chunk_size): + chunk = vectors[i:i + chunk_size] + chunks.append((collection, i // chunk_size, chunk)) + + with ThreadPoolExecutor(max_workers=num_workers) as executor: + results = list(executor.map(insert_chunk, chunks)) + + total_inserted = sum(count for _, count in results) + return total_inserted + + vectors = np.random.randn(4000, 128).astype(np.float32) + + # Mock the insert to track calls + inserted_chunks = [] + mock_collection.insert.side_effect = lambda data: inserted_chunks.append(len(data[0])) + + total = parallel_insert(mock_collection, vectors, num_workers=2, chunk_size=1000) + + assert total == 4000 + assert len(inserted_chunks) == 4 + + def test_insertion_with_metadata(self, mock_collection): + """Test inserting vectors with additional metadata.""" + def insert_vectors_with_metadata(collection, vectors, metadata): + """Insert vectors along with metadata.""" + data = [ + vectors, + metadata.get("ids", list(range(len(vectors)))), + metadata.get("tags", ["default"] * len(vectors)) + ] + + result = collection.insert(data) + return result + + vectors = np.random.randn(100, 128).astype(np.float32) + metadata = { + "ids": list(range(1000, 1100)), + "tags": [f"tag_{i % 10}" for i in range(100)] + } + + mock_collection.insert.return_value = Mock(primary_keys=metadata["ids"]) + + result = insert_vectors_with_metadata(mock_collection, vectors, metadata) + + assert result.primary_keys == metadata["ids"] + mock_collection.insert.assert_called_once() + + @patch('time.time') + def test_insertion_rate_monitoring(self, mock_time, mock_collection): + """Test monitoring insertion rate and throughput.""" + # Start at 1 instead of 0 to avoid issues with 0 being falsy + time_sequence = [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0] + mock_time.side_effect = time_sequence + + class InsertionMonitor: + def __init__(self): + self.total_vectors = 0 + self.start_time = None + self.batch_times = [] + self.last_time = None + + def start(self): + self.start_time = time.time() + self.last_time = self.start_time + + def record_batch(self, batch_size): + current_time = time.time() + if self.start_time is not None: + # Calculate elapsed since last batch + elapsed = current_time - self.last_time + self.last_time = current_time + self.batch_times.append(current_time) + self.total_vectors += batch_size + + # Calculate throughput + total_elapsed = current_time - self.start_time + throughput = self.total_vectors / total_elapsed if total_elapsed > 0 else 0 + + return { + "batch_size": batch_size, + "batch_time": elapsed, + "total_vectors": self.total_vectors, + "throughput": throughput + } + return None + + def get_summary(self): + # Check if we have data to summarize + if self.start_time is None or len(self.batch_times) == 0: + return None + + # Calculate total time from start to last batch + total_time = self.batch_times[-1] - self.start_time + + # Return summary if we have valid data + if self.total_vectors > 0: + return { + "total_vectors": self.total_vectors, + "total_time": total_time, + "average_throughput": self.total_vectors / total_time if total_time > 0 else 0 + } + + return None + + monitor = InsertionMonitor() + monitor.start() # Uses time value 1.0 + + # Simulate inserting batches (uses time values 2.0-6.0) + stats = [] + for i in range(5): + stat = monitor.record_batch(1000) + if stat: + stats.append(stat) + + summary = monitor.get_summary() + + assert summary is not None + assert summary["total_vectors"] == 5000 + assert summary["total_time"] == 5.0 # From time 1.0 to time 6.0 + assert summary["average_throughput"] == 1000.0 # 5000 vectors / 5 seconds + + def test_load_checkpoint_resume(self, test_data_dir): + """Test checkpoint and resume functionality for large loads.""" + checkpoint_file = test_data_dir / "checkpoint.json" + + class LoadCheckpoint: + def __init__(self, checkpoint_path): + self.checkpoint_path = checkpoint_path + self.state = self.load_checkpoint() + + def load_checkpoint(self): + """Load checkpoint from file if exists.""" + if self.checkpoint_path.exists(): + with open(self.checkpoint_path, 'r') as f: + return json.load(f) + return {"last_batch": 0, "total_inserted": 0} + + def save_checkpoint(self, batch_num, total_inserted): + """Save current progress to checkpoint.""" + self.state = { + "last_batch": batch_num, + "total_inserted": total_inserted, + "timestamp": time.time() + } + with open(self.checkpoint_path, 'w') as f: + json.dump(self.state, f) + + def get_resume_point(self): + """Get the batch number to resume from.""" + return self.state["last_batch"] + + def clear(self): + """Clear checkpoint after successful completion.""" + if self.checkpoint_path.exists(): + self.checkpoint_path.unlink() + self.state = {"last_batch": 0, "total_inserted": 0} + + checkpoint = LoadCheckpoint(checkpoint_file) + + # Simulate partial load + checkpoint.save_checkpoint(5, 5000) + assert checkpoint.get_resume_point() == 5 + + # Simulate resume + checkpoint2 = LoadCheckpoint(checkpoint_file) + assert checkpoint2.get_resume_point() == 5 + assert checkpoint2.state["total_inserted"] == 5000 + + # Clear checkpoint + checkpoint2.clear() + assert not checkpoint_file.exists() + + +class TestLoadOptimization: + """Test load optimization strategies.""" + + def test_dynamic_batch_sizing(self): + """Test dynamic batch size adjustment based on performance.""" + class DynamicBatchSizer: + def __init__(self, initial_size=1000, min_size=100, max_size=10000): + self.current_size = initial_size + self.min_size = min_size + self.max_size = max_size + self.history = [] + + def adjust(self, insertion_time, batch_size): + """Adjust batch size based on insertion performance.""" + throughput = batch_size / insertion_time if insertion_time > 0 else 0 + self.history.append((batch_size, throughput)) + + if len(self.history) >= 3: + # Calculate trend + recent_throughputs = [tp for _, tp in self.history[-3:]] + avg_throughput = sum(recent_throughputs) / len(recent_throughputs) + + if throughput > avg_throughput * 1.1: + # Performance improving, increase batch size + self.current_size = min( + int(self.current_size * 1.2), + self.max_size + ) + elif throughput < avg_throughput * 0.9: + # Performance degrading, decrease batch size + self.current_size = max( + int(self.current_size * 0.8), + self.min_size + ) + + return self.current_size + + sizer = DynamicBatchSizer(initial_size=1000) + + # Simulate good performance - should increase batch size + new_size = sizer.adjust(1.0, 1000) # 1000 vectors/sec + new_size = sizer.adjust(0.9, 1000) # 1111 vectors/sec + new_size = sizer.adjust(0.8, 1000) # 1250 vectors/sec + new_size = sizer.adjust(0.7, new_size) # Improving performance + + assert new_size > 1000 # Should have increased + + # Simulate degrading performance - should decrease batch size + sizer2 = DynamicBatchSizer(initial_size=5000) + new_size = sizer2.adjust(1.0, 5000) # 5000 vectors/sec + new_size = sizer2.adjust(1.2, 5000) # 4166 vectors/sec + new_size = sizer2.adjust(1.5, 5000) # 3333 vectors/sec + new_size = sizer2.adjust(2.0, new_size) # Degrading performance + + assert new_size < 5000 # Should have decreased + + def test_memory_aware_loading(self): + """Test memory-aware vector loading.""" + import psutil + + class MemoryAwareLoader: + def __init__(self, memory_threshold=0.8): + self.memory_threshold = memory_threshold + self.base_batch_size = 1000 + + def get_memory_usage(self): + """Get current memory usage percentage.""" + return psutil.virtual_memory().percent / 100 + + def calculate_safe_batch_size(self, vector_dimension): + """Calculate safe batch size based on available memory.""" + memory_usage = self.get_memory_usage() + + if memory_usage > self.memory_threshold: + # Reduce batch size when memory is high + reduction_factor = 1.0 - (memory_usage - self.memory_threshold) + return max(100, int(self.base_batch_size * reduction_factor)) + + # Calculate based on vector size + bytes_per_vector = vector_dimension * 4 # float32 + available_memory = (1.0 - memory_usage) * psutil.virtual_memory().total + max_vectors = int(available_memory * 0.5 / bytes_per_vector) # Use 50% of available + + return min(max_vectors, self.base_batch_size) + + def should_gc(self): + """Determine if garbage collection should be triggered.""" + return self.get_memory_usage() > 0.7 + + with patch('psutil.virtual_memory') as mock_memory: + # Simulate different memory conditions + mock_memory.return_value = Mock(percent=60, total=16 * 1024**3) # 60% used, 16GB total + + loader = MemoryAwareLoader() + batch_size = loader.calculate_safe_batch_size(1536) + + assert batch_size > 0 + assert not loader.should_gc() + + # Simulate high memory usage + mock_memory.return_value = Mock(percent=85, total=16 * 1024**3) # 85% used + + batch_size = loader.calculate_safe_batch_size(1536) + assert batch_size < loader.base_batch_size # Should be reduced + assert loader.should_gc() + + def test_flush_optimization(self, mock_collection): + """Test optimizing flush operations during loading.""" + flush_count = 0 + + def mock_flush(): + nonlocal flush_count + flush_count += 1 + time.sleep(0.1) # Simulate flush time + + mock_collection.flush = mock_flush + + class FlushOptimizer: + def __init__(self, flush_interval=10000, time_interval=60): + self.flush_interval = flush_interval + self.time_interval = time_interval + self.vectors_since_flush = 0 + self.last_flush_time = time.time() + + def should_flush(self, vectors_inserted): + """Determine if flush should be triggered.""" + self.vectors_since_flush += vectors_inserted + current_time = time.time() + + # Flush based on vector count or time + if (self.vectors_since_flush >= self.flush_interval or + current_time - self.last_flush_time >= self.time_interval): + return True + return False + + def flush(self, collection): + """Perform flush and reset counters.""" + collection.flush() + self.vectors_since_flush = 0 + self.last_flush_time = time.time() + + optimizer = FlushOptimizer(flush_interval=5000) + + with patch('time.sleep'): # Speed up test + # Simulate loading vectors + for i in range(10): + if optimizer.should_flush(1000): + optimizer.flush(mock_collection) + + assert flush_count == 2 # Should have flushed twice (at 5000 and 10000) diff --git a/vdb_benchmark/tests/tests/test_simple_bench.py b/vdb_benchmark/tests/tests/test_simple_bench.py new file mode 100755 index 0000000..c322a3d --- /dev/null +++ b/vdb_benchmark/tests/tests/test_simple_bench.py @@ -0,0 +1,766 @@ +""" +Unit tests for benchmarking functionality in vdb-bench +""" +import pytest +import numpy as np +from unittest.mock import Mock, MagicMock, patch, call +import time +import multiprocessing as mp +from typing import List, Dict, Any +import statistics +import json +from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor + + +class TestBenchmarkExecution: + """Test benchmark execution and query operations.""" + + def test_single_query_execution(self, mock_collection): + """Test executing a single query.""" + # Mock search result + mock_collection.search.return_value = [[ + Mock(id=1, distance=0.1), + Mock(id=2, distance=0.2), + Mock(id=3, distance=0.3) + ]] + + def execute_single_query(collection, query_vector, top_k=10): + """Execute a single vector search query.""" + start_time = time.time() + + results = collection.search( + data=[query_vector], + anns_field="embedding", + param={"metric_type": "L2", "params": {"nprobe": 10}}, + limit=top_k + ) + + end_time = time.time() + latency = end_time - start_time + + return { + "latency": latency, + "num_results": len(results[0]) if results else 0, + "top_result": results[0][0].id if results and results[0] else None + } + + query = np.random.randn(128).astype(np.float32) + result = execute_single_query(mock_collection, query) + + assert result["latency"] >= 0 + assert result["num_results"] == 3 + assert result["top_result"] == 1 + mock_collection.search.assert_called_once() + + def test_batch_query_execution(self, mock_collection): + """Test executing batch queries.""" + # Mock batch search results + mock_results = [ + [Mock(id=i, distance=0.1*i) for i in range(1, 6)] + for _ in range(10) + ] + mock_collection.search.return_value = mock_results + + def execute_batch_queries(collection, query_vectors, top_k=10): + """Execute batch vector search queries.""" + start_time = time.time() + + results = collection.search( + data=query_vectors, + anns_field="embedding", + param={"metric_type": "L2"}, + limit=top_k + ) + + end_time = time.time() + total_latency = end_time - start_time + + return { + "total_latency": total_latency, + "queries_per_second": len(query_vectors) / total_latency if total_latency > 0 else 0, + "num_queries": len(query_vectors), + "results_per_query": [len(r) for r in results] + } + + queries = np.random.randn(10, 128).astype(np.float32) + result = execute_batch_queries(mock_collection, queries) + + assert result["num_queries"] == 10 + assert len(result["results_per_query"]) == 10 + assert all(r == 5 for r in result["results_per_query"]) + + @patch('time.time') + def test_throughput_measurement(self, mock_time, mock_collection): + """Test measuring query throughput.""" + # Simulate time progression + time_counter = [0] + def time_side_effect(): + time_counter[0] += 0.001 # 1ms per call + return time_counter[0] + + mock_time.side_effect = time_side_effect + mock_collection.search.return_value = [[Mock(id=1, distance=0.1)]] + + class ThroughputBenchmark: + def __init__(self): + self.results = [] + + def run(self, collection, queries, duration=10): + """Run throughput benchmark for specified duration.""" + start_time = time.time() + end_time = start_time + duration + query_count = 0 + latencies = [] + + query_idx = 0 + while time.time() < end_time: + query_start = time.time() + + # Execute query + collection.search( + data=[queries[query_idx % len(queries)]], + anns_field="embedding", + param={"metric_type": "L2"}, + limit=10 + ) + + query_end = time.time() + latencies.append(query_end - query_start) + query_count += 1 + query_idx += 1 + + # Break if we've done enough queries for the test + if query_count >= 100: # Limit for testing + break + + actual_duration = time.time() - start_time + + return { + "total_queries": query_count, + "duration": actual_duration, + "qps": query_count / actual_duration if actual_duration > 0 else 0, + "avg_latency": statistics.mean(latencies) if latencies else 0, + "p50_latency": statistics.median(latencies) if latencies else 0, + "p95_latency": self._percentile(latencies, 95) if latencies else 0, + "p99_latency": self._percentile(latencies, 99) if latencies else 0 + } + + def _percentile(self, data, percentile): + """Calculate percentile of data.""" + size = len(data) + if size == 0: + return 0 + sorted_data = sorted(data) + index = int(size * percentile / 100) + return sorted_data[min(index, size - 1)] + + benchmark = ThroughputBenchmark() + queries = np.random.randn(10, 128).astype(np.float32) + + result = benchmark.run(mock_collection, queries, duration=1) + + assert result["total_queries"] > 0 + assert result["qps"] > 0 + assert result["avg_latency"] > 0 + + def test_concurrent_query_execution(self, mock_collection): + """Test concurrent query execution with multiple threads.""" + query_counter = {'count': 0} + + def mock_search(data, **kwargs): + query_counter['count'] += 1 + time.sleep(0.01) # Simulate query time + return [[Mock(id=i, distance=0.1*i) for i in range(5)]] + + mock_collection.search = mock_search + + class ConcurrentBenchmark: + def __init__(self, num_threads=4): + self.num_threads = num_threads + + def worker(self, args): + """Worker function for concurrent execution.""" + collection, queries, worker_id = args + results = [] + + for i, query in enumerate(queries): + start = time.time() + result = collection.search( + data=[query], + anns_field="embedding", + param={"metric_type": "L2"}, + limit=10 + ) + latency = time.time() - start + results.append({ + "worker_id": worker_id, + "query_id": i, + "latency": latency + }) + + return results + + def run(self, collection, queries): + """Run concurrent benchmark.""" + # Split queries among workers + queries_per_worker = len(queries) // self.num_threads + worker_args = [] + + for i in range(self.num_threads): + start_idx = i * queries_per_worker + end_idx = start_idx + queries_per_worker if i < self.num_threads - 1 else len(queries) + worker_queries = queries[start_idx:end_idx] + worker_args.append((collection, worker_queries, i)) + + start_time = time.time() + + with ThreadPoolExecutor(max_workers=self.num_threads) as executor: + results = list(executor.map(self.worker, worker_args)) + + end_time = time.time() + + # Flatten results + all_results = [] + for worker_results in results: + all_results.extend(worker_results) + + total_duration = end_time - start_time + latencies = [r["latency"] for r in all_results] + + return { + "num_threads": self.num_threads, + "total_queries": len(all_results), + "duration": total_duration, + "qps": len(all_results) / total_duration if total_duration > 0 else 0, + "avg_latency": statistics.mean(latencies) if latencies else 0, + "min_latency": min(latencies) if latencies else 0, + "max_latency": max(latencies) if latencies else 0 + } + + benchmark = ConcurrentBenchmark(num_threads=4) + queries = np.random.randn(100, 128).astype(np.float32) + + result = benchmark.run(mock_collection, queries) + + assert result["total_queries"] == 100 + assert result["num_threads"] == 4 + assert result["qps"] > 0 + assert query_counter['count'] == 100 + + +class TestBenchmarkMetrics: + """Test benchmark metric collection and analysis.""" + + def test_latency_distribution(self): + """Test calculating latency distribution metrics.""" + class LatencyAnalyzer: + def __init__(self): + self.latencies = [] + + def add_latency(self, latency): + """Add a latency measurement.""" + self.latencies.append(latency) + + def get_distribution(self): + """Calculate latency distribution statistics.""" + if not self.latencies: + return {} + + sorted_latencies = sorted(self.latencies) + + return { + "count": len(self.latencies), + "mean": statistics.mean(self.latencies), + "median": statistics.median(self.latencies), + "stdev": statistics.stdev(self.latencies) if len(self.latencies) > 1 else 0, + "min": min(self.latencies), + "max": max(self.latencies), + "p50": self._percentile(sorted_latencies, 50), + "p90": self._percentile(sorted_latencies, 90), + "p95": self._percentile(sorted_latencies, 95), + "p99": self._percentile(sorted_latencies, 99), + "p999": self._percentile(sorted_latencies, 99.9) + } + + def _percentile(self, sorted_data, percentile): + """Calculate percentile from sorted data.""" + index = len(sorted_data) * percentile / 100 + lower = int(index) + upper = lower + 1 + + if upper >= len(sorted_data): + return sorted_data[-1] + + weight = index - lower + return sorted_data[lower] * (1 - weight) + sorted_data[upper] * weight + + analyzer = LatencyAnalyzer() + + # Add sample latencies (in milliseconds) + np.random.seed(42) + latencies = np.random.exponential(10, 1000) # Exponential distribution + for latency in latencies: + analyzer.add_latency(latency) + + dist = analyzer.get_distribution() + + assert dist["count"] == 1000 + assert dist["p50"] < dist["p90"] + assert dist["p90"] < dist["p95"] + assert dist["p95"] < dist["p99"] + assert dist["min"] < dist["mean"] < dist["max"] + + def test_recall_metric(self): + """Test calculating recall metrics for search results.""" + class RecallCalculator: + def __init__(self, ground_truth): + self.ground_truth = ground_truth + + def calculate_recall(self, query_id, retrieved_ids, k): + """Calculate recall@k for a query.""" + if query_id not in self.ground_truth: + return None + + true_ids = set(self.ground_truth[query_id][:k]) + retrieved_ids_set = set(retrieved_ids[:k]) + + intersection = true_ids.intersection(retrieved_ids_set) + recall = len(intersection) / len(true_ids) if true_ids else 0 + + return recall + + def calculate_average_recall(self, results, k): + """Calculate average recall@k across multiple queries.""" + recalls = [] + + for query_id, retrieved_ids in results.items(): + recall = self.calculate_recall(query_id, retrieved_ids, k) + if recall is not None: + recalls.append(recall) + + return statistics.mean(recalls) if recalls else 0 + + # Mock ground truth data + ground_truth = { + 0: [1, 2, 3, 4, 5], + 1: [6, 7, 8, 9, 10], + 2: [11, 12, 13, 14, 15] + } + + calculator = RecallCalculator(ground_truth) + + # Test perfect recall + perfect_results = { + 0: [1, 2, 3, 4, 5], + 1: [6, 7, 8, 9, 10], + 2: [11, 12, 13, 14, 15] + } + + avg_recall = calculator.calculate_average_recall(perfect_results, k=5) + assert avg_recall == 1.0 + + # Test partial recall + partial_results = { + 0: [1, 2, 3, 16, 17], # 3/5 correct + 1: [6, 7, 18, 19, 20], # 2/5 correct + 2: [11, 12, 13, 14, 21] # 4/5 correct + } + + avg_recall = calculator.calculate_average_recall(partial_results, k=5) + assert 0.5 < avg_recall < 0.7 # Should be (3+2+4)/15 = 0.6 + + def test_benchmark_summary_generation(self): + """Test generating comprehensive benchmark summary.""" + class BenchmarkSummary: + def __init__(self): + self.metrics = { + "latencies": [], + "throughputs": [], + "errors": 0, + "total_queries": 0 + } + self.start_time = None + self.end_time = None + + def start(self): + """Start benchmark timing.""" + self.start_time = time.time() + + def end(self): + """End benchmark timing.""" + self.end_time = time.time() + + def add_query_result(self, latency, success=True): + """Add a query result.""" + self.metrics["total_queries"] += 1 + + if success: + self.metrics["latencies"].append(latency) + else: + self.metrics["errors"] += 1 + + def add_throughput_sample(self, qps): + """Add a throughput sample.""" + self.metrics["throughputs"].append(qps) + + def generate_summary(self): + """Generate comprehensive benchmark summary.""" + if not self.start_time or not self.end_time: + return None + + duration = self.end_time - self.start_time + latencies = self.metrics["latencies"] + + summary = { + "duration": duration, + "total_queries": self.metrics["total_queries"], + "successful_queries": len(latencies), + "failed_queries": self.metrics["errors"], + "error_rate": self.metrics["errors"] / self.metrics["total_queries"] + if self.metrics["total_queries"] > 0 else 0 + } + + if latencies: + summary.update({ + "latency_mean": statistics.mean(latencies), + "latency_median": statistics.median(latencies), + "latency_min": min(latencies), + "latency_max": max(latencies), + "latency_p95": sorted(latencies)[int(len(latencies) * 0.95)], + "latency_p99": sorted(latencies)[int(len(latencies) * 0.99)] + }) + + if self.metrics["throughputs"]: + summary.update({ + "throughput_mean": statistics.mean(self.metrics["throughputs"]), + "throughput_max": max(self.metrics["throughputs"]), + "throughput_min": min(self.metrics["throughputs"]) + }) + + # Overall QPS + summary["overall_qps"] = self.metrics["total_queries"] / duration if duration > 0 else 0 + + return summary + + summary = BenchmarkSummary() + summary.start() + + # Simulate query results + np.random.seed(42) + for i in range(1000): + latency = np.random.exponential(10) # 10ms average + success = np.random.random() > 0.01 # 99% success rate + summary.add_query_result(latency, success) + + # Add throughput samples + for i in range(10): + summary.add_throughput_sample(100 + np.random.normal(0, 10)) + + time.sleep(0.1) # Simulate benchmark duration + summary.end() + + result = summary.generate_summary() + + assert result["total_queries"] == 1000 + assert result["error_rate"] < 0.02 # Should be around 1% + assert result["latency_p99"] > result["latency_p95"] + assert result["latency_p95"] > result["latency_median"] + + +class TestBenchmarkConfiguration: + """Test benchmark configuration and parameter tuning.""" + + def test_search_parameter_tuning(self): + """Test tuning search parameters for optimal performance.""" + class SearchParameterTuner: + def __init__(self, collection): + self.collection = collection + self.results = [] + + def test_parameters(self, params, test_queries): + """Test a set of search parameters.""" + latencies = [] + + for query in test_queries: + start = time.time() + self.collection.search( + data=[query], + anns_field="embedding", + param=params, + limit=10 + ) + latencies.append(time.time() - start) + + return { + "params": params, + "avg_latency": statistics.mean(latencies), + "p95_latency": sorted(latencies)[int(len(latencies) * 0.95)] + } + + def tune(self, parameter_sets, test_queries): + """Find optimal parameters.""" + for params in parameter_sets: + result = self.test_parameters(params, test_queries) + self.results.append(result) + + # Find best parameters based on latency + best = min(self.results, key=lambda x: x["avg_latency"]) + return best + + mock_collection = Mock() + mock_collection.search.return_value = [[Mock(id=1, distance=0.1)]] + + tuner = SearchParameterTuner(mock_collection) + + # Define parameter sets to test + parameter_sets = [ + {"metric_type": "L2", "params": {"nprobe": 10}}, + {"metric_type": "L2", "params": {"nprobe": 20}}, + {"metric_type": "L2", "params": {"nprobe": 50}}, + ] + + test_queries = np.random.randn(10, 128).astype(np.float32) + + best_params = tuner.tune(parameter_sets, test_queries) + + assert best_params is not None + assert "params" in best_params + assert "avg_latency" in best_params + + def test_workload_generation(self): + """Test generating different query workloads.""" + class WorkloadGenerator: + def __init__(self, dimension, seed=None): + self.dimension = dimension + if seed: + np.random.seed(seed) + + def generate_uniform(self, num_queries): + """Generate uniformly distributed queries.""" + return np.random.uniform(-1, 1, (num_queries, self.dimension)).astype(np.float32) + + def generate_gaussian(self, num_queries, centers=1): + """Generate queries from Gaussian distributions.""" + if centers == 1: + return np.random.randn(num_queries, self.dimension).astype(np.float32) + + # Multiple centers + queries_per_center = num_queries // centers + remainder = num_queries % centers + queries = [] + + for i in range(centers): + center = np.random.randn(self.dimension) * 10 + # Add extra query to first clusters if there's a remainder + extra = 1 if i < remainder else 0 + cluster = np.random.randn(queries_per_center + extra, self.dimension) + center + queries.append(cluster) + + return np.vstack(queries).astype(np.float32) + + def generate_skewed(self, num_queries, hot_ratio=0.2): + """Generate skewed workload with hot and cold queries.""" + num_hot = int(num_queries * hot_ratio) + num_cold = num_queries - num_hot + + # Hot queries - concentrated around a few points + hot_queries = np.random.randn(num_hot, self.dimension) * 0.1 + + # Cold queries - widely distributed + cold_queries = np.random.randn(num_cold, self.dimension) * 10 + + # Mix them + all_queries = np.vstack([hot_queries, cold_queries]) + np.random.shuffle(all_queries) + + return all_queries.astype(np.float32) + + def generate_temporal(self, num_queries, drift_rate=0.01): + """Generate queries with temporal drift.""" + queries = [] + current_center = np.zeros(self.dimension) + + for i in range(num_queries): + # Drift the center + current_center += np.random.randn(self.dimension) * drift_rate + + # Generate query around current center + query = current_center + np.random.randn(self.dimension) + queries.append(query) + + return np.array(queries).astype(np.float32) + + generator = WorkloadGenerator(dimension=128, seed=42) + + # Test uniform workload + uniform = generator.generate_uniform(100) + assert uniform.shape == (100, 128) + assert uniform.min() >= -1.1 # Small tolerance + assert uniform.max() <= 1.1 + + # Test Gaussian workload + gaussian = generator.generate_gaussian(100, centers=3) + assert gaussian.shape == (100, 128) + + # Test skewed workload + skewed = generator.generate_skewed(100, hot_ratio=0.2) + assert skewed.shape == (100, 128) + + # Test temporal workload + temporal = generator.generate_temporal(100, drift_rate=0.01) + assert temporal.shape == (100, 128) + + +class TestBenchmarkOutput: + """Test benchmark result output and reporting.""" + + def test_json_output_format(self, test_data_dir): + """Test outputting benchmark results in JSON format.""" + results = { + "timestamp": "2024-01-01T12:00:00", + "configuration": { + "collection": "test_collection", + "dimension": 1536, + "index_type": "DISKANN", + "num_processes": 4, + "batch_size": 100 + }, + "metrics": { + "total_queries": 10000, + "duration": 60.5, + "qps": 165.29, + "latency_p50": 5.2, + "latency_p95": 12.8, + "latency_p99": 18.3, + "error_rate": 0.001 + }, + "system_info": { + "cpu_count": 8, + "memory_gb": 32, + "platform": "Linux" + } + } + + output_file = test_data_dir / "benchmark_results.json" + + # Save results + with open(output_file, 'w') as f: + json.dump(results, f, indent=2) + + # Verify saved file + with open(output_file, 'r') as f: + loaded = json.load(f) + + assert loaded["metrics"]["qps"] == 165.29 + assert loaded["configuration"]["index_type"] == "DISKANN" + + def test_csv_output_format(self, test_data_dir): + """Test outputting benchmark results in CSV format.""" + import csv + + results = [ + {"timestamp": "2024-01-01T12:00:00", "qps": 150.5, "latency_p95": 12.3}, + {"timestamp": "2024-01-01T12:01:00", "qps": 155.2, "latency_p95": 11.8}, + {"timestamp": "2024-01-01T12:02:00", "qps": 148.9, "latency_p95": 12.7} + ] + + output_file = test_data_dir / "benchmark_results.csv" + + # Save results + with open(output_file, 'w', newline='') as f: + writer = csv.DictWriter(f, fieldnames=["timestamp", "qps", "latency_p95"]) + writer.writeheader() + writer.writerows(results) + + # Verify saved file + with open(output_file, 'r') as f: + reader = csv.DictReader(f) + loaded = list(reader) + + assert len(loaded) == 3 + assert float(loaded[0]["qps"]) == 150.5 + + def test_comparison_report_generation(self): + """Test generating comparison reports between benchmarks.""" + class ComparisonReport: + def __init__(self): + self.benchmarks = {} + + def add_benchmark(self, name, results): + """Add benchmark results.""" + self.benchmarks[name] = results + + def generate_comparison(self): + """Generate comparison report.""" + if len(self.benchmarks) < 2: + return None + + comparison = { + "benchmarks": [], + "best_qps": None, + "best_latency": None + } + + best_qps = 0 + best_latency = float('inf') + + for name, results in self.benchmarks.items(): + benchmark_summary = { + "name": name, + "qps": results.get("qps", 0), + "latency_p95": results.get("latency_p95", 0), + "latency_p99": results.get("latency_p99", 0), + "error_rate": results.get("error_rate", 0) + } + + comparison["benchmarks"].append(benchmark_summary) + + if benchmark_summary["qps"] > best_qps: + best_qps = benchmark_summary["qps"] + comparison["best_qps"] = name + + if benchmark_summary["latency_p95"] < best_latency: + best_latency = benchmark_summary["latency_p95"] + comparison["best_latency"] = name + + # Calculate improvements + if len(self.benchmarks) == 2: + names = list(self.benchmarks.keys()) + baseline = self.benchmarks[names[0]] + comparison_bench = self.benchmarks[names[1]] + + comparison["qps_improvement"] = ( + (comparison_bench["qps"] - baseline["qps"]) / baseline["qps"] * 100 + if baseline.get("qps", 0) > 0 else 0 + ) + + comparison["latency_improvement"] = ( + (baseline["latency_p95"] - comparison_bench["latency_p95"]) / baseline["latency_p95"] * 100 + if baseline.get("latency_p95", 0) > 0 else 0 + ) + + return comparison + + report = ComparisonReport() + + # Add benchmark results + report.add_benchmark("DISKANN", { + "qps": 1500, + "latency_p95": 10.5, + "latency_p99": 15.2, + "error_rate": 0.001 + }) + + report.add_benchmark("HNSW", { + "qps": 1200, + "latency_p95": 8.3, + "latency_p99": 12.1, + "error_rate": 0.002 + }) + + comparison = report.generate_comparison() + + assert comparison["best_qps"] == "DISKANN" + assert comparison["best_latency"] == "HNSW" + assert len(comparison["benchmarks"]) == 2 + assert comparison["qps_improvement"] == -20.0 # HNSW is 20% slower diff --git a/vdb_benchmark/tests/tests/test_vector_generation.py b/vdb_benchmark/tests/tests/test_vector_generation.py new file mode 100755 index 0000000..22cf2be --- /dev/null +++ b/vdb_benchmark/tests/tests/test_vector_generation.py @@ -0,0 +1,369 @@ +""" +Unit tests for vector generation utilities +""" +import pytest +import numpy as np +from unittest.mock import Mock, patch +import h5py +import tempfile +from pathlib import Path + + +class TestVectorGenerationUtilities: + """Test vector generation utility functions.""" + + def test_vector_normalization(self): + """Test different vector normalization methods.""" + class VectorNormalizer: + @staticmethod + def l2_normalize(vectors): + """L2 normalization.""" + norms = np.linalg.norm(vectors, axis=1, keepdims=True) + return vectors / (norms + 1e-10) # Add epsilon to avoid division by zero + + @staticmethod + def l1_normalize(vectors): + """L1 normalization.""" + norms = np.sum(np.abs(vectors), axis=1, keepdims=True) + return vectors / (norms + 1e-10) + + @staticmethod + def max_normalize(vectors): + """Max normalization (scale by maximum absolute value).""" + max_vals = np.max(np.abs(vectors), axis=1, keepdims=True) + return vectors / (max_vals + 1e-10) + + @staticmethod + def standardize(vectors): + """Standardization (zero mean, unit variance).""" + mean = np.mean(vectors, axis=0, keepdims=True) + std = np.std(vectors, axis=0, keepdims=True) + return (vectors - mean) / (std + 1e-10) + + # Test data + vectors = np.random.randn(100, 128).astype(np.float32) + + # Test L2 normalization + l2_norm = VectorNormalizer.l2_normalize(vectors) + norms = np.linalg.norm(l2_norm, axis=1) + np.testing.assert_array_almost_equal(norms, np.ones(100), decimal=5) + + # Test L1 normalization + l1_norm = VectorNormalizer.l1_normalize(vectors) + l1_sums = np.sum(np.abs(l1_norm), axis=1) + np.testing.assert_array_almost_equal(l1_sums, np.ones(100), decimal=5) + + # Test max normalization + max_norm = VectorNormalizer.max_normalize(vectors) + max_vals = np.max(np.abs(max_norm), axis=1) + np.testing.assert_array_almost_equal(max_vals, np.ones(100), decimal=5) + + # Test standardization + standardized = VectorNormalizer.standardize(vectors) + assert abs(np.mean(standardized)) < 0.01 # Mean should be close to 0 + assert abs(np.std(standardized) - 1.0) < 0.1 # Std should be close to 1 + + def test_vector_quantization(self): + """Test vector quantization methods.""" + class VectorQuantizer: + @staticmethod + def scalar_quantize(vectors, bits=8): + """Scalar quantization to specified bit depth.""" + min_val = np.min(vectors) + max_val = np.max(vectors) + + # Scale to [0, 2^bits - 1] + scale = (2 ** bits - 1) / (max_val - min_val) + quantized = np.round((vectors - min_val) * scale).astype(np.uint8 if bits == 8 else np.uint16) + + return quantized, (min_val, max_val, scale) + + @staticmethod + def dequantize(quantized, params): + """Dequantize vectors.""" + min_val, max_val, scale = params + return quantized.astype(np.float32) / scale + min_val + + @staticmethod + def product_quantize(vectors, num_subvectors=8, codebook_size=256): + """Simple product quantization simulation.""" + dimension = vectors.shape[1] + subvector_dim = dimension // num_subvectors + + codes = [] + codebooks = [] + + for i in range(num_subvectors): + start = i * subvector_dim + end = start + subvector_dim + subvectors = vectors[:, start:end] + + # Simulate codebook (in reality would use k-means) + codebook = np.random.randn(codebook_size, subvector_dim).astype(np.float32) + codebooks.append(codebook) + + # Assign codes (find nearest codebook entry) + # Simplified - just random assignment for testing + subvector_codes = np.random.randint(0, codebook_size, len(vectors)) + codes.append(subvector_codes) + + return np.array(codes).T, codebooks + + vectors = np.random.randn(100, 128).astype(np.float32) + + # Test scalar quantization + quantizer = VectorQuantizer() + quantized, params = quantizer.scalar_quantize(vectors, bits=8) + + assert quantized.dtype == np.uint8 + assert quantized.shape == vectors.shape + + # Test reconstruction + reconstructed = quantizer.dequantize(quantized, params) + assert reconstructed.shape == vectors.shape + + # Test product quantization + pq_codes, codebooks = quantizer.product_quantize(vectors, num_subvectors=8) + + assert pq_codes.shape == (100, 8) # 100 vectors, 8 subvectors + assert len(codebooks) == 8 + + def test_synthetic_dataset_generation(self): + """Test generating synthetic datasets with specific properties.""" + class SyntheticDataGenerator: + @staticmethod + def generate_clustered(num_vectors, dimension, num_clusters=10, cluster_std=0.1): + """Generate clustered vectors.""" + vectors_per_cluster = num_vectors // num_clusters + vectors = [] + labels = [] + + # Generate cluster centers + centers = np.random.randn(num_clusters, dimension) * 10 + + for i in range(num_clusters): + # Generate vectors around center + cluster_vectors = centers[i] + np.random.randn(vectors_per_cluster, dimension) * cluster_std + vectors.append(cluster_vectors) + labels.extend([i] * vectors_per_cluster) + + # Handle remaining vectors + remaining = num_vectors - (vectors_per_cluster * num_clusters) + if remaining > 0: + cluster_idx = np.random.randint(0, num_clusters) + extra_vectors = centers[cluster_idx] + np.random.randn(remaining, dimension) * cluster_std + vectors.append(extra_vectors) + labels.extend([cluster_idx] * remaining) + + return np.vstack(vectors).astype(np.float32), np.array(labels) + + @staticmethod + def generate_sparse(num_vectors, dimension, sparsity=0.9): + """Generate sparse vectors.""" + vectors = np.random.randn(num_vectors, dimension).astype(np.float32) + + # Create mask for sparsity + mask = np.random.random((num_vectors, dimension)) < sparsity + vectors[mask] = 0 + + return vectors + + @staticmethod + def generate_correlated(num_vectors, dimension, correlation=0.8): + """Generate vectors with correlated dimensions.""" + # Create correlation matrix + base = np.random.randn(num_vectors, 1) + + vectors = [] + for i in range(dimension): + if i == 0: + vectors.append(base.flatten()) + else: + # Mix with random noise based on correlation + noise = np.random.randn(num_vectors) + correlated = correlation * base.flatten() + (1 - correlation) * noise + vectors.append(correlated) + + return np.array(vectors).T.astype(np.float32) + + generator = SyntheticDataGenerator() + + # Test clustered generation + vectors, labels = generator.generate_clustered(1000, 128, num_clusters=10) + assert vectors.shape == (1000, 128) + assert len(labels) == 1000 + assert len(np.unique(labels)) == 10 + + # Test sparse generation + sparse_vectors = generator.generate_sparse(100, 256, sparsity=0.9) + assert sparse_vectors.shape == (100, 256) + sparsity_ratio = np.sum(sparse_vectors == 0) / sparse_vectors.size + assert 0.85 < sparsity_ratio < 0.95 # Should be approximately 0.9 + + # Test correlated generation + correlated = generator.generate_correlated(100, 64, correlation=0.8) + assert correlated.shape == (100, 64) + + def test_vector_io_operations(self, test_data_dir): + """Test saving and loading vectors in different formats.""" + class VectorIO: + @staticmethod + def save_npy(vectors, filepath): + """Save vectors as NPY file.""" + np.save(filepath, vectors) + + @staticmethod + def load_npy(filepath): + """Load vectors from NPY file.""" + return np.load(filepath) + + @staticmethod + def save_hdf5(vectors, filepath, dataset_name="vectors"): + """Save vectors as HDF5 file.""" + with h5py.File(filepath, 'w') as f: + f.create_dataset(dataset_name, data=vectors, compression="gzip") + + @staticmethod + def load_hdf5(filepath, dataset_name="vectors"): + """Load vectors from HDF5 file.""" + with h5py.File(filepath, 'r') as f: + return f[dataset_name][:] + + @staticmethod + def save_binary(vectors, filepath): + """Save vectors as binary file.""" + vectors.tofile(filepath) + + @staticmethod + def load_binary(filepath, dtype=np.float32, shape=None): + """Load vectors from binary file.""" + vectors = np.fromfile(filepath, dtype=dtype) + if shape: + vectors = vectors.reshape(shape) + return vectors + + @staticmethod + def save_text(vectors, filepath): + """Save vectors as text file.""" + np.savetxt(filepath, vectors, fmt='%.6f') + + @staticmethod + def load_text(filepath): + """Load vectors from text file.""" + return np.loadtxt(filepath, dtype=np.float32) + + io_handler = VectorIO() + vectors = np.random.randn(100, 128).astype(np.float32) + + # Test NPY format + npy_path = test_data_dir / "vectors.npy" + io_handler.save_npy(vectors, npy_path) + loaded_npy = io_handler.load_npy(npy_path) + np.testing.assert_array_almost_equal(vectors, loaded_npy) + + # Test HDF5 format + hdf5_path = test_data_dir / "vectors.h5" + io_handler.save_hdf5(vectors, hdf5_path) + loaded_hdf5 = io_handler.load_hdf5(hdf5_path) + np.testing.assert_array_almost_equal(vectors, loaded_hdf5) + + # Test binary format + bin_path = test_data_dir / "vectors.bin" + io_handler.save_binary(vectors, bin_path) + loaded_bin = io_handler.load_binary(bin_path, shape=(100, 128)) + np.testing.assert_array_almost_equal(vectors, loaded_bin) + + # Test text format (smaller dataset for text) + small_vectors = vectors[:10] + txt_path = test_data_dir / "vectors.txt" + io_handler.save_text(small_vectors, txt_path) + loaded_txt = io_handler.load_text(txt_path) + np.testing.assert_array_almost_equal(small_vectors, loaded_txt, decimal=5) + + +class TestIndexConfiguration: + """Test index-specific configurations and parameters.""" + + def test_diskann_parameter_validation(self): + """Test DiskANN index parameter validation.""" + class DiskANNConfig: + VALID_METRICS = ["L2", "IP", "COSINE"] + + @staticmethod + def validate_params(params): + """Validate DiskANN parameters.""" + errors = [] + + # Check metric type + if params.get("metric_type") not in DiskANNConfig.VALID_METRICS: + errors.append(f"Invalid metric_type: {params.get('metric_type')}") + + # Check max_degree + max_degree = params.get("max_degree", 64) + if not (1 <= max_degree <= 128): + errors.append(f"max_degree must be between 1 and 128, got {max_degree}") + + # Check search_list_size + search_list = params.get("search_list_size", 200) + if not (100 <= search_list <= 1000): + errors.append(f"search_list_size must be between 100 and 1000, got {search_list}") + + # Check PQ parameters if present + if "pq_code_budget_gb" in params: + budget = params["pq_code_budget_gb"] + if budget <= 0: + errors.append(f"pq_code_budget_gb must be positive, got {budget}") + + return len(errors) == 0, errors + + @staticmethod + def get_default_params(num_vectors, dimension): + """Get default parameters based on dataset size.""" + if num_vectors < 1000000: + return { + "metric_type": "L2", + "max_degree": 32, + "search_list_size": 100 + } + elif num_vectors < 10000000: + return { + "metric_type": "L2", + "max_degree": 64, + "search_list_size": 200 + } + else: + return { + "metric_type": "L2", + "max_degree": 64, + "search_list_size": 300, + "pq_code_budget_gb": 0.2 + } + + # Test valid parameters + valid_params = { + "metric_type": "L2", + "max_degree": 64, + "search_list_size": 200 + } + + is_valid, errors = DiskANNConfig.validate_params(valid_params) + assert is_valid is True + assert len(errors) == 0 + + # Test invalid parameters + invalid_params = { + "metric_type": "INVALID", + "max_degree": 200, + "search_list_size": 50 + } + + is_valid, errors = DiskANNConfig.validate_params(invalid_params) + assert is_valid is False + assert len(errors) == 3 + + # Test default parameter generation + small_defaults = DiskANNConfig.get_default_params(100000, 128) + assert small_defaults["max_degree"] == 32 + + large_defaults = DiskANNConfig.get_default_params(20000000, 1536) + assert "pq_code_budget_gb" in large_defaults diff --git a/vdb_benchmark/tests/tests/verify_fixes.py b/vdb_benchmark/tests/tests/verify_fixes.py new file mode 100755 index 0000000..ec482a3 --- /dev/null +++ b/vdb_benchmark/tests/tests/verify_fixes.py @@ -0,0 +1,81 @@ +#!/usr/bin/env python3 +""" +Test Suite Verification Script +Verifies that all test fixes have been applied correctly +""" +import subprocess +import sys +import json +from pathlib import Path + +def run_single_test(test_path): + """Run a single test and return result.""" + result = subprocess.run( + [sys.executable, "-m", "pytest", test_path, "-v", "--tb=short"], + capture_output=True, + text=True + ) + return result.returncode == 0, result.stdout, result.stderr + +def main(): + """Run all previously failing tests to verify fixes.""" + + # List of previously failing tests + failing_tests = [ + "tests/test_compact_and_watch.py::TestMonitoring::test_collection_stats_monitoring", + "tests/test_config.py::TestConfigurationLoader::test_config_environment_variable_override", + "tests/test_database_connection.py::TestConnectionResilience::test_automatic_reconnection", + "tests/test_index_management.py::TestIndexManagement::test_index_status_check", + "tests/test_load_vdb.py::TestVectorLoading::test_insertion_with_error_handling", + "tests/test_load_vdb.py::TestVectorLoading::test_insertion_rate_monitoring", + "tests/test_simple_bench.py::TestBenchmarkConfiguration::test_workload_generation" + ] + + print("=" * 60) + print("VDB-Bench Test Suite - Verification of Fixes") + print("=" * 60) + print() + + results = [] + + for test in failing_tests: + print(f"Testing: {test}") + passed, stdout, stderr = run_single_test(test) + + results.append({ + "test": test, + "passed": passed, + "output": stdout if not passed else "" + }) + + if passed: + print(" ✅ PASSED") + else: + print(" ❌ FAILED") + print(f" Error: {stderr[:200]}") + print() + + # Summary + print("=" * 60) + print("Summary") + print("=" * 60) + + passed_count = sum(1 for r in results if r["passed"]) + failed_count = len(results) - passed_count + + print(f"Total Tests: {len(results)}") + print(f"Passed: {passed_count}") + print(f"Failed: {failed_count}") + + if failed_count == 0: + print("\n✅ All previously failing tests now pass!") + return 0 + else: + print("\n❌ Some tests are still failing. Please review the fixes.") + for result in results: + if not result["passed"]: + print(f" - {result['test']}") + return 1 + +if __name__ == "__main__": + sys.exit(main()) diff --git a/vdb_benchmark/tests/utils/__init__.py b/vdb_benchmark/tests/utils/__init__.py new file mode 100755 index 0000000..df966d6 --- /dev/null +++ b/vdb_benchmark/tests/utils/__init__.py @@ -0,0 +1,47 @@ +""" +Test utilities package for vdb-bench +""" + +from .test_helpers import ( + TestDataGenerator, + MockMilvusCollection, + PerformanceSimulator, + temporary_directory, + mock_time_progression, + create_test_yaml_config, + create_test_json_results, + assert_performance_within_bounds, + calculate_recall, + calculate_precision, + generate_random_string, + BenchmarkResultValidator +) + +from .mock_data import ( + MockDataGenerator, + BenchmarkDatasetGenerator, + QueryWorkloadGenerator, + MetricDataGenerator +) + +__all__ = [ + # Test helpers + 'TestDataGenerator', + 'MockMilvusCollection', + 'PerformanceSimulator', + 'temporary_directory', + 'mock_time_progression', + 'create_test_yaml_config', + 'create_test_json_results', + 'assert_performance_within_bounds', + 'calculate_recall', + 'calculate_precision', + 'generate_random_string', + 'BenchmarkResultValidator', + + # Mock data + 'MockDataGenerator', + 'BenchmarkDatasetGenerator', + 'QueryWorkloadGenerator', + 'MetricDataGenerator' +] diff --git a/vdb_benchmark/tests/utils/mock_data.py b/vdb_benchmark/tests/utils/mock_data.py new file mode 100755 index 0000000..da60e37 --- /dev/null +++ b/vdb_benchmark/tests/utils/mock_data.py @@ -0,0 +1,415 @@ +""" +Mock data generators for vdb-bench testing +""" +import numpy as np +import random +from typing import List, Dict, Any, Tuple, Optional +from datetime import datetime, timedelta +import json + + +class MockDataGenerator: + """Generate various types of mock data for testing.""" + + def __init__(self, seed: Optional[int] = None): + """Initialize with optional random seed for reproducibility.""" + if seed is not None: + random.seed(seed) + np.random.seed(seed) + + @staticmethod + def generate_sift_like_vectors(num_vectors: int, dimension: int = 128) -> np.ndarray: + """Generate SIFT-like vectors (similar to common benchmark datasets).""" + # SIFT vectors are typically L2-normalized and have specific distribution + vectors = np.random.randn(num_vectors, dimension).astype(np.float32) + + # Add some structure (make some dimensions more important) + important_dims = random.sample(range(dimension), k=dimension // 4) + vectors[:, important_dims] *= 3 + + # L2 normalize + norms = np.linalg.norm(vectors, axis=1, keepdims=True) + vectors = vectors / (norms + 1e-10) + + # Scale to typical SIFT range + vectors = vectors * 512 + + return vectors.astype(np.float32) + + @staticmethod + def generate_deep_learning_embeddings(num_vectors: int, + dimension: int = 768, + model_type: str = "bert") -> np.ndarray: + """Generate embeddings similar to deep learning models.""" + if model_type == "bert": + # BERT-like embeddings (768-dimensional) + vectors = np.random.randn(num_vectors, dimension).astype(np.float32) + # BERT embeddings typically have values in [-2, 2] range + vectors = np.clip(vectors * 0.5, -2, 2) + + elif model_type == "resnet": + # ResNet-like features (2048-dimensional typical) + vectors = np.random.randn(num_vectors, dimension).astype(np.float32) + # Apply ReLU-like sparsity + vectors[vectors < 0] = 0 + # L2 normalize + norms = np.linalg.norm(vectors, axis=1, keepdims=True) + vectors = vectors / (norms + 1e-10) + + elif model_type == "clip": + # CLIP-like embeddings (512-dimensional, normalized) + vectors = np.random.randn(num_vectors, dimension).astype(np.float32) + # Normalize to unit sphere + norms = np.linalg.norm(vectors, axis=1, keepdims=True) + vectors = vectors / (norms + 1e-10) + + else: + # Generic embeddings + vectors = np.random.randn(num_vectors, dimension).astype(np.float32) + + return vectors + + @staticmethod + def generate_time_series_vectors(num_vectors: int, + dimension: int = 100, + num_series: int = 10) -> Tuple[np.ndarray, List[int]]: + """Generate time series data as vectors with series labels.""" + vectors = [] + labels = [] + + for series_id in range(num_series): + # Generate base pattern for this series + base_pattern = np.sin(np.linspace(0, 4 * np.pi, dimension)) + base_pattern += np.random.randn(dimension) * 0.1 # Add noise + + # Generate variations of the pattern + series_vectors = num_vectors // num_series + for _ in range(series_vectors): + # Add temporal drift and noise + variation = base_pattern + np.random.randn(dimension) * 0.3 + variation += np.random.randn() * 0.1 # Global shift + + vectors.append(variation) + labels.append(series_id) + + # Handle remaining vectors + remaining = num_vectors - len(vectors) + for _ in range(remaining): + vectors.append(vectors[-1] + np.random.randn(dimension) * 0.1) + labels.append(labels[-1]) + + return np.array(vectors).astype(np.float32), labels + + @staticmethod + def generate_categorical_embeddings(num_vectors: int, + num_categories: int = 100, + dimension: int = 64) -> Tuple[np.ndarray, List[str]]: + """Generate embeddings for categorical data.""" + # Create embedding for each category + category_embeddings = np.random.randn(num_categories, dimension).astype(np.float32) + + # Normalize category embeddings + norms = np.linalg.norm(category_embeddings, axis=1, keepdims=True) + category_embeddings = category_embeddings / (norms + 1e-10) + + vectors = [] + categories = [] + + # Generate vectors by sampling categories + for _ in range(num_vectors): + cat_idx = random.randint(0, num_categories - 1) + + # Add small noise to category embedding + vector = category_embeddings[cat_idx] + np.random.randn(dimension) * 0.05 + + vectors.append(vector) + categories.append(f"category_{cat_idx}") + + return np.array(vectors).astype(np.float32), categories + + @staticmethod + def generate_multimodal_vectors(num_vectors: int, + text_dim: int = 768, + image_dim: int = 2048) -> Dict[str, np.ndarray]: + """Generate multimodal vectors (text + image embeddings).""" + # Generate text embeddings (BERT-like) + text_vectors = np.random.randn(num_vectors, text_dim).astype(np.float32) + text_vectors = np.clip(text_vectors * 0.5, -2, 2) + + # Generate image embeddings (ResNet-like) + image_vectors = np.random.randn(num_vectors, image_dim).astype(np.float32) + image_vectors[image_vectors < 0] = 0 # ReLU + norms = np.linalg.norm(image_vectors, axis=1, keepdims=True) + image_vectors = image_vectors / (norms + 1e-10) + + # Combined embeddings (concatenated and projected) + combined_dim = 512 + projection_matrix = np.random.randn(text_dim + image_dim, combined_dim).astype(np.float32) + projection_matrix /= np.sqrt(text_dim + image_dim) # Xavier initialization + + concatenated = np.hstack([text_vectors, image_vectors]) + combined_vectors = np.dot(concatenated, projection_matrix) + + # Normalize combined vectors + norms = np.linalg.norm(combined_vectors, axis=1, keepdims=True) + combined_vectors = combined_vectors / (norms + 1e-10) + + return { + "text": text_vectors, + "image": image_vectors, + "combined": combined_vectors + } + + +class BenchmarkDatasetGenerator: + """Generate datasets similar to common benchmarks.""" + + @staticmethod + def generate_ann_benchmark_dataset(dataset_type: str = "random", + num_train: int = 100000, + num_test: int = 10000, + dimension: int = 128, + num_neighbors: int = 100) -> Dict[str, Any]: + """Generate dataset similar to ANN-Benchmarks format.""" + + if dataset_type == "random": + train_vectors = np.random.randn(num_train, dimension).astype(np.float32) + test_vectors = np.random.randn(num_test, dimension).astype(np.float32) + + elif dataset_type == "clustered": + train_vectors = [] + num_clusters = 100 + vectors_per_cluster = num_train // num_clusters + + for _ in range(num_clusters): + center = np.random.randn(dimension) * 10 + cluster = center + np.random.randn(vectors_per_cluster, dimension) + train_vectors.append(cluster) + + train_vectors = np.vstack(train_vectors).astype(np.float32) + + # Test vectors from same distribution + test_vectors = [] + test_per_cluster = num_test // num_clusters + + for _ in range(num_clusters): + center = np.random.randn(dimension) * 10 + cluster = center + np.random.randn(test_per_cluster, dimension) + test_vectors.append(cluster) + + test_vectors = np.vstack(test_vectors).astype(np.float32) + + else: + raise ValueError(f"Unknown dataset type: {dataset_type}") + + # Generate ground truth (simplified - random for now) + ground_truth = np.random.randint(0, num_train, + (num_test, num_neighbors)) + + # Calculate distances for ground truth (simplified) + distances = np.random.random((num_test, num_neighbors)).astype(np.float32) + distances.sort(axis=1) # Ensure sorted by distance + + return { + "train": train_vectors, + "test": test_vectors, + "neighbors": ground_truth, + "distances": distances, + "dimension": dimension, + "metric": "euclidean" + } + + @staticmethod + def generate_streaming_dataset(initial_size: int = 10000, + dimension: int = 128, + stream_rate: int = 100, + drift_rate: float = 0.01) -> Dict[str, Any]: + """Generate dataset that simulates streaming/incremental scenarios.""" + # Initial dataset + initial_vectors = np.random.randn(initial_size, dimension).astype(np.float32) + + # Streaming batches with concept drift + stream_batches = [] + current_center = np.zeros(dimension) + + for batch_id in range(10): # 10 batches + # Drift the distribution center + current_center += np.random.randn(dimension) * drift_rate + + # Generate batch around drifted center + batch = current_center + np.random.randn(stream_rate, dimension) + stream_batches.append(batch.astype(np.float32)) + + return { + "initial": initial_vectors, + "stream_batches": stream_batches, + "dimension": dimension, + "stream_rate": stream_rate, + "drift_rate": drift_rate + } + + +class QueryWorkloadGenerator: + """Generate different types of query workloads.""" + + @staticmethod + def generate_uniform_workload(num_queries: int, + dimension: int, + seed: Optional[int] = None) -> np.ndarray: + """Generate uniformly distributed queries.""" + if seed: + np.random.seed(seed) + + return np.random.uniform(-1, 1, (num_queries, dimension)).astype(np.float32) + + @staticmethod + def generate_hotspot_workload(num_queries: int, + dimension: int, + num_hotspots: int = 5, + hotspot_ratio: float = 0.8) -> np.ndarray: + """Generate workload with hotspots (skewed distribution).""" + queries = [] + + # Generate hotspot centers + hotspots = np.random.randn(num_hotspots, dimension) * 10 + + num_hot_queries = int(num_queries * hotspot_ratio) + num_cold_queries = num_queries - num_hot_queries + + # Hot queries - concentrated around hotspots + for _ in range(num_hot_queries): + hotspot_idx = random.randint(0, num_hotspots - 1) + query = hotspots[hotspot_idx] + np.random.randn(dimension) * 0.1 + queries.append(query) + + # Cold queries - random distribution + cold_queries = np.random.randn(num_cold_queries, dimension) * 5 + queries.extend(cold_queries) + + # Shuffle to mix hot and cold queries + queries = np.array(queries) + np.random.shuffle(queries) + + return queries.astype(np.float32) + + @staticmethod + def generate_temporal_workload(num_queries: int, + dimension: int, + time_windows: int = 10) -> List[np.ndarray]: + """Generate workload that changes over time.""" + queries_per_window = num_queries // time_windows + workload_windows = [] + + # Start with initial distribution center + current_center = np.zeros(dimension) + + for window in range(time_windows): + # Drift the center over time + drift = np.random.randn(dimension) * 0.5 + current_center += drift + + # Generate queries for this time window + window_queries = current_center + np.random.randn(queries_per_window, dimension) + workload_windows.append(window_queries.astype(np.float32)) + + return workload_windows + + @staticmethod + def generate_mixed_workload(num_queries: int, + dimension: int) -> Dict[str, np.ndarray]: + """Generate mixed workload with different query types.""" + workload = {} + + # Point queries (exact vectors) + num_point = num_queries // 4 + workload["point"] = np.random.randn(num_point, dimension).astype(np.float32) + + # Range queries (represented as center + radius) + num_range = num_queries // 4 + range_centers = np.random.randn(num_range, dimension).astype(np.float32) + range_radii = np.random.uniform(0.1, 2.0, num_range).astype(np.float32) + workload["range"] = {"centers": range_centers, "radii": range_radii} + + # KNN queries (standard similarity search) + num_knn = num_queries // 4 + workload["knn"] = np.random.randn(num_knn, dimension).astype(np.float32) + + # Filtered queries (queries with metadata filters) + num_filtered = num_queries - num_point - num_range - num_knn + filtered_queries = np.random.randn(num_filtered, dimension).astype(np.float32) + filters = [{"category": random.choice(["A", "B", "C"])} for _ in range(num_filtered)] + workload["filtered"] = {"queries": filtered_queries, "filters": filters} + + return workload + + +class MetricDataGenerator: + """Generate realistic metric data for testing.""" + + @staticmethod + def generate_latency_distribution(num_samples: int = 1000, + distribution: str = "lognormal", + mean: float = 10, + std: float = 5) -> np.ndarray: + """Generate realistic latency distribution.""" + if distribution == "lognormal": + # Log-normal distribution (common for latencies) + log_mean = np.log(mean / np.sqrt(1 + (std / mean) ** 2)) + log_std = np.sqrt(np.log(1 + (std / mean) ** 2)) + latencies = np.random.lognormal(log_mean, log_std, num_samples) + + elif distribution == "exponential": + # Exponential distribution + latencies = np.random.exponential(mean, num_samples) + + elif distribution == "gamma": + # Gamma distribution + shape = (mean / std) ** 2 + scale = std ** 2 / mean + latencies = np.random.gamma(shape, scale, num_samples) + + else: + # Normal distribution (less realistic for latencies) + latencies = np.random.normal(mean, std, num_samples) + latencies = np.maximum(latencies, 0.1) # Ensure positive + + return latencies.astype(np.float32) + + @staticmethod + def generate_throughput_series(duration: int = 3600, # 1 hour in seconds + base_qps: float = 1000, + pattern: str = "steady") -> List[Tuple[float, float]]: + """Generate time series of throughput measurements.""" + series = [] + + if pattern == "steady": + for t in range(duration): + qps = base_qps + np.random.normal(0, base_qps * 0.05) + series.append((t, max(0, qps))) + + elif pattern == "diurnal": + # Simulate daily pattern + for t in range(duration): + # Use sine wave for daily pattern + hour = (t / 3600) % 24 + multiplier = 0.5 + 0.5 * np.sin(2 * np.pi * (hour - 6) / 24) + qps = base_qps * multiplier + np.random.normal(0, base_qps * 0.05) + series.append((t, max(0, qps))) + + elif pattern == "spike": + # Occasional spikes + for t in range(duration): + if random.random() < 0.01: # 1% chance of spike + qps = base_qps * random.uniform(2, 5) + else: + qps = base_qps + np.random.normal(0, base_qps * 0.05) + series.append((t, max(0, qps))) + + elif pattern == "degrading": + # Performance degradation over time + for t in range(duration): + degradation = 1 - (t / duration) * 0.5 # 50% degradation + qps = base_qps * degradation + np.random.normal(0, base_qps * 0.05) + series.append((t, max(0, qps))) + + return series diff --git a/vdb_benchmark/tests/utils/test_helpers.py b/vdb_benchmark/tests/utils/test_helpers.py new file mode 100755 index 0000000..1721ba9 --- /dev/null +++ b/vdb_benchmark/tests/utils/test_helpers.py @@ -0,0 +1,458 @@ +""" +Test helper utilities for vdb-bench tests +""" +import numpy as np +import time +import json +import yaml +from pathlib import Path +from typing import Dict, Any, List, Optional, Tuple +from unittest.mock import Mock, MagicMock +import random +import string +from contextlib import contextmanager +import tempfile +import shutil + + +class TestDataGenerator: + """Generate test data for various scenarios.""" + + @staticmethod + def generate_vectors(num_vectors: int, dimension: int, + distribution: str = "normal", + seed: Optional[int] = None) -> np.ndarray: + """Generate test vectors with specified distribution.""" + if seed is not None: + np.random.seed(seed) + + if distribution == "normal": + return np.random.randn(num_vectors, dimension).astype(np.float32) + elif distribution == "uniform": + return np.random.uniform(-1, 1, (num_vectors, dimension)).astype(np.float32) + elif distribution == "sparse": + vectors = np.random.randn(num_vectors, dimension).astype(np.float32) + mask = np.random.random((num_vectors, dimension)) < 0.9 + vectors[mask] = 0 + return vectors + elif distribution == "clustered": + vectors = [] + clusters = 10 + vectors_per_cluster = num_vectors // clusters + + for _ in range(clusters): + center = np.random.randn(dimension) * 10 + cluster_vectors = center + np.random.randn(vectors_per_cluster, dimension) * 0.5 + vectors.append(cluster_vectors) + + return np.vstack(vectors).astype(np.float32) + else: + raise ValueError(f"Unknown distribution: {distribution}") + + @staticmethod + def generate_ids(num_ids: int, start: int = 0) -> List[int]: + """Generate sequential IDs.""" + return list(range(start, start + num_ids)) + + @staticmethod + def generate_metadata(num_items: int) -> List[Dict[str, Any]]: + """Generate random metadata for vectors.""" + metadata = [] + + for i in range(num_items): + metadata.append({ + "id": i, + "category": random.choice(["A", "B", "C", "D"]), + "timestamp": time.time() + i, + "score": random.random(), + "tags": random.sample(["tag1", "tag2", "tag3", "tag4", "tag5"], + k=random.randint(1, 3)) + }) + + return metadata + + @staticmethod + def generate_ground_truth(num_queries: int, num_vectors: int, + top_k: int = 100) -> Dict[int, List[int]]: + """Generate ground truth for recall calculation.""" + ground_truth = {} + + for query_id in range(num_queries): + # Generate random ground truth IDs + true_ids = random.sample(range(num_vectors), + min(top_k, num_vectors)) + ground_truth[query_id] = true_ids + + return ground_truth + + @staticmethod + def generate_config(collection_name: str = "test_collection") -> Dict[str, Any]: + """Generate test configuration.""" + return { + "database": { + "host": "localhost", + "port": 19530, + "database": "default", + "timeout": 30 + }, + "dataset": { + "collection_name": collection_name, + "num_vectors": 10000, + "dimension": 128, + "distribution": "uniform", + "batch_size": 1000, + "num_shards": 2 + }, + "index": { + "index_type": "HNSW", + "metric_type": "L2", + "params": { + "M": 16, + "efConstruction": 200 + } + }, + "benchmark": { + "num_queries": 1000, + "top_k": 10, + "num_processes": 4, + "runtime": 60 + } + } + + +class MockMilvusCollection: + """Advanced mock Milvus collection for testing.""" + + def __init__(self, name: str, dimension: int = 128): + self.name = name + self.dimension = dimension + self.vectors = [] + self.ids = [] + self.num_entities = 0 + self.index = None + self.is_loaded = False + self.partitions = [] + self.schema = Mock() + self.description = f"Mock collection {name}" + + # Index-related attributes + self.index_progress = 0 + self.index_state = "NotExist" + self.index_params = None + + # Compaction-related + self.compaction_id = None + self.compaction_state = "Idle" + + # Search behavior + self.search_latency = 0.01 # Default 10ms + self.search_results = None + + def insert(self, data: List) -> Mock: + """Mock insert operation.""" + vectors = data[0] if isinstance(data[0], (list, np.ndarray)) else data + num_new = len(vectors) if hasattr(vectors, '__len__') else 1 + + self.vectors.extend(vectors) + new_ids = list(range(self.num_entities, self.num_entities + num_new)) + self.ids.extend(new_ids) + self.num_entities += num_new + + result = Mock() + result.primary_keys = new_ids + result.insert_count = num_new + + return result + + def search(self, data: List, anns_field: str, param: Dict, + limit: int = 10, **kwargs) -> List: + """Mock search operation.""" + time.sleep(self.search_latency) # Simulate latency + + if self.search_results: + return self.search_results + + # Generate mock results + results = [] + for query in data: + query_results = [] + for i in range(min(limit, 10)): + result = Mock() + result.id = random.randint(0, max(self.num_entities - 1, 0)) + result.distance = random.random() + query_results.append(result) + results.append(query_results) + + return results + + def create_index(self, field_name: str, index_params: Dict) -> bool: + """Mock index creation.""" + self.index_params = index_params + self.index_state = "InProgress" + self.index_progress = 0 + + # Simulate index building + self.index = Mock() + self.index.params = index_params + self.index.field_name = field_name + + return True + + def drop_index(self, field_name: str) -> None: + """Mock index dropping.""" + self.index = None + self.index_state = "NotExist" + self.index_progress = 0 + self.index_params = None + + def load(self) -> None: + """Mock collection loading.""" + self.is_loaded = True + + def release(self) -> None: + """Mock collection release.""" + self.is_loaded = False + + def flush(self) -> None: + """Mock flush operation.""" + pass # Simulate successful flush + + def compact(self) -> int: + """Mock compaction operation.""" + self.compaction_id = random.randint(1000, 9999) + self.compaction_state = "Executing" + return self.compaction_id + + def get_compaction_state(self, compaction_id: int) -> str: + """Mock getting compaction state.""" + return self.compaction_state + + def drop(self) -> None: + """Mock collection drop.""" + self.vectors = [] + self.ids = [] + self.num_entities = 0 + self.index = None + + def create_partition(self, partition_name: str) -> None: + """Mock partition creation.""" + if partition_name not in self.partitions: + self.partitions.append(partition_name) + + def has_partition(self, partition_name: str) -> bool: + """Check if partition exists.""" + return partition_name in self.partitions + + def get_stats(self) -> Dict[str, Any]: + """Get collection statistics.""" + return { + "row_count": self.num_entities, + "partitions": len(self.partitions), + "index_state": self.index_state, + "loaded": self.is_loaded + } + + +class PerformanceSimulator: + """Simulate performance metrics for testing.""" + + def __init__(self): + self.base_latency = 10 # Base latency in ms + self.base_qps = 1000 + self.variation = 0.2 # 20% variation + + def simulate_latency(self, num_samples: int = 100) -> List[float]: + """Generate simulated latency values.""" + latencies = [] + + for _ in range(num_samples): + # Add random variation + variation = random.uniform(1 - self.variation, 1 + self.variation) + latency = self.base_latency * variation + + # Occasionally add outliers + if random.random() < 0.05: # 5% outliers + latency *= random.uniform(2, 5) + + latencies.append(latency) + + return latencies + + def simulate_throughput(self, duration: int = 60) -> List[Tuple[float, float]]: + """Generate simulated throughput over time.""" + throughput_data = [] + current_time = 0 + + while current_time < duration: + # Simulate varying QPS + variation = random.uniform(1 - self.variation, 1 + self.variation) + qps = self.base_qps * variation + + # Occasionally simulate load spikes or drops + if random.random() < 0.1: # 10% chance of anomaly + if random.random() < 0.5: + qps *= 0.5 # Drop + else: + qps *= 1.5 # Spike + + throughput_data.append((current_time, qps)) + current_time += 1 + + return throughput_data + + def simulate_resource_usage(self, duration: int = 60) -> Dict[str, List[Tuple[float, float]]]: + """Simulate CPU and memory usage over time.""" + cpu_usage = [] + memory_usage = [] + + base_cpu = 50 + base_memory = 60 + + for t in range(duration): + # CPU usage + cpu = base_cpu + random.uniform(-10, 20) + cpu = max(0, min(100, cpu)) # Clamp to 0-100 + cpu_usage.append((t, cpu)) + + # Memory usage (more stable) + memory = base_memory + random.uniform(-5, 10) + memory = max(0, min(100, memory)) + memory_usage.append((t, memory)) + + # Gradually increase if simulating memory leak + if random.random() < 0.1: + base_memory += 0.5 + + return { + "cpu": cpu_usage, + "memory": memory_usage + } + + +@contextmanager +def temporary_directory(): + """Context manager for temporary directory.""" + temp_dir = tempfile.mkdtemp() + try: + yield Path(temp_dir) + finally: + shutil.rmtree(temp_dir, ignore_errors=True) + + +@contextmanager +def mock_time_progression(increments: List[float]): + """Mock time.time() with controlled progression.""" + time_values = [] + current = 0 + + for increment in increments: + current += increment + time_values.append(current) + + with patch('time.time', side_effect=time_values): + yield + + +def create_test_yaml_config(path: Path, config: Dict[str, Any]) -> None: + """Create a YAML configuration file for testing.""" + with open(path, 'w') as f: + yaml.dump(config, f, default_flow_style=False) + + +def create_test_json_results(path: Path, results: Dict[str, Any]) -> None: + """Create a JSON results file for testing.""" + with open(path, 'w') as f: + json.dump(results, f, indent=2) + + +def assert_performance_within_bounds(actual: float, expected: float, + tolerance: float = 0.1) -> None: + """Assert that performance metric is within expected bounds.""" + lower_bound = expected * (1 - tolerance) + upper_bound = expected * (1 + tolerance) + + assert lower_bound <= actual <= upper_bound, \ + f"Performance {actual} not within {tolerance*100}% of expected {expected}" + + +def calculate_recall(retrieved: List[int], relevant: List[int], k: int) -> float: + """Calculate recall@k metric.""" + retrieved_k = set(retrieved[:k]) + relevant_k = set(relevant[:k]) + + if not relevant_k: + return 0.0 + + intersection = retrieved_k.intersection(relevant_k) + return len(intersection) / len(relevant_k) + + +def calculate_precision(retrieved: List[int], relevant: List[int], k: int) -> float: + """Calculate precision@k metric.""" + retrieved_k = set(retrieved[:k]) + relevant_set = set(relevant) + + if not retrieved_k: + return 0.0 + + intersection = retrieved_k.intersection(relevant_set) + return len(intersection) / len(retrieved_k) + + +def generate_random_string(length: int = 10) -> str: + """Generate random string for testing.""" + return ''.join(random.choices(string.ascii_lowercase + string.digits, k=length)) + + +class BenchmarkResultValidator: + """Validate benchmark results for consistency.""" + + @staticmethod + def validate_metrics(metrics: Dict[str, Any]) -> Tuple[bool, List[str]]: + """Validate that metrics are reasonable.""" + errors = [] + + # Check required fields + required_fields = ["qps", "latency_p50", "latency_p95", "latency_p99"] + for field in required_fields: + if field not in metrics: + errors.append(f"Missing required field: {field}") + + # Check value ranges + if "qps" in metrics: + if metrics["qps"] <= 0: + errors.append("QPS must be positive") + if metrics["qps"] > 1000000: + errors.append("QPS seems unrealistically high") + + if "latency_p50" in metrics and "latency_p95" in metrics: + if metrics["latency_p50"] > metrics["latency_p95"]: + errors.append("P50 latency cannot be greater than P95") + + if "latency_p95" in metrics and "latency_p99" in metrics: + if metrics["latency_p95"] > metrics["latency_p99"]: + errors.append("P95 latency cannot be greater than P99") + + if "error_rate" in metrics: + if not (0 <= metrics["error_rate"] <= 1): + errors.append("Error rate must be between 0 and 1") + + return len(errors) == 0, errors + + @staticmethod + def validate_consistency(results: List[Dict[str, Any]]) -> Tuple[bool, List[str]]: + """Check consistency across multiple benchmark runs.""" + if len(results) < 2: + return True, [] + + errors = [] + + # Check for extreme variations + qps_values = [r["qps"] for r in results if "qps" in r] + if qps_values: + mean_qps = sum(qps_values) / len(qps_values) + for i, qps in enumerate(qps_values): + if abs(qps - mean_qps) / mean_qps > 0.5: # 50% variation + errors.append(f"Run {i} has QPS {qps} which varies >50% from mean {mean_qps}") + + return len(errors) == 0, errors diff --git a/vdb_benchmark/vdbbench/__init__.py b/vdb_benchmark/vdbbench/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/vdb_benchmark/vdbbench/compact_and_watch.py b/vdb_benchmark/vdbbench/compact_and_watch.py new file mode 100644 index 0000000..b6fafa4 --- /dev/null +++ b/vdb_benchmark/vdbbench/compact_and_watch.py @@ -0,0 +1,292 @@ +import argparse +import logging +import os +import sys +import time + +from datetime import datetime, timedelta +from pymilvus import connections, Collection, utility + +# Configure logging +logging.basicConfig( + level=logging.INFO, + format='%(asctime)s - %(levelname)s - %(message)s' +) + +# Add the parent directory to sys.path to import config_loader +sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) +from vdbbench.config_loader import load_config, merge_config_with_args + +# Configure logging +logging.basicConfig( + level=logging.INFO, + format='%(asctime)s - %(levelname)s - %(message)s' +) + + +def parse_args(): + parser = argparse.ArgumentParser(description="Monitor Milvus collection compaction process") + parser.add_argument("--host", type=str, default="127.0.0.1", help="Milvus server host") + parser.add_argument("--port", type=str, default="19530", help="Milvus server port") + parser.add_argument("--collection", type=str, required=False, help="Collection name to compact and monitor") + parser.add_argument("--interval", type=int, default=5, help="Monitoring interval in seconds") + parser.add_argument("--compact", action="store_true", help="Perform compaction before monitoring") + parser.add_argument("--zero-threshold", type=int, default=90, + help="Time in seconds to wait with zero pending rows before considering complete") + parser.add_argument("--config", type=str, help="Path to YAML configuration file") + + args = parser.parse_args() + + # Track which arguments were explicitly set vs using defaults + args.is_default = { + 'host': args.host == "127.0.0.1", + 'port': args.port == "19530", + 'interval': args.interval == 5, + 'zero_threshold': args.zero_threshold == 90, + 'compact': not args.compact # Default is False + } + + # Load configuration from YAML if specified + config = {} + if args.config: + config = load_config(args.config) + args = merge_config_with_args(config, args) + + # Validate required parameters + if not args.collection: + parser.error("Collection name is required. Specify with --collection or in config file.") + + return args + + +def connect_to_milvus(host, port): + """Connect to Milvus server""" + try: + connections.connect( + "default", + host=host, + port=port, + max_receive_message_length=514_983_574, + max_send_message_length=514_983_574 + ) + logging.info(f"Connected to Milvus server at {host}:{port}") + return True + except Exception as e: + logging.error(f"Failed to connect to Milvus: {str(e)}") + return False + +def perform_compaction(collection_name): + """Perform compaction on the collection""" + try: + collection = Collection(name=collection_name) + logging.info(f"Starting compaction on collection: {collection_name}") + compaction_start = time.time() + collection.compact() + compaction_time = time.time() - compaction_start + logging.info(f"Compaction command completed in {compaction_time:.2f} seconds") + return True + except Exception as e: + logging.error(f"Failed to perform compaction: {str(e)}") + return False + +def monitor_progress(collection_name, interval=60, zero_threshold=300): + """Monitor the progress of index building/compaction""" + start_time = time.time() + prev_check_time = start_time + + try: + # Get initial progress + prev_progress = utility.index_building_progress(collection_name=collection_name) + initial_indexed_rows = prev_progress.get("indexed_rows", 0) + initial_pending_rows = prev_progress.get("pending_index_rows", 0) + total_rows = prev_progress.get("total_rows", 0) + + logging.info(f"Starting to monitor progress for collection: {collection_name}") + logging.info(f"Initial state: {initial_indexed_rows:,} of {total_rows:,} rows indexed") + logging.info(f"Initial pending rows: {initial_pending_rows:,}") + + # Track the phases + indexing_phase_complete = initial_indexed_rows >= total_rows + pending_phase_complete = False + + # Track time with zero pending rows + pending_zero_start_time = None + + while True: + time.sleep(interval) # Check at specified interval + current_time = time.time() + elapsed_time = current_time - start_time + time_since_last_check = current_time - prev_check_time + + try: + progress = utility.index_building_progress(collection_name=collection_name) + + # Calculate progress metrics + indexed_rows = progress.get("indexed_rows", 0) + total_rows = progress.get("total_rows", total_rows) # Use previous if not available + pending_rows = progress.get("pending_index_rows", 0) + + # Quick exit: + if pending_rows == 0 and indexed_rows == total_rows: + # Ensure the pending counter has started + if not pending_zero_start_time: + pending_zero_start_time = current_time + logging.info("No pending rows detected. Assuming indexing phase is complete.") + indexing_phase_complete = True + + # Calculate both overall and recent indexing rates + total_rows_indexed_since_start = indexed_rows - initial_indexed_rows + rows_since_last_check = indexed_rows - prev_progress.get("indexed_rows", indexed_rows) + + # Calculate pending rows reduction + pending_rows_reduction = prev_progress.get("pending_index_rows", pending_rows) - pending_rows + pending_reduction_rate = pending_rows_reduction / time_since_last_check if time_since_last_check > 0 else 0 + + # Calculate overall rate (based on total time since monitoring began) + if elapsed_time > 0: + # Calculate percent done regardless of whether new rows were indexed + percent_done = indexed_rows / total_rows * 100 if total_rows > 0 else 100 + + if total_rows_indexed_since_start > 0: + # Normal case: some rows have been indexed since we started monitoring + overall_indexing_rate = total_rows_indexed_since_start / elapsed_time # rows per second + remaining_rows = total_rows - indexed_rows + estimated_seconds_remaining = remaining_rows / overall_indexing_rate if overall_indexing_rate > 0 else float('inf') + + # Alternative estimate based on pending rows + pending_estimate = pending_rows / pending_reduction_rate if pending_reduction_rate > 0 and pending_rows > 0 else float('inf') + + # Calculate recent rate (for comparison) + recent_indexing_rate = rows_since_last_check / time_since_last_check if time_since_last_check > 0 else 0 + + # Format the estimated time remaining + eta = datetime.now() + timedelta(seconds=estimated_seconds_remaining) + eta_str = eta.strftime("%Y-%m-%d %H:%M:%S") + + # Format the pending-based estimate + pending_eta = datetime.now() + timedelta(seconds=pending_estimate) if pending_estimate != float('inf') else "Unknown" + if isinstance(pending_eta, datetime): + pending_eta_str = pending_eta.strftime("%Y-%m-%d %H:%M:%S") + else: + pending_eta_str = str(pending_eta) + + # Log progress with estimates + if not indexing_phase_complete: + # Still in initial indexing phase + logging.info( + f"Phase 1 - Building index: {percent_done:.2f}% complete... " + f"({indexed_rows:,}/{total_rows:,} rows) | " + f"Pending rows: {pending_rows:,} | " + f"Overall rate: {overall_indexing_rate:.2f} rows/sec | " + f"Recent rate: {recent_indexing_rate:.2f} rows/sec | " + f"ETA: {eta_str} | " + f"Est. remaining: {timedelta(seconds=int(estimated_seconds_remaining))}" + ) + else: + # In pending rows processing phase + if pending_rows > 0: + # Reset the zero pending timer if we see pending rows + pending_zero_start_time = None + + logging.info( + f"Phase 2 - Processing pending rows: {pending_rows:,} remaining | " + f"Reduction rate: {pending_reduction_rate:.2f} rows/sec | " + f"ETA: {pending_eta_str} | " + f"Est. remaining: {timedelta(seconds=int(pending_estimate)) if pending_estimate != float('inf') else 'Unknown'}" + ) + else: + # Handle zero pending rows case (same as below) + if pending_zero_start_time is None: + pending_zero_start_time = current_time + logging.info(f"No pending rows detected. Starting {zero_threshold//60}-minute confirmation timer.") + else: + zero_pending_time = current_time - pending_zero_start_time + logging.info(f"No pending rows for {zero_pending_time:.1f} seconds (waiting for {zero_threshold} seconds to confirm)") + + if zero_pending_time >= zero_threshold: + logging.info(f"No pending rows detected for {zero_threshold//60} minutes. Process is considered complete.") + pending_phase_complete = True + else: + # Special case: all rows were already indexed when we started monitoring + logging.info( + f"Progress: {percent_done:.2f}% complete... " + f"({indexed_rows:,}/{total_rows:,} rows) | " + f"Pending rows: {pending_rows:,}" + ) + + # If all rows are indexed and there are no pending rows, we might be done + if indexed_rows >= total_rows and pending_rows == 0: + if not indexing_phase_complete: + indexing_phase_complete = True + logging.info(f"Initial indexing phase complete! All {indexed_rows:,} rows have been indexed.") + + # Handle zero pending rows case + if pending_zero_start_time is None: + pending_zero_start_time = current_time + logging.info(f"No pending rows detected. Starting {zero_threshold}-second confirmation timer.") + else: + zero_pending_time = current_time - pending_zero_start_time + logging.info(f"No pending rows for {zero_pending_time:.1f} seconds (waiting for {zero_threshold} seconds to confirm)") + + if zero_pending_time >= zero_threshold: + logging.info(f"No pending rows detected for {zero_threshold} seconds. Process is considered complete.") + pending_phase_complete = True + else: + # If no time has elapsed (first iteration) + percent_done = indexed_rows / total_rows * 100 if total_rows > 0 else 0 + logging.info( + f"Progress: {percent_done:.2f}% complete... " + f"({indexed_rows:,}/{total_rows:,} rows) | " + f"Pending rows: {pending_rows:,} | " + f"Initial measurement, no progress data yet" + ) + + # Check if pending phase is complete + if not pending_phase_complete and pending_rows == 0: + # If we've already waited long enough with zero pending rows + if pending_zero_start_time is not None and (current_time - pending_zero_start_time) >= zero_threshold: + pending_phase_complete = True + logging.info(f"Pending rows processing complete! All pending rows have been processed.") + + # Check if both phases are complete + if (indexed_rows >= total_rows or indexing_phase_complete) and pending_phase_complete: + total_time = time.time() - start_time + logging.info(f"Process fully complete! Total time: {timedelta(seconds=int(total_time))}") + break + + # Update for next iteration + prev_progress = progress + prev_check_time = current_time + + except Exception as e: + logging.error(f"Error checking progress: {str(e)}") + time.sleep(5) # Short delay before retrying + + except Exception as e: + logging.error(f"Error in monitor_progress: {str(e)}") + return False + + return True + +def main(): + args = parse_args() + + # Connect to Milvus + if not connect_to_milvus(args.host, args.port): + return 1 + + # Perform compaction if requested + if args.compact: + if not perform_compaction(args.collection): + return 1 + + # Monitor progress + logging.info(f"Starting to monitor progress (checking every {args.interval} seconds)") + if not monitor_progress(args.collection, args.interval, args.zero_threshold): + return 1 + + logging.info("Monitoring completed successfully!") + return 0 + +if __name__ == "__main__": + sys.exit(main()) \ No newline at end of file diff --git a/vdb_benchmark/vdbbench/config_loader.py b/vdb_benchmark/vdbbench/config_loader.py new file mode 100644 index 0000000..ba6449d --- /dev/null +++ b/vdb_benchmark/vdbbench/config_loader.py @@ -0,0 +1,60 @@ +import yaml +import os + +def load_config(config_file=None): + """ + Load configuration from a YAML file. + + Args: + config_file (str): Path to the YAML configuration file + + Returns: + dict: Configuration dictionary or empty dict if file not found + """ + if not config_file: + return {} + + path_exists = os.path.exists(config_file) + configs_path_exists = os.path.exists(os.path.join("configs", config_file)) + if path_exists or configs_path_exists: + config_file = config_file if path_exists else os.path.join("configs", config_file) + else: + print(f"ERROR: Configuration file not found: {config_file}") + return {} + + try: + with open(config_file, 'r') as f: + config = yaml.safe_load(f) + print(f"Loaded vdbbench configuration from {config_file}") + return config + except Exception as e: + print("ERROR - Error loading configuration file: {str(e)}") + return {} + + +def merge_config_with_args(config, args): + """ + Merge configuration from YAML with command line arguments. + Command line arguments take precedence over YAML configuration. + + Args: + config (dict): Configuration dictionary from YAML + args (Namespace): Parsed command line arguments + + Returns: + Namespace: Updated arguments with values from config where not specified in args + """ + # Convert args to a dictionary + args_dict = vars(args) + + # For each key in config, if the corresponding arg is None or has a default value, + # update it with the value from config + for section, params in config.items(): + for key, value in params.items(): + if key in args_dict and (args_dict[key] is None or + (hasattr(args, 'is_default') and + key in args.is_default and + args.is_default[key])): + args_dict[key] = value + + return args diff --git a/vdb_benchmark/vdbbench/configs/10m_diskann.yaml b/vdb_benchmark/vdbbench/configs/10m_diskann.yaml new file mode 100644 index 0000000..a25b681 --- /dev/null +++ b/vdb_benchmark/vdbbench/configs/10m_diskann.yaml @@ -0,0 +1,26 @@ +database: + host: 127.0.0.1 + port: 19530 + database: milvus + max_receive_message_length: 514_983_574 + max_send_message_length: 514_983_574 + +dataset: + collection_name: mlps_10m_10shards_1536dim_uniform_diskann + num_vectors: 10_000_000 + dimension: 1536 + distribution: uniform + chunk_size: 1_000_000 + batch_size: 1000 + num_shards: 10 + vector_dtype: FLOAT_VECTOR + +index: + index_type: DISKANN + metric_type: COSINE + #index_params + max_degree: 64 + search_list_size: 200 + +workflow: + compact: True diff --git a/vdb_benchmark/vdbbench/configs/10m_hnsw.yaml b/vdb_benchmark/vdbbench/configs/10m_hnsw.yaml new file mode 100644 index 0000000..da4228f --- /dev/null +++ b/vdb_benchmark/vdbbench/configs/10m_hnsw.yaml @@ -0,0 +1,26 @@ +database: + host: 127.0.0.1 + port: 19530 + database: milvus + max_receive_message_length: 514_983_574 + max_send_message_length: 514_983_574 + +dataset: + collection_name: mlps_10m_10shards_1536dim_uniform_hnsw + num_vectors: 10_000_000 + dimension: 1536 + distribution: uniform + chunk_size: 1_000_000 + batch_size: 1000 + num_shards: 10 + vector_dtype: FLOAT_VECTOR + +index: + index_type: HNSW + metric_type: COSINE + #index_params + M: 64 + ef_construction: 200 + +workflow: + compact: True diff --git a/vdb_benchmark/vdbbench/configs/1m_diskann.yaml b/vdb_benchmark/vdbbench/configs/1m_diskann.yaml new file mode 100644 index 0000000..34d5570 --- /dev/null +++ b/vdb_benchmark/vdbbench/configs/1m_diskann.yaml @@ -0,0 +1,26 @@ +database: + host: 127.0.0.1 + port: 19530 + database: milvus + max_receive_message_length: 514_983_574 + max_send_message_length: 514_983_574 + +dataset: + collection_name: mlps_1m_1shards_1536dim_uniform_diskann + num_vectors: 1_000_000 + dimension: 1536 + distribution: uniform + chunk_size: 100_000 + batch_size: 1000 + num_shards: 1 + vector_dtype: FLOAT_VECTOR + +index: + index_type: DISKANN + metric_type: COSINE + #index_params + max_degree: 64 + search_list_size: 200 + +workflow: + compact: True diff --git a/vdb_benchmark/vdbbench/configs/1m_hnsw.yaml b/vdb_benchmark/vdbbench/configs/1m_hnsw.yaml new file mode 100644 index 0000000..1aeb428 --- /dev/null +++ b/vdb_benchmark/vdbbench/configs/1m_hnsw.yaml @@ -0,0 +1,26 @@ +database: + host: 127.0.0.1 + port: 19530 + database: milvus + max_receive_message_length: 514_983_574 + max_send_message_length: 514_983_574 + +dataset: + collection_name: mlps_1m_1shards_1536dim_uniform_hnsw + num_vectors: 1_000_000 + dimension: 1536 + distribution: uniform + chunk_size: 100_000 + batch_size: 1000 + num_shards: 1 + vector_dtype: FLOAT_VECTOR + +index: + index_type: HNSW + metric_type: COSINE + #index_params + M: 64 + ef_construction: 200 + +workflow: + compact: True diff --git a/vdb_benchmark/vdbbench/list_collections.py b/vdb_benchmark/vdbbench/list_collections.py new file mode 100644 index 0000000..d6633cb --- /dev/null +++ b/vdb_benchmark/vdbbench/list_collections.py @@ -0,0 +1,183 @@ +#!/usr/bin/env python3 +""" +Milvus Collection Information Script + +This script connects to a Milvus instance and lists all collections with detailed information +including the number of vectors in each collection and index information. +""" + +import sys +import os +import argparse +import logging +from tabulate import tabulate +from typing import Dict, List, Any + +# Configure logging +logging.basicConfig( + level=logging.INFO, + format='%(asctime)s - %(levelname)s - %(message)s' +) +logger = logging.getLogger(__name__) + +# Add the parent directory to sys.path to import config_loader +sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) + +try: + from pymilvus import connections, utility, Collection +except ImportError: + logger.error("Error: pymilvus package not found. Please install it with 'pip install pymilvus'") + sys.exit(1) + +try: + from tabulate import tabulate +except ImportError: + logger.error("Error: tabulate package not found. Please install it with 'pip install tabulate'") + sys.exit(1) + + +def parse_args(): + """Parse command line arguments""" + parser = argparse.ArgumentParser(description="List Milvus collections with detailed information") + parser.add_argument("--host", type=str, default="127.0.0.1", help="Milvus server host") + parser.add_argument("--port", type=str, default="19530", help="Milvus server port") + parser.add_argument("--format", type=str, choices=["table", "json"], default="table", + help="Output format (table or json)") + return parser.parse_args() + + +def connect_to_milvus(host, port): + """Connect to Milvus server""" + try: + connections.connect( + alias="default", + host=host, + port=port + ) + logger.info(f"Connected to Milvus server at {host}:{port}") + return True + except Exception as e: + logger.error(f"Failed to connect to Milvus server: {str(e)}") + return False + + +def get_collection_info(collection_name, release=True): + """Get detailed information about a collection""" + try: + collection = Collection(collection_name) + # collection.load() + + # Get basic collection info - using num_entities instead of get_statistics + row_count = collection.num_entities + # row_count = get_collection_info(collection_name)["row_count"] + + # Get schema information + schema = collection.schema + dimension = None + for field in schema.fields: + if field.dtype in [100, 101]: # FLOAT_VECTOR or BINARY_VECTOR + dimension = field.params.get("dim") + break + + # Get index information + index_info = [] + if collection.has_index(): + index = collection.index() + index_info.append({ + "field_name": index.field_name, + "index_type": index.params.get("index_type"), + "metric_type": index.params.get("metric_type"), + "params": index.params.get("params", {}) + }) + + # Get partition information + partitions = collection.partitions + partition_info = [{"name": p.name, "description": p.description} for p in partitions] + + return { + "name": collection_name, + "row_count": row_count, + "dimension": dimension, + "schema": str(schema), + "index_info": index_info, + "partitions": partition_info + } + except Exception as e: + logger.error(f"Error getting info for collection {collection_name}: {str(e)}") + return { + "name": collection_name, + "error": str(e) + } + finally: + # Release collection + if release: + try: + collection.release() + except: + pass + + +def main(): + """Main function""" + args = parse_args() + + # Connect to Milvus + if not connect_to_milvus(args.host, args.port): + return 1 + + # List all collections + try: + collection_names = utility.list_collections() + logger.info(f"Found {len(collection_names)} collections") + + if not collection_names: + logger.info("No collections found in the Milvus instance") + return 0 + + # Get detailed information for each collection + collections_info = [] + for name in collection_names: + logger.info(f"Getting information for collection: {name}") + info = get_collection_info(name) + collections_info.append(info) + + # Display information based on format + if args.format == "json": + import json + print(json.dumps(collections_info, indent=2)) + else: + # Table format + table_data = [] + for info in collections_info: + index_types = ", ".join([idx.get("index_type", "N/A") for idx in info.get("index_info", [])]) + metric_types = ", ".join([idx.get("metric_type", "N/A") for idx in info.get("index_info", [])]) + + row = [ + info["name"], + info.get("row_count", "N/A"), + info.get("dimension", "N/A"), + index_types, + metric_types, + len(info.get("partitions", [])) + ] + table_data.append(row) + + headers = ["Collection Name", "Vector Count", "Dimension", "Index Types", "Metric Types", "Partitions"] + print(tabulate(table_data, headers=headers, tablefmt="grid")) + + return 0 + + except Exception as e: + logger.error(f"Error listing collections: {str(e)}") + return 1 + finally: + # Disconnect from Milvus + try: + connections.disconnect("default") + logger.info("Disconnected from Milvus server") + except: + pass + + +if __name__ == "__main__": + sys.exit(main()) \ No newline at end of file diff --git a/vdb_benchmark/vdbbench/load_vdb.py b/vdb_benchmark/vdbbench/load_vdb.py new file mode 100644 index 0000000..0a7a932 --- /dev/null +++ b/vdb_benchmark/vdbbench/load_vdb.py @@ -0,0 +1,370 @@ +import argparse +import logging +import sys +import os +import time +import numpy as np +from pymilvus import connections, Collection, FieldSchema, CollectionSchema, DataType, utility + +# Add the parent directory to sys.path to import config_loader +sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) +from vdbbench.config_loader import load_config, merge_config_with_args +from vdbbench.compact_and_watch import monitor_progress + +# Configure logging +logging.basicConfig( + level=logging.INFO, + format='%(asctime)s - %(levelname)s - %(message)s' +) +logger = logging.getLogger(__name__) + +def parse_args(): + parser = argparse.ArgumentParser(description="Load vectors into Milvus database") + + # Connection parameters + parser.add_argument("--host", type=str, default="localhost", help="Milvus server host") + parser.add_argument("--port", type=str, default="19530", help="Milvus server port") + + # Collection parameters + parser.add_argument("--collection-name", type=str, help="Name of the collection to create") + parser.add_argument("--dimension", type=int, help="Vector dimension") + parser.add_argument("--num-shards", type=int, default=1, help="Number of shards for the collection") + parser.add_argument("--vector-dtype", type=str, default="float", choices=["FLOAT_VECTOR"], + help="Vector data type. Only FLOAT_VECTOR is supported for now") + parser.add_argument("--force", action="store_true", help="Force recreate collection if it exists") + + # Data generation parameters + parser.add_argument("--num-vectors", type=int, help="Number of vectors to generate") + parser.add_argument("--distribution", type=str, default="uniform", + choices=["uniform", "normal"], help="Distribution for vector generation") + parser.add_argument("--batch-size", type=int, default=10000, help="Batch size for insertion") + parser.add_argument("--chunk-size", type=int, default=1000000, help="Number of vectors to generate in each chunk (for memory management)") + + # Index parameters + parser.add_argument("--index-type", type=str, default="DISKANN", help="Index type") + parser.add_argument("--metric-type", type=str, default="COSINE", help="Metric type for index") + parser.add_argument("--max-degree", type=int, default=16, help="DiskANN MaxDegree parameter") + parser.add_argument("--search-list-size", type=int, default=200, help="DiskANN SearchListSize parameter") + parser.add_argument("--M", type=int, default=16, help="HNSW M parameter") + parser.add_argument("--ef-construction", type=int, default=200, help="HNSW efConstruction parameter") + + # Monitoring parameters + parser.add_argument("--monitor-interval", type=int, default=5, help="Interval in seconds for monitoring index building") + parser.add_argument("--compact", action="store_true", help="Perform compaction after loading") + + # Configuration file + parser.add_argument("--config", type=str, help="Path to YAML configuration file") + + # What-if option to print args and exit + parser.add_argument("--what-if", action="store_true", help="Print the arguments after processing and exit") + + # Debug option to set logging level to DEBUG + parser.add_argument("--debug", action="store_true", help="Enable debug logging") + + args = parser.parse_args() + + # Track which arguments were explicitly set vs using defaults + args.is_default = { + 'host': args.host == "localhost", + 'port': args.port == "19530", + 'num_shards': args.num_shards == 1, + 'vector_dtype': args.vector_dtype == "float", + 'distribution': args.distribution == "uniform", + 'batch_size': args.batch_size == 10000, + 'chunk_size': args.chunk_size == 1000000, + 'index_type': args.index_type == "DISKANN", + 'metric_type': args.metric_type == "COSINE", + 'max_degree': args.max_degree == 16, + 'search_list_size': args.search_list_size == 200, + 'M': args.M == 16, + 'ef_construction': args.ef_construction == 200, + 'monitor_interval': args.monitor_interval == 5, + 'compact': not args.compact, # Default is False + 'force': not args.force, # Default is False + 'what_if': not args.what_if, # Default is False + 'debug': not args.debug # Default is False + } + + # Set logging level to DEBUG if --debug is specified + if args.debug: + logger.setLevel(logging.DEBUG) + logger.debug("Debug logging enabled") + + # Load configuration from YAML if specified + if args.config: + config = load_config(args.config) + args = merge_config_with_args(config, args) + + # If what-if is specified, print the arguments and exit + if args.what_if: + logger.info("Running in what-if mode. Printing arguments and exiting.") + print("\nConfiguration after processing arguments and config file:") + print("=" * 60) + for key, value in vars(args).items(): + if key != 'is_default': # Skip the is_default dictionary + source = "default" if args.is_default.get(key, False) else "specified" + print(f"{key}: {value} ({source})") + print("=" * 60) + sys.exit(0) + + # Validate required parameters + required_params = ['collection_name', 'dimension', 'num_vectors'] + missing_params = [param for param in required_params if getattr(args, param.replace('-', '_'), None) is None] + + if missing_params: + parser.error(f"Missing required parameters: {', '.join(missing_params)}. " + f"Specify with command line arguments or in config file.") + + return args + + +def connect_to_milvus(host, port): + """Connect to Milvus server""" + try: + logger.debug(f"Connecting to Milvus server at {host}:{port}") + connections.connect( + "default", + host=host, + port=port, + max_receive_message_length=514_983_574, + max_send_message_length=514_983_574 + ) + logger.info(f"Connected to Milvus server at {host}:{port}") + return True + + except Exception as e: + logger.error(f"Error connecting to Milvus server: {str(e)}") + return False + + +def create_collection(collection_name, dim, num_shards, vector_dtype, force=False): + """Create a new collection with the specified parameters""" + try: + # Check if collection exists + if utility.has_collection(collection_name): + if force: + Collection(name=collection_name).drop() + logger.info(f"Dropped existing collection: {collection_name}") + else: + logger.warning(f"Collection '{collection_name}' already exists. Use --force to drop and recreate it.") + return None + + # Define vector data type + vector_type = DataType.FLOAT_VECTOR + + # Define collection schema + fields = [ + FieldSchema(name="id", dtype=DataType.INT64, is_primary=True, auto_id=False), + FieldSchema(name="vector", dtype=vector_type, dim=dim) + ] + schema = CollectionSchema(fields, description="Benchmark Collection") + + # Create collection + collection = Collection(name=collection_name, schema=schema, num_shards=num_shards) + logger.info(f"Created collection '{collection_name}' with {dim} dimensions and {num_shards} shards") + + return collection + except Exception as e: + logger.error(f"Failed to create collection: {str(e)}") + return None + + +def generate_vectors(num_vectors, dim, distribution='uniform'): + """Generate random vectors based on the specified distribution""" + if distribution == 'uniform': + vectors = np.random.random((num_vectors, dim)).astype('float16') + elif distribution == 'normal': + vectors = np.random.normal(0, 1, (num_vectors, dim)).astype('float16') + elif distribution == 'zipfian': + # Simplified zipfian-like distribution + base = np.random.random((num_vectors, dim)).astype('float16') + skew = np.random.zipf(1.5, (num_vectors, 1)).astype('float16') + vectors = base * (skew / 10) + else: + vectors = np.random.random((num_vectors, dim)).astype('float16') + + # Normalize vectors + norms = np.linalg.norm(vectors, axis=1, keepdims=True) + normalized_vectors = vectors / norms + + return normalized_vectors.tolist() + + +def insert_data(collection, vectors, batch_size=10000): + """Insert vectors into the collection in batches""" + total_vectors = len(vectors) + num_batches = (total_vectors + batch_size - 1) // batch_size + + start_time = time.time() + total_inserted = 0 + + for i in range(num_batches): + batch_start = i * batch_size + batch_end = min((i + 1) * batch_size, total_vectors) + batch_size_actual = batch_end - batch_start + + # Prepare batch data + ids = list(range(batch_start, batch_end)) + batch_vectors = vectors[batch_start:batch_end] + + # Insert batch + try: + collection.insert([ids, batch_vectors]) + total_inserted += batch_size_actual + + # Log progress + progress = total_inserted / total_vectors * 100 + elapsed = time.time() - start_time + rate = total_inserted / elapsed if elapsed > 0 else 0 + + logger.info(f"Inserted batch {i+1}/{num_batches}: {progress:.2f}% complete, " + f"rate: {rate:.2f} vectors/sec") + + except Exception as e: + logger.error(f"Error inserting batch {i+1}: {str(e)}") + + return total_inserted, time.time() - start_time + + +def flush_collection(collection): + # Flush the collection + flush_start = time.time() + collection.flush() + flush_time = time.time() - flush_start + logger.info(f"Flush completed in {flush_time:.2f} seconds") + + +def create_index(collection, index_params): + """Create an index on the collection""" + try: + start_time = time.time() + logger.info(f"Creating index with parameters: {index_params}") + collection.create_index("vector", index_params) + index_creation_time = time.time() - start_time + logger.info(f"Index creation command completed in {index_creation_time:.2f} seconds") + return True + except Exception as e: + logger.error(f"Failed to create index: {str(e)}") + return False + + +def main(): + args = parse_args() + + # Connect to Milvus + if not connect_to_milvus(args.host, args.port): + logger.error("Failed to connect to Milvus.") + return 1 + + logger.debug(f'Determining datatype for vector representation.') + # Determine vector data type + try: + # Check if FLOAT16 is available in newer versions of pymilvus + if hasattr(DataType, 'FLOAT16'): + logger.debug(f'Using FLOAT16 data type for vector representation.")') + vector_dtype = DataType.FLOAT16 if args.vector_dtype == 'float16' else DataType.FLOAT_VECTOR + else: + # Fall back to supported data types + logger.warning("FLOAT16 data type not available in this version of pymilvus. Using FLOAT_VECTOR instead.") + vector_dtype = DataType.FLOAT_VECTOR + except Exception as e: + logger.warning(f"Error determining vector data type: {str(e)}. Using FLOAT_VECTOR as default.") + vector_dtype = DataType.FLOAT_VECTOR + + # Create collection + collection = create_collection( + collection_name=args.collection_name, + dim=args.dimension, + num_shards=args.num_shards, + vector_dtype=vector_dtype, + force=args.force + ) + + if collection is None: + return 1 + + # Create index with updated parameters + index_params = { + "index_type": args.index_type, + "metric_type": args.metric_type, + "params": {} + } + + # Update only the parameters based on index_type + if args.index_type == "HNSW": + index_params["params"] = { + "M": args.M, + "efConstruction": args.ef_construction + } + elif args.index_type == "DISKANN": + index_params["params"] = { + "MaxDegree": args.max_degree, + "SearchListSize": args.search_list_size + } + else: + raise ValueError(f"Unsupported index_type: {args.index_type}") + + logger.debug(f'Creating index. This should be immediate on an empty collection') + if not create_index(collection, index_params): + return 1 + + # Generate vectors + logger.info( + f"Generating {args.num_vectors} vectors with {args.dimension} dimensions using {args.distribution} distribution") + start_gen_time = time.time() + + # Split vector generation into chunks if num_vectors is large + if args.num_vectors > args.chunk_size: + logger.info(f"Large vector count detected. Generating in chunks of {args.chunk_size:,} vectors") + vectors = [] + remaining = args.num_vectors + chunks_processed = 0 + + while remaining > 0: + chunk_size = min(args.chunk_size, remaining) + logger.info(f"Generating chunk {chunks_processed+1}: {chunk_size:,} vectors") + chunk_start = time.time() + chunk_vectors = generate_vectors(chunk_size, args.dimension, args.distribution) + chunk_time = time.time() - chunk_start + + logger.info(f"Generated chunk {chunks_processed} ({chunk_size:,} vectors) in {chunk_time:.2f} seconds. " + f"Progress: {(args.num_vectors - remaining):,}/{args.num_vectors:,} vectors " + f"({(args.num_vectors - remaining) / args.num_vectors * 100:.1f}%)") + + # Insert data + logger.info(f"Inserting {args.num_vectors} vectors into collection '{args.collection_name}'") + total_inserted, insert_time = insert_data(collection, chunk_vectors, args.batch_size) + logger.info(f"Inserted {total_inserted} vectors in {insert_time:.2f} seconds") + + remaining -= chunk_size + chunks_processed += 1 + else: + # For smaller vector counts, generate all at once + vectors = generate_vectors(args.num_vectors, args.dimension, args.distribution) + # Insert data + logger.info(f"Inserting {args.num_vectors} vectors into collection '{args.collection_name}'") + total_inserted, insert_time = insert_data(collection, vectors, args.batch_size) + logger.info(f"Inserted {total_inserted} vectors in {insert_time:.2f} seconds") + + gen_time = time.time() - start_gen_time + logger.info(f"Generated all {args.num_vectors:,} vectors in {gen_time:.2f} seconds") + + flush_collection(collection) + + # Monitor index building + logger.info(f"Starting to monitor index building progress (checking every {args.monitor_interval} seconds)") + monitor_progress(args.collection_name, args.monitor_interval, zero_threshold=10) + + if args.compact: + logger.info(f"Compacting collection '{args.collection_name}'") + collection.compact() + monitor_progress(args.collection_name, args.monitor_interval, zero_threshold=30) + logger.info(f"Collection '{args.collection_name}' compacted successfully.") + + # Summary + logger.info("Benchmark completed successfully!") + return 0 + + +if __name__ == "__main__": + exit(main()) diff --git a/vdb_benchmark/vdbbench/simple_bench.py b/vdb_benchmark/vdbbench/simple_bench.py new file mode 100644 index 0000000..b679cd1 --- /dev/null +++ b/vdb_benchmark/vdbbench/simple_bench.py @@ -0,0 +1,668 @@ +#!/usr/bin/env python3 +""" +Milvus Vector Database Benchmark Script + +This script executes random vector queries against a Milvus collection using multiple processes. +It measures and reports query latency statistics. +""" + +import argparse +import multiprocessing as mp +import numpy as np +import os +import time +import json +import csv +import uuid +from datetime import datetime +from pathlib import Path +from typing import Dict, List, Any, Optional, Tuple, Union +import signal +import sys +from tabulate import tabulate + +from vdbbench.config_loader import load_config, merge_config_with_args +from vdbbench.list_collections import get_collection_info + +try: + from pymilvus import connections, Collection, utility +except ImportError: + print("Error: pymilvus package not found. Please install it with 'pip install pymilvus'") + sys.exit(1) + +STAGGER_INTERVAL_SEC = 0.1 + +# Global flag for graceful shutdown +shutdown_flag = mp.Value('i', 0) + +# CSV header fields +csv_fields = [ + "process_id", + "batch_id", + "timestamp", + "batch_size", + "batch_time_seconds", + "avg_query_time_seconds", + "success" +] + + +def signal_handler(sig, frame): + """Handle interrupt signals to gracefully shut down worker processes""" + print("\nReceived interrupt signal. Shutting down workers gracefully...") + with shutdown_flag.get_lock(): + shutdown_flag.value = 1 + + +def read_disk_stats() -> Dict[str, Dict[str, int]]: + """ + Read disk I/O statistics from /proc/diskstats + + Returns: + Dictionary mapping device names to their read/write statistics + """ + stats = {} + try: + with open('/proc/diskstats', 'r') as f: + for line in f: + parts = line.strip().split() + if len(parts) >= 14: # Ensure we have enough fields + device = parts[2] + # Fields based on kernel documentation + # https://www.kernel.org/doc/Documentation/ABI/testing/procfs-diskstats + sectors_read = int(parts[5]) # sectors read + sectors_written = int(parts[9]) # sectors written + + # 1 sector = 512 bytes + bytes_read = sectors_read * 512 + bytes_written = sectors_written * 512 + + stats[device] = { + "bytes_read": bytes_read, + "bytes_written": bytes_written + } + return stats + except FileNotFoundError: + print("Warning: /proc/diskstats not available (non-Linux system)") + return {} + except Exception as e: + print(f"Error reading disk stats: {e}") + return {} + + +def format_bytes(bytes_value: int) -> str: + """Format bytes into human-readable format with appropriate units""" + units = ['B', 'KB', 'MB', 'GB', 'TB'] + unit_index = 0 + value = float(bytes_value) + + while value > 1024 and unit_index < len(units) - 1: + value /= 1024 + unit_index += 1 + + return f"{value:.2f} {units[unit_index]}" + + +def calculate_disk_io_diff(start_stats: Dict[str, Dict[str, int]], + end_stats: Dict[str, Dict[str, int]]) -> Dict[str, Dict[str, int]]: + """Calculate the difference in disk I/O between start and end measurements""" + diff_stats = {} + + for device in end_stats: + if device in start_stats: + diff_stats[device] = { + "bytes_read": end_stats[device]["bytes_read"] - start_stats[device]["bytes_read"], + "bytes_written": end_stats[device]["bytes_written"] - start_stats[device]["bytes_written"] + } + + return diff_stats + + +def generate_random_vector(dim: int) -> List[float]: + """Generate a random normalized vector of the specified dimension""" + vec = np.random.random(dim).astype(np.float32) + return (vec / np.linalg.norm(vec)).tolist() + + +def connect_to_milvus(host: str, port: str) -> connections: + """Establish connection to Milvus server""" + try: + connections.connect(alias="default", host=host, port=port) + return connections + except Exception as e: + print(f"Failed to connect to Milvus: {e}") + return False + + +def execute_batch_queries(process_id: int, host: str, port: str, collection_name: str, vector_dim: int, batch_size: int, + report_count: int, max_queries: Optional[int], runtime_seconds: Optional[int], output_dir: str, + shutdown_flag: mp.Value) -> None: + """ + Execute batches of vector queries and log results to disk + + Args: + process_id: ID of the current process + host: Milvus server host + port: Milvus server port + collection_name: Name of the collection to query + vector_dim: Dimension of vectors + batch_size: Number of queries to execute in each batch + max_queries: Maximum number of queries to execute (None for unlimited) + runtime_seconds: Maximum runtime in seconds (None for unlimited) + output_dir: Directory to save results + shutdown_flag: Shared value to signal process termination + """ + print(f'Process {process_id} initialized') + # Connect to Milvus + connections = connect_to_milvus(host, port) + if not connections: + print(f'Process {process_id} - No milvus connection') + return + + # Get collection + try: + collection = Collection(collection_name) + print(f'Process {process_id} - Loading collection') + collection.load() + except Exception as e: + print(f"Process {process_id}: Failed to load collection: {e}") + return + + # Prepare output file + output_file = Path(output_dir) / f"milvus_benchmark_p{process_id}.csv" + sys.stdout.write(f"Process {process_id}: Writing results to {output_file}\r\n") + # Create output directory if it doesn't exist + os.makedirs(os.path.dirname(output_file), exist_ok=True) + + # Track execution + start_time = time.time() + query_count = 0 + batch_count = 0 + + sys.stdout.write(f"Process {process_id}: Starting benchmark ...\r\n") + sys.stdout.flush() + + try: + with open(output_file, 'w') as f: + writer = csv.DictWriter(f, fieldnames=csv_fields) + writer.writeheader() + while True: + # Check if we should terminate + with shutdown_flag.get_lock(): + if shutdown_flag.value == 1: + break + + # Check termination conditions + current_time = time.time() + elapsed_time = current_time - start_time + + if runtime_seconds is not None and elapsed_time >= runtime_seconds: + break + + if max_queries is not None and query_count >= max_queries: + break + + # Generate batch of query vectors + batch_vectors = [generate_random_vector(vector_dim) for _ in range(batch_size)] + + # Execute batch and measure time + batch_start = time.time() + try: + search_params = {"metric_type": "COSINE", "params": {"ef": 200}} + results = collection.search( + data=batch_vectors, + anns_field="vector", + param=search_params, + limit=10, + output_fields=["id"] + ) + batch_end = time.time() + batch_success = True + except Exception as e: + print(f"Process {process_id}: Search error: {e}") + batch_end = time.time() + batch_success = False + + # Record batch results + batch_time = batch_end - batch_start + batch_count += 1 + query_count += batch_size + + # Log batch results to file + batch_data = { + "process_id": process_id, + "batch_id": batch_count, + "timestamp": current_time, + "batch_size": batch_size, + "batch_time_seconds": batch_time, + "avg_query_time_seconds": batch_time / batch_size, + "success": batch_success + } + + writer.writerow(batch_data) + f.flush() # Ensure data is written to disk immediately + + # Print progress + if batch_count % report_count == 0: + sys.stdout.write(f"Process {process_id}: Completed {query_count} queries in {elapsed_time:.2f} seconds.\r\n") + sys.stdout.flush() + + except Exception as e: + print(f"Process {process_id}: Error during benchmark: {e}") + + finally: + # Disconnect from Milvus + try: + connections.disconnect("default") + except: + pass + + print( + f"Process {process_id}: Finished. Executed {query_count} queries in {time.time() - start_time:.2f} seconds", flush=True) + + +def calculate_statistics(results_dir: str) -> Dict[str, Union[str, int, float, Dict[str, int]]]: + """Calculate statistics from benchmark results""" + import pandas as pd + + # Find all result files + file_paths = list(Path(results_dir).glob("milvus_benchmark_p*.csv")) + + if not file_paths: + return {"error": "No benchmark result files found"} + + # Read and concatenate all CSV files into a single DataFrame + dfs = [] + for file_path in file_paths: + try: + df = pd.read_csv(file_path) + if not df.empty: + dfs.append(df) + except Exception as e: + print(f"Error reading result file {file_path}: {e}") + + if not dfs: + return {"error": "No valid data found in benchmark result files"} + + # Concatenate all dataframes + all_data = pd.concat(dfs, ignore_index=True) + all_data.sort_values('timestamp', inplace=True) + + # Calculate start and end times + file_start_time = min(all_data['timestamp']) + file_end_time = max(all_data['timestamp'] + all_data['batch_time_seconds']) + total_time_seconds = file_end_time - file_start_time + + # Each row represents a batch, so we need to expand based on batch_size + all_latencies = [] + for _, row in all_data.iterrows(): + query_time_ms = row['avg_query_time_seconds'] * 1000 + all_latencies.extend([query_time_ms] * row['batch_size']) + + # Convert batch times to milliseconds + batch_times_ms = all_data['batch_time_seconds'] * 1000 + + # Calculate statistics + latencies = np.array(all_latencies) + batch_times = np.array(batch_times_ms) + total_queries = len(latencies) + + stats = { + "total_queries": total_queries, + "total_time_seconds": total_time_seconds, + "min_latency_ms": float(np.min(latencies)), + "max_latency_ms": float(np.max(latencies)), + "mean_latency_ms": float(np.mean(latencies)), + "median_latency_ms": float(np.median(latencies)), + "p95_latency_ms": float(np.percentile(latencies, 95)), + "p99_latency_ms": float(np.percentile(latencies, 99)), + "p999_latency_ms": float(np.percentile(latencies, 99.9)), + "p9999_latency_ms": float(np.percentile(latencies, 99.99)), + "throughput_qps": float(total_queries / total_time_seconds) if total_time_seconds > 0 else 0, + + # Batch time statistics + "batch_count": len(batch_times), + "min_batch_time_ms": float(np.min(batch_times)) if len(batch_times) > 0 else 0, + "max_batch_time_ms": float(np.max(batch_times)) if len(batch_times) > 0 else 0, + "mean_batch_time_ms": float(np.mean(batch_times)) if len(batch_times) > 0 else 0, + "median_batch_time_ms": float(np.median(batch_times)) if len(batch_times) > 0 else 0, + "p95_batch_time_ms": float(np.percentile(batch_times, 95)) if len(batch_times) > 0 else 0, + "p99_batch_time_ms": float(np.percentile(batch_times, 99)) if len(batch_times) > 0 else 0, + "p999_batch_time_ms": float(np.percentile(batch_times, 99.9)) if len(batch_times) > 0 else 0, + "p9999_batch_time_ms": float(np.percentile(batch_times, 99.99)) if len(batch_times) > 0 else 0 + } + + return stats + + +def load_database(host: str, port: str, collection_name: str, reload=False) -> Union[dict, None]: + print(f'Connecting to Milvus server at {host}:{port}...', flush=True) + connections = connect_to_milvus(host, port) + if not connections: + print(f'Unable to connect to Milvus server', flush=True) + return None + + # Connect to Milvus + try: + collection = Collection(collection_name) + except Exception as e: + print(f"Unable to connect to Milvus collection {collection_name}: {e}", flush=True) + return None + + try: + # Get the load state of the collection: + state = utility.load_state(collection_name) + if reload or state.name != "Loaded": + if reload: + print(f'Reloading the collection {collection_name}...') + else: + print(f'Loading the collection {collection_name}...') + start_load_time = time.time() + collection.load() + load_time = time.time() - start_load_time + print(f'Collection {collection_name} loaded in {load_time:.2f} seconds', flush=True) + if not reload and state.name == "Loaded": + print(f'Collection {collection_name} already reloaded and not reloading...') + + except Exception as e: + print(f'Unable to load collection {collection_name}: {e}') + return None + + print(f'Getting collection statistics...', flush=True) + collection_info = get_collection_info(collection_name, release=False) + table_data = [] + + index_types = ", ".join([idx.get("index_type", "N/A") for idx in collection_info.get("index_info", [])]) + metric_types = ", ".join([idx.get("metric_type", "N/A") for idx in collection_info.get("index_info", [])]) + + row = [ + collection_info["name"], + collection_info.get("row_count", "N/A"), + collection_info.get("dimension", "N/A"), + index_types, + metric_types, + len(collection_info.get("partitions", [])) + ] + table_data.append(row) + + headers = ["Collection Name", "Vector Count", "Dimension", "Index Types", "Metric Types", "Partitions"] + print(f'\nTabulating information...', flush=True) + tabulated_data = tabulate(table_data, headers=headers, tablefmt="grid") + print(tabulated_data, flush=True) + + return collection_info + + +def main(): + parser = argparse.ArgumentParser(description="Milvus Vector Database Benchmark") + + parser.add_argument("--config", type=str, help="Path to vdbbench config file") + + # Required parameters + parser.add_argument("--processes", type=int, help="Number of parallel processes") + parser.add_argument("--batch-size", type=int, help="Number of queries per batch") + parser.add_argument("--vector-dim", type=int, default=1536, help="Vector dimension") + parser.add_argument("--report-count", type=int, default=10, help="Number of queries between logging results") + + # Database parameters + parser.add_argument("--host", type=str, default="localhost", help="Milvus server host") + parser.add_argument("--port", type=str, default="19530", help="Milvus server port") + parser.add_argument("--collection-name", type=str, help="Collection name to query") + + # Termination conditions (at least one must be specified) + termination_group = parser.add_argument_group("termination conditions (at least one required)") + termination_group.add_argument("--runtime", type=int, help="Maximum runtime in seconds") + termination_group.add_argument("--queries", type=int, help="Total number of queries to execute") + + # Output directory + parser.add_argument("--output-dir", type=str, help="Directory to save benchmark results") + parser.add_argument("--json-output", action="store_true", help="Print benchmark results as JSON document") + + args = parser.parse_args() + + # Validate termination conditions + if args.runtime is None and args.queries is None: + parser.error("At least one termination condition (--runtime or --queries) must be specified") + + # Register signal handlers for graceful shutdown + signal.signal(signal.SIGINT, signal_handler) + signal.signal(signal.SIGTERM, signal_handler) + + print("") + print("=" * 50) + print("OUTPUT CONFIGURATION", flush=True) + print("=" * 50, flush=True) + + # Load config from YAML if specified + if args.config: + config = load_config(args.config) + args = merge_config_with_args(config, args) + + # Create output directory + if not args.output_dir: + output_dir = "vdbbench_results" + datetime_str = datetime.now().strftime("%Y%m%d_%H%M%S") + output_dir = os.path.join(output_dir, datetime_str) + else: + output_dir = args.output_dir + + os.makedirs(output_dir, exist_ok=True) + + # Save benchmark configuration + config = { + "timestamp": datetime.now().isoformat(), + "processes": args.processes, + "batch_size": args.batch_size, + "report_count": args.report_count, + "vector_dim": args.vector_dim, + "host": args.host, + "port": args.port, + "collection_name": args.collection_name, + "runtime_seconds": args.runtime, + "total_queries": args.queries + } + + print(f"Results will be saved to: {output_dir}") + print(f'Writing configuration to {output_dir}/config.json') + with open(os.path.join(output_dir, "config.json"), 'w') as f: + json.dump(config, f, indent=2) + + print("") + print("=" * 50) + print("Database Verification and Loading", flush=True) + print("=" * 50) + + connections = connect_to_milvus(args.host, args.port) + print(f'Verifing database connection and loading collection') + if collection_info := load_database(args.host, args.port, args.collection_name): + print(f"\nCOLLECTION INFORMATION: {collection_info}") + # Having an active connection in the main thread when we fork seems to cause problems + connections.disconnect("default") + else: + print("Unable to load the specified collection") + sys.exit(1) + + # Read initial disk stats + print(f'\nCollecting initial disk statistics...') + start_disk_stats = read_disk_stats() + + # Calculate queries per process if total queries specified + max_queries_per_process = None + if args.queries is not None: + max_queries_per_process = args.queries // args.processes + # Add remainder to the first process + remainder = args.queries % args.processes + + # Start worker processes + processes = [] + stagger_interval_secs = 1 / args.processes + + print("") + print("=" * 50) + print("Benchmark Execution", flush=True) + print("=" * 50) + if max_queries_per_process is not None: + print(f"Starting benchmark with {args.processes} processes and {max_queries_per_process} queries per process") + else: + print(f'Starting benchmark with {args.processes} processes and running for {args.runtime} seconds') + if args.processes > 1: + print(f"Staggering benchmark execution by {stagger_interval_secs} seconds between processes") + try: + for i in range(args.processes): + if i > 0: + time.sleep(stagger_interval_secs) + # Adjust queries for the first process if there's a remainder + process_max_queries = None + if max_queries_per_process is not None: + process_max_queries = max_queries_per_process + (remainder if i == 0 else 0) + + p = mp.Process( + target=execute_batch_queries, + args=( + i, + args.host, + args.port, + args.collection_name, + args.vector_dim, + args.batch_size, + args.report_count, + process_max_queries, + args.runtime, + output_dir, + shutdown_flag + ) + ) + print(f'Starting process {i}...') + p.start() + processes.append(p) + + # Wait for all processes to complete + for p in processes: + p.join() + except Exception as e: + print(f"Error during benchmark execution: {e}") + # Signal all processes to terminate + with shutdown_flag.get_lock(): + shutdown_flag.value = 1 + + # Wait for processes to terminate + for p in processes: + if p.is_alive(): + p.join(timeout=5) + if p.is_alive(): + p.terminate() + else: + print(f'Running single process benchmark...') + execute_batch_queries(0, args.host, args.port, args.collection_name, args.vector_dim, args.batch_size, + args.report_count, args.queries, args.runtime, output_dir, shutdown_flag) + + # Read final disk stats + print('Reading final disk statistics...') + end_disk_stats = read_disk_stats() + + # Calculate disk I/O during benchmark + disk_io_diff = calculate_disk_io_diff(start_disk_stats, end_disk_stats) + + # Calculate and print statistics + print("\nCalculating benchmark statistics...") + stats = calculate_statistics(output_dir) + + # Add disk I/O statistics to the stats dictionary + if disk_io_diff: + # Calculate totals across all devices + total_bytes_read = sum(dev_stats["bytes_read"] for dev_stats in disk_io_diff.values()) + total_bytes_written = sum(dev_stats["bytes_written"] for dev_stats in disk_io_diff.values()) + + # Add disk I/O totals to stats + stats["disk_io"] = { + "total_bytes_read": total_bytes_read, + "total_bytes_read_per_sec": total_bytes_read / stats["total_time_seconds"], + "total_bytes_written": total_bytes_written, + "total_read_formatted": format_bytes(total_bytes_read), + "total_write_formatted": format_bytes(total_bytes_written), + "devices": {} + } + + # Add per-device breakdown + for device, io_stats in disk_io_diff.items(): + bytes_read = io_stats["bytes_read"] + bytes_written = io_stats["bytes_written"] + if bytes_read > 0 or bytes_written > 0: # Only include devices with activity + stats["disk_io"]["devices"][device] = { + "bytes_read": bytes_read, + "bytes_written": bytes_written, + "read_formatted": format_bytes(bytes_read), + "write_formatted": format_bytes(bytes_written) + } + else: + stats["disk_io"] = {"error": "Disk I/O statistics not available"} + + # Save statistics to file + with open(os.path.join(output_dir, "statistics.json"), 'w') as f: + json.dump(stats, f, indent=2) + + if args.json_output: + print("\nBenchmark statistics as JSON:") + print(json.dumps(stats)) + else: + # Print summary + print("\n" + "=" * 50) + print("BENCHMARK SUMMARY") + print("=" * 50) + print(f"Total Queries: {stats.get('total_queries', 0)}") + print(f"Total Batches: {stats.get('batch_count', 0)}") + print(f'Total Runtime: {stats.get("total_time_seconds", 0):.2f} seconds') + + # Print query time statistics + print("\nQUERY STATISTICS") + print("-" * 50) + + print(f"Mean Latency: {stats.get('mean_latency_ms', 0):.2f} ms") + print(f"Median Latency: {stats.get('median_latency_ms', 0):.2f} ms") + print(f"95th Percentile: {stats.get('p95_latency_ms', 0):.2f} ms") + print(f"99th Percentile: {stats.get('p99_latency_ms', 0):.2f} ms") + print(f"99.9th Percentile: {stats.get('p999_latency_ms', 0):.2f} ms") + print(f"99.99th Percentile: {stats.get('p9999_latency_ms', 0):.2f} ms") + print(f"Throughput: {stats.get('throughput_qps', 0):.2f} queries/second") + + # Print batch time statistics + print("\nBATCH STATISTICS") + print("-" * 50) + + print(f"Mean Batch Time: {stats.get('mean_batch_time_ms', 0):.2f} ms") + print(f"Median Batch Time: {stats.get('median_batch_time_ms', 0):.2f} ms") + print(f"95th Percentile: {stats.get('p95_batch_time_ms', 0):.2f} ms") + print(f"99th Percentile: {stats.get('p99_batch_time_ms', 0):.2f} ms") + print(f"99.9th Percentile: {stats.get('p999_batch_time_ms', 0):.2f} ms") + print(f"99.99th Percentile: {stats.get('p9999_batch_time_ms', 0):.2f} ms") + print(f"Max Batch Time: {stats.get('max_batch_time_ms', 0):.2f} ms") + print(f"Batch Throughput: {1000 / stats.get('mean_batch_time_ms', float('inf')):.2f} batches/second") + + # Print disk I/O statistics + print("\nDISK I/O DURING BENCHMARK") + print("-" * 50) + if disk_io_diff: + # Calculate totals across all devices + total_bytes_read = sum(dev_stats["bytes_read"] for dev_stats in disk_io_diff.values()) + total_bytes_written = sum(dev_stats["bytes_written"] for dev_stats in disk_io_diff.values()) + + print(f"Total Bytes Read: {format_bytes(total_bytes_read)}") + print(f"Total Bytes Written: {format_bytes(total_bytes_written)}") + print("\nPer-Device Breakdown:") + + for device, io_stats in disk_io_diff.items(): + bytes_read = io_stats["bytes_read"] + bytes_written = io_stats["bytes_written"] + if bytes_read > 0 or bytes_written > 0: # Only show devices with activity + print(f" {device}:") + print(f" Read: {format_bytes(bytes_read)}") + print(f" Write: {format_bytes(bytes_written)}") + else: + print("Disk I/O statistics not available") + + print("\nDetailed results saved to:", output_dir) + print("=" * 50) + + +if __name__ == "__main__": + main() \ No newline at end of file