diff --git a/benchmarks-jmh/scripts/test_node_setup.sh b/benchmarks-jmh/scripts/test_node_setup.sh index 3a1ee95d5..8c5d9258d 100644 --- a/benchmarks-jmh/scripts/test_node_setup.sh +++ b/benchmarks-jmh/scripts/test_node_setup.sh @@ -43,5 +43,5 @@ java --enable-native-access=ALL-UNNAMED \ --add-modules=jdk.incubator.vector \ -XX:+HeapDumpOnOutOfMemoryError \ -Xmx14G -Djvector.experimental.enable_native_vectorization=true \ - -jar target/benchmarks-jmh-4.0.0-beta.3-SNAPSHOT.jar + -jar target/benchmarks-jmh-4.0.0-rc.4-SNAPSHOT.jar diff --git a/benchmarks-jmh/src/main/java/io/github/jbellis/jvector/bench/ParallelWriteBenchmark.java b/benchmarks-jmh/src/main/java/io/github/jbellis/jvector/bench/ParallelWriteBenchmark.java new file mode 100644 index 000000000..8637710b0 --- /dev/null +++ b/benchmarks-jmh/src/main/java/io/github/jbellis/jvector/bench/ParallelWriteBenchmark.java @@ -0,0 +1,287 @@ +/* + * Copyright DataStax, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.github.jbellis.jvector.bench; + +import io.github.jbellis.jvector.disk.ReaderSupplierFactory; +import io.github.jbellis.jvector.graph.GraphIndexBuilder; +import io.github.jbellis.jvector.graph.ImmutableGraphIndex; +import io.github.jbellis.jvector.graph.ListRandomAccessVectorValues; +import io.github.jbellis.jvector.graph.NodesIterator; +import io.github.jbellis.jvector.graph.RandomAccessVectorValues; +import io.github.jbellis.jvector.graph.disk.OnDiskGraphIndex; +import io.github.jbellis.jvector.graph.disk.OnDiskGraphIndexWriter; +import io.github.jbellis.jvector.graph.disk.OrdinalMapper; +import io.github.jbellis.jvector.graph.disk.feature.Feature; +import io.github.jbellis.jvector.graph.disk.feature.FeatureId; +import io.github.jbellis.jvector.graph.disk.feature.FusedADC; +import io.github.jbellis.jvector.graph.disk.feature.NVQ; +import io.github.jbellis.jvector.graph.similarity.BuildScoreProvider; +import io.github.jbellis.jvector.quantization.NVQuantization; +import io.github.jbellis.jvector.quantization.PQVectors; +import io.github.jbellis.jvector.quantization.ProductQuantization; +import io.github.jbellis.jvector.vector.VectorSimilarityFunction; +import io.github.jbellis.jvector.vector.VectorizationProvider; +import io.github.jbellis.jvector.vector.types.VectorFloat; +import io.github.jbellis.jvector.vector.types.VectorTypeSupport; +import org.openjdk.jmh.annotations.*; +import org.openjdk.jmh.infra.Blackhole; + +import java.io.IOException; +import java.nio.file.Files; +import java.nio.file.Path; +import java.util.*; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.function.IntFunction; + +import static io.github.jbellis.jvector.quantization.KMeansPlusPlusClusterer.UNWEIGHTED; + +/** + * JMH benchmark that mirrors the ParallelWriteExample: it builds a graph from vectors, then + * writes the graph to disk sequentially and in parallel using NVQ + FUSED_ADC features, + * and verifies that the outputs are identical. + */ +@BenchmarkMode(Mode.AverageTime) +@OutputTimeUnit(TimeUnit.MILLISECONDS) +@State(Scope.Benchmark) +@Fork(value = 1, jvmArgsAppend = {"--add-modules=jdk.incubator.vector", "--enable-preview", "-Djvector.experimental.enable_native_vectorization=false"}) +@Warmup(iterations = 1) +@Measurement(iterations = 2) +@Threads(1) +public class ParallelWriteBenchmark { + private static final VectorTypeSupport VECTOR_TYPE_SUPPORT = VectorizationProvider.getInstance().getVectorTypeSupport(); + + @Param({"100000"}) + int numBaseVectors; + + @Param({"1024"}) + int dimension; + + // Graph build parameters + final int M = 32; + final int efConstruction = 100; + final float neighborOverflow = 1.2f; + final float alpha = 1.2f; + final boolean addHierarchy = false; + final boolean refineFinalGraph = true; + + // Dataset and index state + private RandomAccessVectorValues floatVectors; + private PQVectors pqVectors; + private ImmutableGraphIndex graph; + + // Feature state reused between iterations + private NVQ nvqFeature; + private FusedADC fusedAdcFeature; + private OrdinalMapper identityMapper; + private Map> inlineSuppliers; + + // Paths + private Path tempDir; + private final AtomicInteger fileCounter = new AtomicInteger(); + + @Setup(Level.Trial) + public void setup() throws IOException { + // Generate random vectors + final var baseVectors = new ArrayList>(numBaseVectors); + for (int i = 0; i < numBaseVectors; i++) { + baseVectors.add(createRandomVector(dimension)); + } + floatVectors = new ListRandomAccessVectorValues(baseVectors, dimension); + + // Compute PQ compression + final int pqM = Math.max(1, dimension / 8); + final boolean centerData = true; // for EUCLIDEAN + final var pq = ProductQuantization.compute(floatVectors, pqM, 256, centerData, UNWEIGHTED); + pqVectors = (PQVectors) pq.encodeAll(floatVectors); + + // Build graph using PQ build score provider + final var bsp = BuildScoreProvider.pqBuildScoreProvider(VectorSimilarityFunction.EUCLIDEAN, pqVectors); + try (var builder = new GraphIndexBuilder(bsp, floatVectors.dimension(), M, efConstruction, + neighborOverflow, alpha, addHierarchy, refineFinalGraph)) { + graph = builder.build(floatVectors); + } + + // Prepare features + int nSubVectors = floatVectors.dimension() == 2 ? 1 : 2; + var nvq = NVQuantization.compute(floatVectors, nSubVectors); + nvqFeature = new NVQ(nvq); + fusedAdcFeature = new FusedADC(graph.maxDegree(), pqVectors.getCompressor()); + + inlineSuppliers = new EnumMap<>(FeatureId.class); + inlineSuppliers.put(FeatureId.NVQ_VECTORS, ordinal -> new NVQ.State(nvq.encode(floatVectors.getVector(ordinal)))); + + identityMapper = new OrdinalMapper.IdentityMapper(floatVectors.size() - 1); + + // Temp directory for outputs + tempDir = Files.createTempDirectory("parallel-write-bench"); + } + + @TearDown(Level.Trial) + public void tearDown() throws IOException { + if (tempDir != null) { + // Best-effort cleanup of all files created + try (var stream = Files.list(tempDir)) { + stream.forEach(p -> { + try { Files.deleteIfExists(p); } catch (IOException ignored) {} + }); + } + Files.deleteIfExists(tempDir); + } + } + + @Benchmark + public void writeSequentialThenParallelAndVerify(Blackhole blackhole) throws IOException { + // Unique output files per invocation + int idx = fileCounter.getAndIncrement(); + Path sequentialPath = tempDir.resolve("graph-sequential-" + idx); + Path parallelPath = tempDir.resolve("graph-parallel-" + idx); + + long startSeq = System.nanoTime(); + writeGraph(graph, sequentialPath, false); + long seqTime = System.nanoTime() - startSeq; + + long startPar = System.nanoTime(); + writeGraph(graph, parallelPath, true); + long parTime = System.nanoTime() - startPar; + + // Report times and speedup for this invocation + double seqMs = seqTime / 1_000_000.0; + double parMs = parTime / 1_000_000.0; + double speedup = parTime == 0 ? Double.NaN : seqTime / (double) parTime; + System.out.printf("Sequential write: %.2f ms, Parallel write: %.2f ms, Speedup: %.2fx%n", seqMs, parMs, speedup); + + // Load and verify identical + OnDiskGraphIndex sequentialIndex = OnDiskGraphIndex.load(ReaderSupplierFactory.open(sequentialPath)); + OnDiskGraphIndex parallelIndex = OnDiskGraphIndex.load(ReaderSupplierFactory.open(parallelPath)); + try { + verifyIndicesIdentical(sequentialIndex, parallelIndex); + } finally { + sequentialIndex.close(); + parallelIndex.close(); + } + + // Consume sizes to prevent DCE + blackhole.consume(Files.size(sequentialPath)); + blackhole.consume(Files.size(parallelPath)); + + // Cleanup files after each invocation to limit disk usage + Files.deleteIfExists(sequentialPath); + Files.deleteIfExists(parallelPath); + } + + private void writeGraph(ImmutableGraphIndex graph, + Path path, + boolean parallel) throws IOException { + try (var writer = new OnDiskGraphIndexWriter.Builder(graph, path) + .withParallelWrites(parallel) + .with(nvqFeature) + .with(fusedAdcFeature) + .withMapper(identityMapper) + .build()) { + var view = graph.getView(); + Map> writeSuppliers = new EnumMap<>(FeatureId.class); + writeSuppliers.put(FeatureId.NVQ_VECTORS, inlineSuppliers.get(FeatureId.NVQ_VECTORS)); + writeSuppliers.put(FeatureId.FUSED_ADC, ordinal -> new FusedADC.State(view, pqVectors, ordinal)); + + writer.write(writeSuppliers); + view.close(); + } + } + + private static void verifyIndicesIdentical(OnDiskGraphIndex index1, OnDiskGraphIndex index2) throws IOException { + // Basic properties + if (index1.getMaxLevel() != index2.getMaxLevel()) { + throw new AssertionError("Max levels differ: " + index1.getMaxLevel() + " vs " + index2.getMaxLevel()); + } + if (index1.getIdUpperBound() != index2.getIdUpperBound()) { + throw new AssertionError("ID upper bounds differ: " + index1.getIdUpperBound() + " vs " + index2.getIdUpperBound()); + } + if (!index1.getFeatureSet().equals(index2.getFeatureSet())) { + throw new AssertionError("Feature sets differ: " + index1.getFeatureSet() + " vs " + index2.getFeatureSet()); + } + + try (var view1 = index1.getView(); var view2 = index2.getView()) { + if (!view1.entryNode().equals(view2.entryNode())) { + throw new AssertionError("Entry nodes differ: " + view1.entryNode() + " vs " + view2.entryNode()); + } + for (int level = 0; level <= index1.getMaxLevel(); level++) { + if (index1.size(level) != index2.size(level)) { + throw new AssertionError("Layer " + level + " sizes differ: " + index1.size(level) + " vs " + index2.size(level)); + } + if (index1.getDegree(level) != index2.getDegree(level)) { + throw new AssertionError("Layer " + level + " degrees differ: " + index1.getDegree(level) + " vs " + index2.getDegree(level)); + } + + // Collect node IDs in arrays + java.util.List nodeList1 = new java.util.ArrayList<>(); + java.util.List nodeList2 = new java.util.ArrayList<>(); + NodesIterator nodes1 = index1.getNodes(level); + while (nodes1.hasNext()) nodeList1.add(nodes1.nextInt()); + NodesIterator nodes2 = index2.getNodes(level); + while (nodes2.hasNext()) nodeList2.add(nodes2.nextInt()); + if (!nodeList1.equals(nodeList2)) { + throw new AssertionError("Layer " + level + " has different node sets"); + } + + // Compare neighbors + for (int nodeId : nodeList1) { + NodesIterator neighbors1 = view1.getNeighborsIterator(level, nodeId); + NodesIterator neighbors2 = view2.getNeighborsIterator(level, nodeId); + if (neighbors1.size() != neighbors2.size()) { + throw new AssertionError("Layer " + level + " node " + nodeId + " neighbor counts differ: " + neighbors1.size() + " vs " + neighbors2.size()); + } + int[] n1 = new int[neighbors1.size()]; + int[] n2 = new int[neighbors2.size()]; + for (int i = 0; i < n1.length; i++) { + n1[i] = neighbors1.nextInt(); + n2[i] = neighbors2.nextInt(); + } + if (!Arrays.equals(n1, n2)) { + throw new AssertionError("Layer " + level + " node " + nodeId + " has different neighbor sets"); + } + } + } + + // Optional vector checks (layer 0) + if (index1.getFeatureSet().contains(FeatureId.INLINE_VECTORS) || + index1.getFeatureSet().contains(FeatureId.NVQ_VECTORS)) { + int vectorsChecked = 0; + int maxToCheck = Math.min(100, index1.size(0)); + NodesIterator nodes = index1.getNodes(0); + while (nodes.hasNext() && vectorsChecked < maxToCheck) { + int node = nodes.nextInt(); + if (index1.getFeatureSet().contains(FeatureId.INLINE_VECTORS)) { + var vec1 = view1.getVector(node); + var vec2 = view2.getVector(node); + if (!vec1.equals(vec2)) { + throw new AssertionError("Node " + node + " vectors differ"); + } + } + vectorsChecked++; + } + } + } + } + + private VectorFloat createRandomVector(int dimension) { + VectorFloat vector = VECTOR_TYPE_SUPPORT.createFloatVector(dimension); + for (int i = 0; i < dimension; i++) { + vector.set(i, (float) Math.random()); + } + return vector; + } +} diff --git a/jvector-base/src/main/java/io/github/jbellis/jvector/disk/ByteBufferIndexWriter.java b/jvector-base/src/main/java/io/github/jbellis/jvector/disk/ByteBufferIndexWriter.java new file mode 100644 index 000000000..c9d184133 --- /dev/null +++ b/jvector-base/src/main/java/io/github/jbellis/jvector/disk/ByteBufferIndexWriter.java @@ -0,0 +1,271 @@ +/* + * Copyright DataStax, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.github.jbellis.jvector.disk; + +import java.io.IOException; +import java.nio.ByteBuffer; +import java.nio.ByteOrder; + +/** + * An IndexWriter implementation backed by a ByteBuffer for in-memory record building. + * This allows existing Feature.writeInline() implementations to write to memory buffers + * that can later be bulk-written to disk. + *

+ * Byte order is set to BIG_ENDIAN to match Java's DataOutput specification and ensure + * cross-platform compatibility. + *

+ * Not thread-safe. Each thread should use its own instance. + */ +public class ByteBufferIndexWriter implements IndexWriter { + private final ByteBuffer buffer; + private final int initialPosition; + + /** + * Creates a writer that writes to the given buffer. + * The buffer's byte order is set to BIG_ENDIAN to match DataOutput behavior. + * + * @param buffer the buffer to write to + * @param autoClear if true, automatically clears the buffer before writing + */ + public ByteBufferIndexWriter(ByteBuffer buffer, boolean autoClear) { + this.buffer = buffer; + if (autoClear) { + buffer.clear(); + } + this.buffer.order(ByteOrder.BIG_ENDIAN); + this.initialPosition = buffer.position(); + } + + /** + * Creates a writer that writes to the given buffer, automatically clearing it first. + * The buffer's byte order is set to BIG_ENDIAN to match DataOutput behavior. + * This is the most common usage pattern and is equivalent to: + * {@code new ByteBufferIndexWriter(buffer, true)} + * + * @param buffer the buffer to write to (will be cleared) + */ + public ByteBufferIndexWriter(ByteBuffer buffer) { + this(buffer, true); + } + + /** + * Creates a new {@code ByteBufferIndexWriter} with the specified capacity. + *

+ * If {@code offHeap} is {@code true}, a direct (off-heap) {@link ByteBuffer} is used; + * otherwise, a heap-based buffer is used. + * + * @param capacity the buffer capacity in bytes + * @param offHeap if {@code true}, use a direct (off-heap) buffer; otherwise, use a heap buffer + * @return a new {@code ByteBufferIndexWriter} backed by a buffer of the specified type and capacity + */ + public static ByteBufferIndexWriter create(int capacity, boolean offHeap) { + if (offHeap) { + return allocateDirect(capacity); + } else { + return allocate(capacity); + } + } + + /** + * Creates a writer with a new heap ByteBuffer of the given capacity. + * The buffer uses BIG_ENDIAN byte order. + */ + private static ByteBufferIndexWriter allocate(int capacity) { + ByteBuffer buffer = ByteBuffer.allocate(capacity); + buffer.order(ByteOrder.BIG_ENDIAN); + return new ByteBufferIndexWriter(buffer); + } + + /** + * Creates a writer with a new direct ByteBuffer of the given capacity. + * The buffer uses BIG_ENDIAN byte order. + */ + private static ByteBufferIndexWriter allocateDirect(int capacity) { + ByteBuffer buffer = ByteBuffer.allocateDirect(capacity); + buffer.order(ByteOrder.BIG_ENDIAN); + return new ByteBufferIndexWriter(buffer); + } + + /** + * Returns the underlying buffer. The buffer's position will be at the end of written data. + */ + public ByteBuffer getBuffer() { + return buffer; + } + + /** + * Returns a read-only view of the written data (from initial position to current position). + */ + public ByteBuffer getWrittenData() { + int currentPos = buffer.position(); + buffer.position(initialPosition); + ByteBuffer slice = buffer.slice(); + slice.limit(currentPos - initialPosition); + buffer.position(currentPos); + return slice.asReadOnlyBuffer(); + } + + /** + * Resets the buffer position to the initial position, allowing reuse. + */ + public void reset() { + // Reset for next use + buffer.clear(); + buffer.position(initialPosition); + } + + /** + * Returns an independent copy of the written data as a new ByteBuffer. + * The returned buffer is ready to read (position=0, limit=written data length). + * The writer's buffer is automatically reset and ready for reuse. + *

+ * This method handles all buffer management: + *

    + *
  • Flips the buffer to prepare for reading (sets limit=position, position=initialPosition)
  • + *
  • Allocates and creates a copy of the data
  • + *
  • Resets the buffer for the next write operation
  • + *
+ *

+ * This is the recommended way to extract data from the writer when the buffer + * will be reused (e.g., in thread-local scenarios). + * + * @return a new ByteBuffer containing a copy of the written data + */ + public ByteBuffer cloneBuffer() { + // Calculate the amount of data written + int bytesWritten = buffer.position() - initialPosition; + + // Set limit to current position and position to initial for reading + int savedPosition = buffer.position(); + buffer.position(initialPosition); + buffer.limit(savedPosition); + + // Create independent copy + ByteBuffer copy = ByteBuffer.allocate(bytesWritten); + copy.put(buffer); + copy.flip(); + + return copy; + } + + /** + * Returns the number of bytes written since construction or last reset. + * + * @return bytes written + */ + public int bytesWritten() { + return buffer.position() - initialPosition; + } + + @Override + public long position() { + return buffer.position() - initialPosition; + } + + @Override + public void close() { + // No-op for ByteBuffer + } + + // DataOutput methods + + @Override + public void write(int b) { + buffer.put((byte) b); + } + + @Override + public void write(byte[] b) { + buffer.put(b); + } + + @Override + public void write(byte[] b, int off, int len) { + buffer.put(b, off, len); + } + + @Override + public void writeBoolean(boolean v) { + buffer.put((byte) (v ? 1 : 0)); + } + + @Override + public void writeByte(int v) { + buffer.put((byte) v); + } + + @Override + public void writeShort(int v) { + buffer.putShort((short) v); + } + + @Override + public void writeChar(int v) { + buffer.putChar((char) v); + } + + @Override + public void writeInt(int v) { + buffer.putInt(v); + } + + @Override + public void writeLong(long v) { + buffer.putLong(v); + } + + @Override + public void writeFloat(float v) { + buffer.putFloat(v); + } + + @Override + public void writeDouble(double v) { + buffer.putDouble(v); + } + + @Override + public void writeBytes(String s) { + int len = s.length(); + for (int i = 0; i < len; i++) { + buffer.put((byte) s.charAt(i)); + } + } + + @Override + public void writeChars(String s) { + int len = s.length(); + for (int i = 0; i < len; i++) { + buffer.putChar(s.charAt(i)); + } + } + + @Override + public void writeUTF(String s) throws IOException { + // Use standard DataOutputStream UTF encoding + byte[] bytes = s.getBytes("UTF-8"); + int utflen = bytes.length; + // UTF format stores the string length as a 2-byte (16-bit) unsigned integer prefix, + // which has a maximum value of 65535 + if (utflen > 65535) { + throw new IOException("encoded string too long: " + utflen + " bytes"); + } + + buffer.putShort((short) utflen); + buffer.put(bytes); + } +} diff --git a/jvector-base/src/main/java/io/github/jbellis/jvector/graph/disk/GraphIndexWriter.java b/jvector-base/src/main/java/io/github/jbellis/jvector/graph/disk/GraphIndexWriter.java index ac67900fe..68bad7d10 100644 --- a/jvector-base/src/main/java/io/github/jbellis/jvector/graph/disk/GraphIndexWriter.java +++ b/jvector-base/src/main/java/io/github/jbellis/jvector/graph/disk/GraphIndexWriter.java @@ -16,16 +16,31 @@ package io.github.jbellis.jvector.graph.disk; +import io.github.jbellis.jvector.disk.IndexWriter; +import io.github.jbellis.jvector.disk.RandomAccessWriter; +import io.github.jbellis.jvector.graph.ImmutableGraphIndex; import io.github.jbellis.jvector.graph.disk.feature.Feature; import io.github.jbellis.jvector.graph.disk.feature.FeatureId; import java.io.Closeable; +import java.io.FileNotFoundException; import java.io.IOException; +import java.nio.file.Path; import java.util.Map; import java.util.function.IntFunction; /** - * writes a graph index to a target + * Interface for writing graph indices to various storage targets. + *

+ * Implementations support different strategies for writing graph data, + * including random access, sequential, and parallel writing modes. + * Use {@link #getBuilderFor(GraphIndexWriterTypes, ImmutableGraphIndex, IndexWriter)} + * or {@link #getBuilderFor(GraphIndexWriterTypes, ImmutableGraphIndex, Path)} + * factory methods to obtain appropriate builder instances. + * + * @see GraphIndexWriterTypes + * @see OnDiskGraphIndexWriter + * @see OnDiskSequentialGraphIndexWriter */ public interface GraphIndexWriter extends Closeable { /** @@ -38,4 +53,62 @@ public interface GraphIndexWriter extends Closeable { * @param featureStateSuppliers a map of FeatureId to a function that returns a Feature.State */ void write(Map> featureStateSuppliers) throws IOException; + + /** + * Factory method to obtain a builder for the specified writer type with an IndexWriter. + *

+ * This overload accepts any IndexWriter but certain types have specific requirements: + *

    + *
  • ON_DISK requires a RandomAccessWriter (will throw IllegalArgumentException otherwise)
  • + *
  • ON_DISK_SEQUENTIAL accepts any IndexWriter
  • + *
  • ON_DISK_PARALLEL is not supported via this method (use the Path overload instead)
  • + *
+ * + * @param type the type of writer to create + * @param graphIndex the graph index to write + * @param out the output writer + * @return a builder for the specified writer type + * @throws IllegalArgumentException if the type requires a specific writer type that wasn't provided + */ + static AbstractGraphIndexWriter.Builder, ? extends IndexWriter> + getBuilderFor(GraphIndexWriterTypes type, ImmutableGraphIndex graphIndex, IndexWriter out) { + switch (type) { + case ON_DISK_PARALLEL: + if (!(out instanceof RandomAccessWriter)) { + throw new IllegalArgumentException("ON_DISK_PARALLEL requires a RandomAccessWriter"); + } + return new OnDiskGraphIndexWriter.Builder(graphIndex, (RandomAccessWriter) out); + case ON_DISK_SEQUENTIAL: + return new OnDiskSequentialGraphIndexWriter.Builder(graphIndex, out); + default: + throw new IllegalArgumentException("Unknown GraphIndexWriterType: " + type); + } + } + + /** + * Factory method to obtain a builder for the specified writer type with a file Path. + *

+ * This overload accepts a Path and is required for: + *

    + *
  • ON_DISK_PARALLEL - enables async I/O for improved throughput
  • + *
+ * Other writer types should use the {@link #getBuilderFor(GraphIndexWriterTypes, ImmutableGraphIndex, IndexWriter)} + * overload instead. + * + * @param type the type of writer to create (currently only ON_DISK_PARALLEL is supported) + * @param graphIndex the graph index to write + * @param out the output file path + * @return a builder for the specified writer type + * @throws FileNotFoundException if the file cannot be created or opened + * @throws IllegalArgumentException if the type is not supported via this method + */ + static AbstractGraphIndexWriter.Builder, ? extends IndexWriter> + getBuilderFor(GraphIndexWriterTypes type, ImmutableGraphIndex graphIndex, Path out) throws FileNotFoundException { + switch (type) { + case ON_DISK_PARALLEL: + return new OnDiskGraphIndexWriter.Builder(graphIndex, out); + default: + throw new IllegalArgumentException("Unknown GraphIndexWriterType: " + type); + } + } } diff --git a/jvector-base/src/main/java/io/github/jbellis/jvector/graph/disk/GraphIndexWriterTypes.java b/jvector-base/src/main/java/io/github/jbellis/jvector/graph/disk/GraphIndexWriterTypes.java new file mode 100644 index 000000000..b6c0b8816 --- /dev/null +++ b/jvector-base/src/main/java/io/github/jbellis/jvector/graph/disk/GraphIndexWriterTypes.java @@ -0,0 +1,42 @@ +/* + * Copyright DataStax, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.github.jbellis.jvector.graph.disk; + +/** + * Enum defining the available types of graph index writers. + *

+ * Different writer types offer different tradeoffs between performance, + * compatibility, and features. + */ +public enum GraphIndexWriterTypes { + /** + * Sequential on-disk writer optimized for write-once scenarios. + * Writes all data sequentially without seeking back, making it suitable + * for cloud storage or systems that optimize for sequential I/O. + * Writes header as footer. Does not support incremental updates. + * Accepts any IndexWriter. + */ + ON_DISK_SEQUENTIAL, + + /** + * Parallel on-disk writer that uses asynchronous I/O for improved throughput. + * Builds records in parallel across multiple threads and writes them + * asynchronously using AsynchronousFileChannel. + * Requires a Path to be provided for async file channel access. + */ + ON_DISK_PARALLEL +} diff --git a/jvector-base/src/main/java/io/github/jbellis/jvector/graph/disk/NodeRecordTask.java b/jvector-base/src/main/java/io/github/jbellis/jvector/graph/disk/NodeRecordTask.java new file mode 100644 index 000000000..b3089f10a --- /dev/null +++ b/jvector-base/src/main/java/io/github/jbellis/jvector/graph/disk/NodeRecordTask.java @@ -0,0 +1,181 @@ +/* + * Copyright DataStax, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.github.jbellis.jvector.graph.disk; + +import io.github.jbellis.jvector.disk.ByteBufferIndexWriter; +import io.github.jbellis.jvector.graph.ImmutableGraphIndex; +import io.github.jbellis.jvector.graph.disk.feature.Feature; +import io.github.jbellis.jvector.graph.disk.feature.FeatureId; + +import java.nio.ByteBuffer; +import java.util.ArrayList; +import java.util.List; +import java.util.Map; +import java.util.concurrent.Callable; +import java.util.function.IntFunction; + +/** + * A task that builds L0 records for a range of nodes in memory. + *

+ * This task is designed to be executed in a thread pool, with each worker thread + * owning its own ImmutableGraphIndex.View for thread-safe neighbor iteration. + * Each task processes a contiguous range of ordinals to reduce task creation overhead. + */ +class NodeRecordTask implements Callable> { + private final int startOrdinal; // Inclusive + private final int endOrdinal; // Exclusive + private final OrdinalMapper ordinalMapper; + private final ImmutableGraphIndex graph; + private final ImmutableGraphIndex.View view; + private final List inlineFeatures; + private final Map> featureStateSuppliers; + private final int recordSize; + private final long baseOffset; // Base file offset for L0 (offsets calculated per-ordinal) + private final ByteBuffer buffer; + + /** + * Result of building a node record. + */ + static class Result { + final int newOrdinal; + final long fileOffset; + final ByteBuffer data; + + Result(int newOrdinal, long fileOffset, ByteBuffer data) { + this.newOrdinal = newOrdinal; + this.fileOffset = fileOffset; + this.data = data; + } + } + + NodeRecordTask(int startOrdinal, + int endOrdinal, + OrdinalMapper ordinalMapper, + ImmutableGraphIndex graph, + ImmutableGraphIndex.View view, + List inlineFeatures, + Map> featureStateSuppliers, + int recordSize, + long baseOffset, + ByteBuffer buffer) { + this.startOrdinal = startOrdinal; + this.endOrdinal = endOrdinal; + this.ordinalMapper = ordinalMapper; + this.graph = graph; + this.view = view; + this.inlineFeatures = inlineFeatures; + this.featureStateSuppliers = featureStateSuppliers; + this.recordSize = recordSize; + this.baseOffset = baseOffset; + this.buffer = buffer; + } + + @Override + public List call() throws Exception { + List results = new ArrayList<>(endOrdinal - startOrdinal); + + // Reuse writer and buffer across all ordinals in this range + var writer = new ByteBufferIndexWriter(buffer); + + for (int newOrdinal = startOrdinal; newOrdinal < endOrdinal; newOrdinal++) { + // Calculate file offset for this ordinal + long fileOffset = baseOffset + (long) newOrdinal * recordSize; + + // Reset buffer for this ordinal + writer.reset(); + + var originalOrdinal = ordinalMapper.newToOld(newOrdinal); + + // Write node ordinal + writer.writeInt(newOrdinal); + + // Handle OMITTED nodes (holes in ordinal space) + if (originalOrdinal == OrdinalMapper.OMITTED) { + // Write placeholder: skip inline features and write empty neighbor list + for (var feature : inlineFeatures) { + // Write zeros for missing features + for (int i = 0; i < feature.featureSize(); i++) { + writer.writeByte(0); + } + } + writer.writeInt(0); // neighbor count + for (int n = 0; n < graph.getDegree(0); n++) { + writer.writeInt(-1); // padding + } + } else { + // Validate node exists + if (!graph.containsNode(originalOrdinal)) { + throw new IllegalStateException( + String.format("Ordinal mapper mapped new ordinal %s to non-existing node %s", + newOrdinal, originalOrdinal)); + } + + // Write inline features + for (var feature : inlineFeatures) { + var supplier = featureStateSuppliers.get(feature.id()); + if (supplier == null) { + // Write zeros for missing supplier + for (int i = 0; i < feature.featureSize(); i++) { + writer.writeByte(0); + } + } else { + feature.writeInline(writer, supplier.apply(originalOrdinal)); + } + } + + // Write neighbors + var neighbors = view.getNeighborsIterator(0, originalOrdinal); + if (neighbors.size() > graph.getDegree(0)) { + throw new IllegalStateException( + String.format("Node %d has more neighbors %d than the graph's max degree %d -- run Builder.cleanup()!", + originalOrdinal, neighbors.size(), graph.getDegree(0))); + } + + writer.writeInt(neighbors.size()); + int n = 0; + for (; n < neighbors.size(); n++) { + var newNeighborOrdinal = ordinalMapper.oldToNew(neighbors.nextInt()); + if (newNeighborOrdinal < 0 || newNeighborOrdinal > ordinalMapper.maxOrdinal()) { + throw new IllegalStateException( + String.format("Neighbor ordinal out of bounds: %d/%d", + newNeighborOrdinal, ordinalMapper.maxOrdinal())); + } + writer.writeInt(newNeighborOrdinal); + } + + // Pad to max degree + for (; n < graph.getDegree(0); n++) { + writer.writeInt(-1); + } + } + + // Verify we wrote exactly the expected amount + if (writer.bytesWritten() != recordSize) { + throw new IllegalStateException( + String.format("Record size mismatch for ordinal %d: expected %d bytes, wrote %d bytes", + newOrdinal, recordSize, writer.bytesWritten())); + } + + // Writer handles flip, copy, and reset internally + // The copy ensures thread-local buffer can be safely reused for the next ordinal + ByteBuffer dataCopy = writer.cloneBuffer(); + results.add(new Result(newOrdinal, fileOffset, dataCopy)); + } + + return results; + } +} diff --git a/jvector-base/src/main/java/io/github/jbellis/jvector/graph/disk/OnDiskGraphIndexWriter.java b/jvector-base/src/main/java/io/github/jbellis/jvector/graph/disk/OnDiskGraphIndexWriter.java index a8515c191..851934df9 100644 --- a/jvector-base/src/main/java/io/github/jbellis/jvector/graph/disk/OnDiskGraphIndexWriter.java +++ b/jvector-base/src/main/java/io/github/jbellis/jvector/graph/disk/OnDiskGraphIndexWriter.java @@ -66,17 +66,66 @@ */ public class OnDiskGraphIndexWriter extends AbstractGraphIndexWriter { private final long startOffset; + private volatile boolean useParallelWrites = false; + private final Path filePath; // Required for parallel writes + private final int parallelWorkerThreads; + private final boolean parallelUseDirectBuffers; + /** + * Constructs an OnDiskGraphIndexWriter with all parameters including optional file path + * and parallel write configuration. + * + * @param randomAccessWriter the writer to use for output + * @param version the format version to write + * @param startOffset the starting offset in the file + * @param graph the graph to write + * @param oldToNewOrdinals mapper for ordinal renumbering + * @param dimension the vector dimension + * @param features the features to include + * @param filePath file path required for parallel writes (can be null for sequential writes) + * @param parallelWorkerThreads number of worker threads for parallel writes (0 = use available processors) + * @param parallelUseDirectBuffers whether to use direct ByteBuffers for parallel writes + */ OnDiskGraphIndexWriter(RandomAccessWriter randomAccessWriter, int version, long startOffset, ImmutableGraphIndex graph, OrdinalMapper oldToNewOrdinals, int dimension, - EnumMap features) + EnumMap features, + Path filePath, + int parallelWorkerThreads, + boolean parallelUseDirectBuffers) { super(randomAccessWriter, version, graph, oldToNewOrdinals, dimension, features); this.startOffset = startOffset; + this.filePath = filePath; + this.parallelWorkerThreads = parallelWorkerThreads; + this.parallelUseDirectBuffers = parallelUseDirectBuffers; + } + + /** + * Constructs an OnDiskGraphIndexWriter without a file path. + * Parallel writes will not be available without a file path. + * Uses default parallel write configuration. + * + * @param randomAccessWriter the writer to use for output + * @param version the format version to write + * @param startOffset the starting offset in the file + * @param graph the graph to write + * @param oldToNewOrdinals mapper for ordinal renumbering + * @param dimension the vector dimension + * @param features the features to include + */ + OnDiskGraphIndexWriter(RandomAccessWriter randomAccessWriter, + int version, + long startOffset, + ImmutableGraphIndex graph, + OrdinalMapper oldToNewOrdinals, + int dimension, + EnumMap features) + { + this(randomAccessWriter, version, startOffset, graph, oldToNewOrdinals, dimension, features, null, 0, false); } /** @@ -155,6 +204,87 @@ public synchronized void write(Map> featur writeHeader(view); // sets position to start writing features + // Write L0 records (either parallel or sequential) + if (useParallelWrites) { + writeL0RecordsParallel(featureStateSuppliers); + } else { + writeL0RecordsSequential(view, featureStateSuppliers); + } + + // We will use the abstract method because no random access is needed + writeSparseLevels(view); + + // We will use the abstract method because no random access is needed + writeSeparatedFeatures(featureStateSuppliers); + + if (version >= 5) { + writeFooter(view, out.position()); + } + final var endOfGraphPosition = out.position(); + + // Write the header again with updated offsets + writeHeader(view); + out.seek(endOfGraphPosition); + out.flush(); + view.close(); + } + + /** + * Writes L0 records using parallel workers with asynchronous file I/O. + *

+ * Records are written asynchronously using AsynchronousFileChannel for improved throughput + * while maintaining correct ordering. This method parallelizes record building across + * multiple threads and writes results in sequential order. + *

+ * Requires filePath to have been provided during construction. + * + * @param featureStateSuppliers suppliers for feature state data + * @throws IOException if an I/O error occurs + */ + private void writeL0RecordsParallel(Map> featureStateSuppliers) throws IOException { + if (filePath == null) { + throw new IllegalStateException("Parallel writes require a file path. Use Builder(ImmutableGraphIndex, Path) constructor."); + } + + // Flush writer before async writes to ensure buffered data is on disk + // This is critical when using AsynchronousFileChannel in parallel with BufferedRandomAccessWriter + out.flush(); + long baseOffset = out.position(); + + var config = new ParallelGraphWriter.Config( + parallelWorkerThreads, + parallelUseDirectBuffers, + 4 // Default task multiplier (4x cores) + ); + + try (var parallelWriter = new ParallelGraphWriter( + out, + graph, + inlineFeatures, + config, + filePath)) { + + parallelWriter.writeL0Records( + ordinalMapper, + inlineFeatures, + featureStateSuppliers, + baseOffset + ); + + // Update maxOrdinalWritten + maxOrdinalWritten = ordinalMapper.maxOrdinal(); + + // Seek to end of L0 region + long endOffset = baseOffset + (long) (ordinalMapper.maxOrdinal() + 1) * parallelWriter.getRecordSize(); + out.seek(endOffset); + } + } + + /** + * Writes L0 records sequentially (original implementation). + */ + private void writeL0RecordsSequential(ImmutableGraphIndex.View view, + Map> featureStateSuppliers) throws IOException { // for each graph node, write the associated features, followed by its neighbors at L0 for (int newOrdinal = 0; newOrdinal <= ordinalMapper.maxOrdinal(); newOrdinal++) { var originalOrdinal = ordinalMapper.newToOld(newOrdinal); @@ -211,23 +341,6 @@ public synchronized void write(Map> featur out.writeInt(-1); } } - - // We will use the abstract method because no random access is needed - writeSparseLevels(view); - - // We will use the abstract method because no random access is needed - writeSeparatedFeatures(featureStateSuppliers); - - // Write the header again with updated offsets - if (version >= 5) { - writeFooter(view, out.position()); - } - - final var endOfGraphPosition = out.position(); - writeHeader(view); - out.seek(endOfGraphPosition); - out.flush(); - view.close(); } /** @@ -243,20 +356,53 @@ public synchronized void writeHeader(ImmutableGraphIndex.View view) throws IOExc out.flush(); } - /** CRC32 checksum of bytes written since the starting offset */ + /** CRC32 checksum of bytes written since the starting offset + * Note on parallel writes and footer handling: + * When parallel writes are enabled (via {@link #setParallelWrites(boolean)}), it is the caller's responsibility + * to ensure that all parallel write operations have fully completed before writing the footer (e.g., checksum). + * The footer must only be written after all data has been flushed and no further writes are in progress, + * to avoid data corruption or incomplete checksums. This class does not currently coordinate or synchronize + * footer writing with parallel operations. Parallel writes are experimental and should be used with caution. + */ public synchronized long checksum() throws IOException { long endOffset = out.position(); return out.checksum(startOffset, endOffset); } + /** + * Enables parallel writes for L0 records. This can significantly improve throughput + * for large graphs by parallelizing record building across multiple cores. + *

+ * Note: This is currently experimental. The sequential path is the default. + * + * @param enabled whether to enable parallel writes + */ + public void setParallelWrites(boolean enabled) { + this.useParallelWrites = enabled; + } + /** * Builder for {@link OnDiskGraphIndexWriter}, with optional features. */ public static class Builder extends AbstractGraphIndexWriter.Builder { private long startOffset = 0L; + private boolean useParallelWrites = false; + /** + * The current implementation of this Builder allows for a RandomAccessWriter to be passed to the constructor in + * order to allow the backing of any IndexWriter and not tying to a particular implementation. However in this + * case the class we are in is literally named "OnDiskGraphIndexWriter" and is built to store the graph index + * on a file on disk. As RandomAccessWriter does not allow for the extraction of the backing Path there is no way + * to use async i/o to write the file without modifying the way the OnDiskGraphIndexWriter.Builder is constructed. + * Hence the addition here of a Path variable. In the future it would be an optimization to deprecate the constructor + * that uses a RandomAccessWriter and only allow the one that takes a Path. For now this allows for backwards compatibility. + */ + private Path filePath = null; + private int parallelWorkerThreads = 0; + private boolean parallelUseDirectBuffers = false; public Builder(ImmutableGraphIndex graphIndex, Path outPath) throws FileNotFoundException { this(graphIndex, new BufferedRandomAccessWriter(outPath)); + this.filePath = outPath; } public Builder(ImmutableGraphIndex graphIndex, RandomAccessWriter out) { @@ -272,9 +418,45 @@ public Builder withStartOffset(long startOffset) { return this; } + /** + * Enable parallel writes for L0 records. Can significantly improve throughput for large graphs. + */ + public Builder withParallelWrites(boolean enabled) { + this.useParallelWrites = enabled; + return this; + } + + /** + * Set the number of worker threads for parallel writes. + * + * @param workerThreads number of worker threads (0 = use available processors) + * @return this builder + */ + public Builder withParallelWorkerThreads(int workerThreads) { + this.parallelWorkerThreads = workerThreads; + return this; + } + + /** + * Set whether to use direct ByteBuffers for parallel writes. + * Direct buffers can provide better performance for large records but use off-heap memory. + * + * @param useDirectBuffers whether to use direct ByteBuffers + * @return this builder + */ + public Builder withParallelDirectBuffers(boolean useDirectBuffers) { + this.parallelUseDirectBuffers = useDirectBuffers; + return this; + } + @Override - protected OnDiskGraphIndexWriter reallyBuild(int dimension) throws IOException { - return new OnDiskGraphIndexWriter(out, version, startOffset, graphIndex, ordinalMapper, dimension, features); + protected OnDiskGraphIndexWriter reallyBuild(int dimension) { + var writer = new OnDiskGraphIndexWriter( + out, version, startOffset, graphIndex, ordinalMapper, dimension, features, filePath, + parallelWorkerThreads, parallelUseDirectBuffers + ); + writer.setParallelWrites(useParallelWrites); + return writer; } } } \ No newline at end of file diff --git a/jvector-base/src/main/java/io/github/jbellis/jvector/graph/disk/ParallelGraphWriter.java b/jvector-base/src/main/java/io/github/jbellis/jvector/graph/disk/ParallelGraphWriter.java new file mode 100644 index 000000000..1173b30a8 --- /dev/null +++ b/jvector-base/src/main/java/io/github/jbellis/jvector/graph/disk/ParallelGraphWriter.java @@ -0,0 +1,358 @@ +/* + * Copyright DataStax, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.github.jbellis.jvector.graph.disk; + +import io.github.jbellis.jvector.disk.RandomAccessWriter; +import io.github.jbellis.jvector.graph.ImmutableGraphIndex; +import io.github.jbellis.jvector.graph.disk.feature.Feature; +import io.github.jbellis.jvector.graph.disk.feature.FeatureId; + +import java.io.IOException; +import java.nio.ByteBuffer; +import java.nio.channels.AsynchronousFileChannel; +import java.nio.file.Path; +import java.nio.file.StandardOpenOption; + +import java.util.EnumSet; +import java.util.List; +import java.util.Objects; +import java.util.Map; +import java.util.ArrayList; +import java.util.concurrent.LinkedBlockingQueue; +import java.util.concurrent.CopyOnWriteArrayList; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.Future; +import java.util.concurrent.ThreadPoolExecutor; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.function.IntFunction; + +/** + * Orchestrates parallel writing of L0 node records to disk using asynchronous file I/O. + *

+ * This class manages: + * - A thread pool for building node records in parallel + * - Per-thread ImmutableGraphIndex.View instances for thread-safe neighbor iteration + * - A buffer pool to avoid excessive allocation + * - Asynchronous file channel writes that maintain correct ordering + *

+ * Usage: + *

+ * try (var parallelWriter = new ParallelGraphWriter(writer, graph, config, filePath)) {
+ *     parallelWriter.writeL0Records(ordinalMapper, inlineFeatures, featureStateSuppliers, baseOffset);
+ * }
+ * 
+ */ +class ParallelGraphWriter implements AutoCloseable { + private final RandomAccessWriter writer; + private final ImmutableGraphIndex graph; + private final ExecutorService executor; + private final ThreadLocal viewPerThread; + private final ThreadLocal bufferPerThread; + private final CopyOnWriteArrayList allViews = new CopyOnWriteArrayList<>(); + private final int recordSize; + private final Path filePath; + private final int taskMultiplier; + private static final AtomicInteger threadCounter = new AtomicInteger(0); + + /** + * Configuration for parallel writing. + */ + static class Config { + final int workerThreads; + final boolean useDirectBuffers; + final int taskMultiplier; + + /** + * @param workerThreads number of worker threads for building records (0 = use available processors) + * @param useDirectBuffers whether to use direct ByteBuffers (can be faster for large records) + * @param taskMultiplier multiplier for number of tasks relative to worker threads + * (4x = good balance for most use cases, higher = more fine-grained parallelism) + */ + public Config(int workerThreads, boolean useDirectBuffers, int taskMultiplier) { + this.workerThreads = workerThreads <= 0 ? Runtime.getRuntime().availableProcessors() : workerThreads; + this.useDirectBuffers = useDirectBuffers; + this.taskMultiplier = taskMultiplier <= 0 ? 4 : taskMultiplier; + } + + /** + * Returns a default configuration suitable for most use cases. + * Uses available CPU cores, heap buffers, and 4x task multiplier. + * + * @return default configuration + */ + public static Config defaultConfig() { + return new Config(0, false, 4); + } + } + + /** + * Creates a parallel writer. + * + * @param writer the underlying writer + * @param graph the graph being written + * @param inlineFeatures the inline features to write + * @param config parallelization configuration + * @param filePath file path for async writes (required, cannot be null) + */ + public ParallelGraphWriter(RandomAccessWriter writer, + ImmutableGraphIndex graph, + List inlineFeatures, + Config config, + Path filePath) { + this.writer = writer; + this.graph = graph; + this.filePath = Objects.requireNonNull(filePath); + this.taskMultiplier = config.taskMultiplier; + this.executor = Executors.newFixedThreadPool(config.workerThreads, + r -> { + Thread t = new Thread(r); + t.setName("ParallelGraphWriter-Worker-" + threadCounter.getAndIncrement()); + t.setDaemon(false); + return t; + }); + + // Compute fixed record size for L0 + this.recordSize = Integer.BYTES // node ordinal + + inlineFeatures.stream().mapToInt(Feature::featureSize).sum() + + Integer.BYTES // neighbor count + + graph.getDegree(0) * Integer.BYTES; // neighbors + padding + + // Thread-local views for safe neighbor iteration + // CopyOnWriteArrayList handles concurrent additions safely + this.viewPerThread = ThreadLocal.withInitial(() -> { + var view = graph.getView(); + allViews.add(view); + return view; + }); + + // Thread-local buffers to avoid allocation overhead + // Use BIG_ENDIAN to match Java DataOutput specification + final int bufferSize = recordSize; + final boolean useDirect = config.useDirectBuffers; + this.bufferPerThread = ThreadLocal.withInitial(() -> { + ByteBuffer buffer = useDirect ? ByteBuffer.allocateDirect(bufferSize) : ByteBuffer.allocate(bufferSize); + buffer.order(java.nio.ByteOrder.BIG_ENDIAN); + return buffer; + }); + } + + /** + * Writes all L0 node records in parallel using asynchronous file I/O with range-based task batching. + * Records are written in order to maintain index correctness. The implementation divides the ordinal + * space into ranges that are processed by a fixed number of tasks, reducing overhead compared to + * per-ordinal task creation. + *

+ * The number of tasks is determined by available CPU cores multiplied by a configurable multiplier. + * This provides good load balancing while minimizing task creation and management overhead. + * + * @param ordinalMapper maps between old and new ordinals + * @param inlineFeatures the inline features to write + * @param featureStateSuppliers suppliers for feature state + * @param baseOffset the file offset where L0 records start + * @throws IOException if an IO error occurs + */ + public void writeL0Records(OrdinalMapper ordinalMapper, + List inlineFeatures, + Map> featureStateSuppliers, + long baseOffset) throws IOException { + int maxOrdinal = ordinalMapper.maxOrdinal(); + int totalOrdinals = maxOrdinal + 1; + + // Calculate optimal number of tasks based on cores and task multiplier + int numCores = Runtime.getRuntime().availableProcessors(); + int numTasks = Math.min((totalOrdinals / (numCores * taskMultiplier)), totalOrdinals); + + // Calculate ordinals per task (ceiling division to cover all ordinals) + int ordinalsPerTask = (totalOrdinals + numTasks - 1) / numTasks; + + List>> futures = new ArrayList<>(numTasks); + + // Submit range-based tasks + for (int i = 0; i < numTasks; i++) { + int startOrdinal = i * ordinalsPerTask; + int endOrdinal = Math.min(startOrdinal + ordinalsPerTask, totalOrdinals); + + // Skip if range is empty (can happen with final task) + if (startOrdinal >= totalOrdinals) { + break; + } + + final int start = startOrdinal; + final int end = endOrdinal; + + Future> future = executor.submit(() -> { + var view = viewPerThread.get(); + var buffer = bufferPerThread.get(); + + var task = new NodeRecordTask( + start, // Start of range (inclusive) + end, // End of range (exclusive) + ordinalMapper, + graph, + view, + inlineFeatures, + featureStateSuppliers, + recordSize, + baseOffset, // Base offset (task calculates per-ordinal offsets) + buffer + ); + + return task.call(); + }); + + futures.add(future); + } + + // Write all records async + writeRecordsAsync(futures); + } + + + /** + * Writes records asynchronously using AsynchronousFileChannel for improved throughput. + * Records are written in sequential order by iterating through the task futures and their + * results, which ensures that even though record building is parallelized, writes occur in + * the correct order. Creates a dedicated thread pool for async I/O operations and properly + * cleans up resources. + * + * @param futures the completed record building tasks (each containing a list of results) + * @throws IOException if an I/O error occurs + */ + private void writeRecordsAsync(List>> futures) throws IOException { + var opts = EnumSet.of(StandardOpenOption.WRITE, StandardOpenOption.READ); + int numThreads = Math.min(Runtime.getRuntime().availableProcessors(), 32); + ExecutorService fileWritePool = null; + + try { + fileWritePool = new ThreadPoolExecutor( + numThreads, numThreads, + 0L, TimeUnit.MILLISECONDS, + new LinkedBlockingQueue<>(), + r -> { + var t = new Thread(r, "graphnode-writer"); + t.setDaemon(true); + return t; + }); + + // Use a bounded list to allow multiple concurrent async writes while providing backpressure + // Buffer size is 2x the I/O thread pool size to keep the pipeline full + int maxConcurrentWrites = numThreads * 2; + List> pendingWrites = new ArrayList<>(maxConcurrentWrites); + + try (var afc = AsynchronousFileChannel.open(filePath, opts, fileWritePool)) { + // Iterate through task futures (in order) + for (Future> future : futures) { + List results = future.get(); + + // Write each result in the batch + for (NodeRecordTask.Result result : results) { + // Submit async write and track the future + // result.data is already a copy made in NodeRecordTask to avoid + // race conditions with thread-local buffer reuse + Future writeFuture = afc.write(result.data, result.fileOffset); + pendingWrites.add(writeFuture); + + // When buffer is full, wait for all pending writes to complete + if (pendingWrites.size() >= maxConcurrentWrites) { + for (Future wf : pendingWrites) { + wf.get(); // Wait for write completion + } + pendingWrites.clear(); + } + } + } + + // Wait for any remaining pending writes + for (Future wf : pendingWrites) { + wf.get(); + } + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + throw new IOException("Interrupted while writing records", e); + } catch (ExecutionException e) { + throw unwrapExecutionException(e); + } + } finally { + if (fileWritePool != null) { + fileWritePool.shutdown(); + try { + if (!fileWritePool.awaitTermination(60, TimeUnit.SECONDS)) { + fileWritePool.shutdownNow(); + } + } catch (InterruptedException e) { + fileWritePool.shutdownNow(); + Thread.currentThread().interrupt(); + } + } + } + } + + /** + * Unwraps ExecutionException to throw the underlying cause. + * Handles IOException, RuntimeException, and wraps other exceptions. + * + * @param e the execution exception to unwrap + * @return an IOException wrapping the cause + * @throws RuntimeException if the cause is a RuntimeException + */ + private IOException unwrapExecutionException(ExecutionException e) { + Throwable cause = e.getCause(); + if (cause instanceof IOException) { + return (IOException) cause; + } else if (cause instanceof RuntimeException) { + throw (RuntimeException) cause; + } else { + throw new RuntimeException("Error building node record", cause); + } + } + + /** + * Returns the computed record size for L0 nodes. + */ + public int getRecordSize() { + return recordSize; + } + + @Override + public void close() throws IOException { + try { + // Shutdown executor + executor.shutdown(); + try { + if (!executor.awaitTermination(60, TimeUnit.SECONDS)) { + executor.shutdownNow(); + } + } catch (InterruptedException e) { + executor.shutdownNow(); + Thread.currentThread().interrupt(); + } + + // Close all views (CopyOnWriteArrayList is safe for concurrent iteration) + for (var view : allViews) { + view.close(); + } + allViews.clear(); + } catch (IOException e) { + throw e; + } catch (Exception e) { + throw new IOException("Error closing parallel writer", e); + } + } +} diff --git a/jvector-examples/src/main/java/io/github/jbellis/jvector/example/Grid.java b/jvector-examples/src/main/java/io/github/jbellis/jvector/example/Grid.java index a4d62645f..098b07816 100644 --- a/jvector-examples/src/main/java/io/github/jbellis/jvector/example/Grid.java +++ b/jvector-examples/src/main/java/io/github/jbellis/jvector/example/Grid.java @@ -300,6 +300,9 @@ private static BuilderWithSuppliers builderWithSuppliers(Set features var identityMapper = new OrdinalMapper.IdentityMapper(floatVectors.size() - 1); var builder = new OnDiskGraphIndexWriter.Builder(onHeapGraph, outPath); builder.withMapper(identityMapper); + + // Enable parallel writes for improved throughput + builder.withParallelWrites(true); Map> suppliers = new EnumMap<>(FeatureId.class); for (var featureId : features) { switch (featureId) { diff --git a/jvector-examples/src/main/java/io/github/jbellis/jvector/graph/disk/ParallelWriteExample.java b/jvector-examples/src/main/java/io/github/jbellis/jvector/graph/disk/ParallelWriteExample.java new file mode 100644 index 000000000..b9ce75cc8 --- /dev/null +++ b/jvector-examples/src/main/java/io/github/jbellis/jvector/graph/disk/ParallelWriteExample.java @@ -0,0 +1,383 @@ +/* + * Copyright DataStax, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.github.jbellis.jvector.graph.disk; + +import io.github.jbellis.jvector.disk.ReaderSupplierFactory; +import io.github.jbellis.jvector.example.util.DataSet; +import io.github.jbellis.jvector.example.util.DataSetLoader; +import io.github.jbellis.jvector.graph.GraphIndexBuilder; +import io.github.jbellis.jvector.graph.ImmutableGraphIndex; +import io.github.jbellis.jvector.graph.NodesIterator; +import io.github.jbellis.jvector.graph.RandomAccessVectorValues; +import io.github.jbellis.jvector.graph.disk.feature.Feature; +import io.github.jbellis.jvector.graph.disk.feature.FeatureId; +import io.github.jbellis.jvector.graph.disk.feature.FusedADC; +import io.github.jbellis.jvector.graph.disk.feature.NVQ; +import io.github.jbellis.jvector.graph.similarity.BuildScoreProvider; +import io.github.jbellis.jvector.quantization.NVQuantization; +import io.github.jbellis.jvector.quantization.PQVectors; +import io.github.jbellis.jvector.quantization.ProductQuantization; + +import java.io.IOException; +import java.nio.file.Files; +import java.nio.file.Path; +import java.util.Arrays; +import java.util.EnumMap; +import java.util.Map; +import java.util.function.IntFunction; + +import static io.github.jbellis.jvector.quantization.KMeansPlusPlusClusterer.UNWEIGHTED; + +/** + * Example demonstrating how to use parallel writes with OnDiskGraphIndexWriter. + *

+ * Usage patterns: + *

+ * // Sequential (default):
+ * var writer = new OnDiskGraphIndexWriter.Builder(graph, outputPath)
+ *     .with(inlineVectors)
+ *     .build();
+ * writer.write(featureSuppliers);
+ *
+ * // Parallel:
+ * var writer = new OnDiskGraphIndexWriter.Builder(graph, outputPath)
+ *     .with(inlineVectors)
+ *     .withParallelWrites(true)  // Enable parallel writes
+ *     .build();
+ * writer.write(featureSuppliers);
+ * 
+ */ +public class ParallelWriteExample { + + /** + * Verifies that two OnDiskGraphIndex instances are identical in structure and content. + * Compares graph structure (nodes, neighbors) and feature data (vectors). + */ + private static void verifyIndicesIdentical(OnDiskGraphIndex index1, OnDiskGraphIndex index2) throws IOException { + System.out.println("\n=== Verifying Graph Indices ==="); + + // Check basic properties + if (index1.getMaxLevel() != index2.getMaxLevel()) { + throw new AssertionError(String.format("Max levels differ: %d vs %d", + index1.getMaxLevel(), index2.getMaxLevel())); + } + System.out.printf("✓ Max level matches: %d%n", index1.getMaxLevel()); + + if (index1.getIdUpperBound() != index2.getIdUpperBound()) { + throw new AssertionError(String.format("ID upper bounds differ: %d vs %d", + index1.getIdUpperBound(), index2.getIdUpperBound())); + } + System.out.printf("✓ ID upper bound matches: %d%n", index1.getIdUpperBound()); + + if (!index1.getFeatureSet().equals(index2.getFeatureSet())) { + throw new AssertionError(String.format("Feature sets differ: %s vs %s", + index1.getFeatureSet(), index2.getFeatureSet())); + } + System.out.printf("✓ Feature sets match: %s%n", index1.getFeatureSet()); + + // Check each layer + try (var view1 = index1.getView(); var view2 = index2.getView()) { + // Check entry nodes (accessed through views) + if (!view1.entryNode().equals(view2.entryNode())) { + throw new AssertionError(String.format("Entry nodes differ: %s vs %s", + view1.entryNode(), view2.entryNode())); + } + System.out.printf("✓ Entry node matches: %s%n", view1.entryNode()); + for (int level = 0; level <= index1.getMaxLevel(); level++) { + if (index1.size(level) != index2.size(level)) { + throw new AssertionError(String.format("Layer %d sizes differ: %d vs %d", + level, index1.size(level), index2.size(level))); + } + + if (index1.getDegree(level) != index2.getDegree(level)) { + throw new AssertionError(String.format("Layer %d degrees differ: %d vs %d", + level, index1.getDegree(level), index2.getDegree(level))); + } + + // Collect all node IDs from both indices into arrays + java.util.List nodeList1 = new java.util.ArrayList<>(); + java.util.List nodeList2 = new java.util.ArrayList<>(); + + NodesIterator nodes1 = index1.getNodes(level); + while (nodes1.hasNext()) { + nodeList1.add(nodes1.nextInt()); + } + + NodesIterator nodes2 = index2.getNodes(level); + while (nodes2.hasNext()) { + nodeList2.add(nodes2.nextInt()); + } + + // Verify same set of nodes + if (!nodeList1.equals(nodeList2)) { + // Find differences + java.util.Set set1 = new java.util.HashSet<>(nodeList1); + java.util.Set set2 = new java.util.HashSet<>(nodeList2); + + java.util.Set onlyIn1 = new java.util.HashSet<>(set1); + onlyIn1.removeAll(set2); + + java.util.Set onlyIn2 = new java.util.HashSet<>(set2); + onlyIn2.removeAll(set1); + + System.out.printf("Layer %d node count: sequential=%d, parallel=%d%n", + level, nodeList1.size(), nodeList2.size()); + + if (!onlyIn1.isEmpty()) { + var sample1 = onlyIn1.stream().limit(10).collect(java.util.stream.Collectors.toList()); + System.out.printf(" Nodes only in sequential (first 10): %s%n", sample1); + } + if (!onlyIn2.isEmpty()) { + var sample2 = onlyIn2.stream().limit(10).collect(java.util.stream.Collectors.toList()); + System.out.printf(" Nodes only in parallel (first 10): %s%n", sample2); + } + + // Sample some nodes from each to see the pattern + System.out.printf(" First 20 nodes in sequential: %s%n", + nodeList1.stream().limit(20).collect(java.util.stream.Collectors.toList())); + System.out.printf(" First 20 nodes in parallel: %s%n", + nodeList2.stream().limit(20).collect(java.util.stream.Collectors.toList())); + + throw new AssertionError(String.format("Layer %d has different node sets: sequential has %d nodes, parallel has %d nodes, %d nodes differ", + level, nodeList1.size(), nodeList2.size(), onlyIn1.size() + onlyIn2.size())); + } + + // Compare neighbors for each node + int differentNeighbors = 0; + for (int nodeId : nodeList1) { + NodesIterator neighbors1 = view1.getNeighborsIterator(level, nodeId); + NodesIterator neighbors2 = view2.getNeighborsIterator(level, nodeId); + + if (neighbors1.size() != neighbors2.size()) { + throw new AssertionError(String.format("Layer %d node %d neighbor counts differ: %d vs %d", + level, nodeId, neighbors1.size(), neighbors2.size())); + } + + int[] n1 = new int[neighbors1.size()]; + int[] n2 = new int[neighbors2.size()]; + for (int i = 0; i < n1.length; i++) { + n1[i] = neighbors1.nextInt(); + n2[i] = neighbors2.nextInt(); + } + + if (!Arrays.equals(n1, n2)) { + differentNeighbors++; + if (differentNeighbors <= 3) { + System.out.printf(" ✗ Layer %d node %d has different neighbor sets: %s vs %s%n", + level, nodeId, Arrays.toString(n1), Arrays.toString(n2)); + } + } + } + + if (differentNeighbors > 0) { + throw new AssertionError(String.format("Layer %d: %d/%d nodes have different neighbor sets", + level, differentNeighbors, nodeList1.size())); + } + + System.out.printf("✓ Layer %d structure matches (%d nodes, degree %d)%n", + level, index1.size(level), index1.getDegree(level)); + } + + // Compare vectors if present (only check layer 0) + if (index1.getFeatureSet().contains(FeatureId.INLINE_VECTORS) || + index1.getFeatureSet().contains(FeatureId.NVQ_VECTORS)) { + + int vectorsChecked = 0; + int maxToCheck = Math.min(100, index1.size(0)); // Check up to 100 vectors as a sample + + NodesIterator nodes = index1.getNodes(0); + while (nodes.hasNext() && vectorsChecked < maxToCheck) { + int node = nodes.nextInt(); + + if (index1.getFeatureSet().contains(FeatureId.INLINE_VECTORS)) { + var vec1 = view1.getVector(node); + var vec2 = view2.getVector(node); + + if (!vec1.equals(vec2)) { + throw new AssertionError(String.format("Node %d vectors differ", node)); + } + } + + vectorsChecked++; + } + + System.out.printf("✓ Sampled %d vectors, all match%n", vectorsChecked); + } + } + + System.out.println("✓ All checks passed - indices are identical!"); + } + + /** + * Benchmark comparison between sequential and parallel writes using NVQ + FUSED_ADC features. + * This matches the configuration used in Grid.buildOnDisk for realistic performance testing. + */ + public static void benchmarkComparison(ImmutableGraphIndex graph, + Path sequentialPath, + Path parallelPath, + RandomAccessVectorValues floatVectors, + PQVectors pqVectors) throws IOException { + + int nSubVectors = floatVectors.dimension() == 2 ? 1 : 2; + var nvq = NVQuantization.compute(floatVectors, nSubVectors); + var pq = pqVectors.getCompressor(); + + // Create features: NVQ + FUSED_ADC + var nvqFeature = new NVQ(nvq); + var fusedAdcFeature = new FusedADC(graph.maxDegree(), pq); + + // Build suppliers for inline features (NVQ only - FUSED_ADC needs neighbors) + Map> inlineSuppliers = new EnumMap<>(FeatureId.class); + inlineSuppliers.put(FeatureId.NVQ_VECTORS, ordinal -> new NVQ.State(nvq.encode(floatVectors.getVector(ordinal)))); + + // FUSED_ADC supplier needs graph view, provided at write time + var identityMapper = new OrdinalMapper.IdentityMapper(floatVectors.size() - 1); + + // Sequential write + System.out.printf("Writing with NVQ + FUSED_ADC features...%n"); + long sequentialStart = System.nanoTime(); + try (var writer = new OnDiskGraphIndexWriter.Builder(graph, sequentialPath) + .withParallelWrites(false) + .with(nvqFeature) + .with(fusedAdcFeature) + .withMapper(identityMapper) + .build()) { + + var view = graph.getView(); + Map> writeSuppliers = new EnumMap<>(FeatureId.class); + writeSuppliers.put(FeatureId.NVQ_VECTORS, inlineSuppliers.get(FeatureId.NVQ_VECTORS)); + writeSuppliers.put(FeatureId.FUSED_ADC, ordinal -> new FusedADC.State(view, pqVectors, ordinal)); + + writer.write(writeSuppliers); + view.close(); + } + long sequentialTime = System.nanoTime() - sequentialStart; + System.out.printf("Sequential write: %.2f ms%n", sequentialTime / 1_000_000.0); + + // Parallel write + long parallelStart = System.nanoTime(); + try (var writer = new OnDiskGraphIndexWriter.Builder(graph, parallelPath) + .withParallelWrites(true) + .with(nvqFeature) + .with(fusedAdcFeature) + .withMapper(identityMapper) + .build()) { + + var view = graph.getView(); + Map> writeSuppliers = new EnumMap<>(FeatureId.class); + writeSuppliers.put(FeatureId.NVQ_VECTORS, inlineSuppliers.get(FeatureId.NVQ_VECTORS)); + writeSuppliers.put(FeatureId.FUSED_ADC, ordinal -> new FusedADC.State(view, pqVectors, ordinal)); + + writer.write(writeSuppliers); + view.close(); + } + long parallelTime = System.nanoTime() - parallelStart; + + System.out.printf("Parallel write: %.2f ms%n", parallelTime / 1_000_000.0); + System.out.printf("Speedup: %.2fx%n", (double) sequentialTime / parallelTime); + } + + /** + * Main method to run a benchmark test of sequential vs parallel writes. + * + * Usage: java ParallelWriteExample [dataset-name] + * + * Example: java ParallelWriteExample cohere-english-v3-100k + * + * If no dataset is provided, uses "cohere-english-v3-100k" by default. + */ + public static void main(String[] args) throws IOException { + String datasetName = args.length > 0 ? args[0] : "cohere-english-v3-100k"; + + System.out.println("Loading dataset: " + datasetName); + DataSet ds = DataSetLoader.loadDataSet(datasetName); + System.out.printf("Loaded %d vectors of dimension %d%n", ds.baseVectors.size(), ds.getDimension()); + + var floatVectors = ds.getBaseRavv(); + + // Build PQ compression (matching Grid.buildOnDisk pattern) + System.out.println("Computing PQ compression..."); + int pqM = floatVectors.dimension() / 8; // m = dimension / 8 + boolean centerData = ds.similarityFunction == io.github.jbellis.jvector.vector.VectorSimilarityFunction.EUCLIDEAN; + var pq = ProductQuantization.compute(floatVectors, pqM, 256, centerData, UNWEIGHTED); + var pqVectors = (PQVectors) pq.encodeAll(floatVectors); + System.out.printf("PQ compression: %d subspaces, 256 clusters%n", pqM); + + // Build graph parameters (matching typical benchmark settings) + int M = 32; + int efConstruction = 100; + float neighborOverflow = 1.2f; + float alpha = 1.2f; + boolean addHierarchy = false; + boolean refineFinalGraph = true; + + System.out.printf("Building graph with PQ-compressed vectors (M=%d, efConstruction=%d)...%n", M, efConstruction); + long buildStart = System.nanoTime(); + + var bsp = BuildScoreProvider.pqBuildScoreProvider(ds.similarityFunction, pqVectors); + var builder = new GraphIndexBuilder(bsp, floatVectors.dimension(), M, efConstruction, + neighborOverflow, alpha, addHierarchy, refineFinalGraph); + + // Build graph using parallel construction for much better performance + var graph = builder.build(floatVectors); + long buildTime = System.nanoTime() - buildStart; + System.out.printf("Graph built in %.2fs%n", buildTime / 1_000_000_000.0); + System.out.printf("Graph has %d nodes%n", graph.size(0)); + + // Create temporary paths for writing + Path tempDir = Files.createTempDirectory("parallel-write-test"); + Path sequentialPath = tempDir.resolve("graph-sequential"); + Path parallelPath = tempDir.resolve("graph-parallel"); + + try { + System.out.println("\n=== Testing Write Performance ==="); + + // Run benchmark comparison + benchmarkComparison(graph, sequentialPath, parallelPath, floatVectors, pqVectors); + + // Report file sizes + long seqSize = Files.size(sequentialPath); + long parSize = Files.size(parallelPath); + System.out.printf("%nFile sizes: Sequential=%.2f MB, Parallel=%.2f MB%n", + seqSize / 1024.0 / 1024.0, + parSize / 1024.0 / 1024.0); + + // === Read Phase: Load and verify both indices === + System.out.println("\n=== Testing Read Correctness ==="); + System.out.println("Loading sequential index..."); + OnDiskGraphIndex sequentialIndex = OnDiskGraphIndex.load(ReaderSupplierFactory.open(sequentialPath)); + System.out.println("Loading parallel index..."); + OnDiskGraphIndex parallelIndex = OnDiskGraphIndex.load(ReaderSupplierFactory.open(parallelPath)); + + // Verify that both indices are identical + verifyIndicesIdentical(sequentialIndex, parallelIndex); + + // Close the loaded indices + sequentialIndex.close(); + parallelIndex.close(); + + } finally { + // Cleanup + builder.close(); + Files.deleteIfExists(sequentialPath); + Files.deleteIfExists(parallelPath); + Files.deleteIfExists(tempDir); + } + + System.out.println("\n✅ Test complete - sequential and parallel writes produce identical results!"); + } +}