Skip to content

Commit 4d22016

Browse files
authored
Fix binary vector types and conversion in $vectorSearch stage (#2871)
1 parent 7c5dbc1 commit 4d22016

File tree

4 files changed

+63
-38
lines changed

4 files changed

+63
-38
lines changed

lib/Doctrine/ODM/MongoDB/Aggregation/Stage/VectorSearch.php

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
use Doctrine\ODM\MongoDB\Aggregation\Stage;
99
use Doctrine\ODM\MongoDB\Persisters\DocumentPersister;
1010
use Doctrine\ODM\MongoDB\Query\Expr;
11+
use Doctrine\ODM\MongoDB\Types\Type;
1112
use InvalidArgumentException;
1213
use MongoDB\BSON\Binary;
1314
use MongoDB\BSON\Decimal128;
@@ -81,7 +82,7 @@ public function getExpression(): array
8182
}
8283

8384
if ($this->queryVector !== null) {
84-
$params['queryVector'] = $this->queryVector;
85+
$params['queryVector'] = Type::getType($this->persister->getClassMetadata()->fieldMappings[$this->path]['type'] ?? Type::RAW)->convertToDatabaseValue($this->queryVector);
8586
}
8687

8788
return [$this->getStageName() => $params];

lib/Doctrine/ODM/MongoDB/Types/AbstractVectorType.php

Lines changed: 12 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -77,27 +77,17 @@ public function closureToMongo(): string
7777
return str_replace('%%vectorType%%', $this->getVectorType()->name, <<<'PHP'
7878
if ($value === null) {
7979
$return = null;
80-
return;
81-
}
82-
83-
if (\is_array($value)) {
80+
} elseif (\is_array($value)) {
8481
$return = \MongoDB\BSON\Binary::fromVector($value, \MongoDB\BSON\VectorType::%%vectorType%%);
85-
return;
86-
}
87-
88-
if (! $value instanceof \MongoDB\BSON\Binary) {
82+
} elseif (! $value instanceof \MongoDB\BSON\Binary) {
8983
throw new InvalidArgumentException(sprintf('Invalid data type %s received for vector field, expected null, array or MongoDB\BSON\Binary', get_debug_type($value)));
90-
}
91-
92-
if ($value->getType() !== \MongoDB\BSON\Binary::TYPE_VECTOR) {
84+
} elseif ($value->getType() !== \MongoDB\BSON\Binary::TYPE_VECTOR) {
9385
throw new InvalidArgumentException(sprintf('Invalid binary data of type %d received for vector field, expected binary type %d', $value->getType(), \MongoDB\BSON\Binary::TYPE_VECTOR));
94-
}
95-
96-
if ($value->getVectorType() !== \MongoDB\BSON\VectorType::%%vectorType%%) {
86+
} elseif ($value->getVectorType() !== \MongoDB\BSON\VectorType::%%vectorType%%) {
9787
throw new \InvalidArgumentException(sprintf('Invalid binary vector data of vector type %s received for vector field, expected vector type %%vectorType%%', $value->getVectorType()->name));
88+
} else {
89+
$return = $value;
9890
}
99-
100-
$return = $value;
10191
PHP);
10292
}
10393

@@ -106,27 +96,17 @@ public function closureToPHP(): string
10696
return str_replace('%%vectorType%%', $this->getVectorType()->name, <<<'PHP'
10797
if ($value === null) {
10898
$return = null;
109-
return;
110-
}
111-
112-
if (\is_array($value)) {
99+
} elseif (\is_array($value)) {
113100
$return = $value;
114-
return;
115-
}
116-
117-
if (! $value instanceof \MongoDB\BSON\Binary) {
101+
} elseif (! $value instanceof \MongoDB\BSON\Binary) {
118102
throw new \InvalidArgumentException(sprintf('Invalid data of type "%s" received for vector field', get_debug_type($value)));
119-
}
120-
121-
if ($value->getType() !== \MongoDB\BSON\Binary::TYPE_VECTOR) {
103+
} elseif ($value->getType() !== \MongoDB\BSON\Binary::TYPE_VECTOR) {
122104
throw new \InvalidArgumentException(sprintf('Invalid binary data of type %d received for vector field', $value->getType()));
123-
}
124-
125-
if ($value->getVectorType() !== \MongoDB\BSON\VectorType::%%vectorType%%) {
105+
} elseif ($value->getVectorType() !== \MongoDB\BSON\VectorType::%%vectorType%%) {
126106
throw new \InvalidArgumentException(sprintf('Invalid binary vector data of vector type %s received for vector field, expected vector type %%vectorType%%', $value->getVectorType()->name));
107+
} else {
108+
$return = $value->toArray();
127109
}
128-
129-
$return = $value->toArray();
130110
PHP);
131111
}
132112

tests/Doctrine/ODM/MongoDB/Tests/Functional/VectorSearchTest.php

Lines changed: 39 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,12 @@
55
namespace Doctrine\ODM\MongoDB\Tests\Functional;
66

77
use Doctrine\ODM\MongoDB\Tests\BaseTestCase;
8+
use Doctrine\ODM\MongoDB\Types\Type;
89
use Documents\VectorEmbedding;
10+
use MongoDB\BSON\Binary;
911
use MongoDB\Driver\WriteConcern;
1012
use PHPUnit\Framework\Attributes\Group;
13+
use PHPUnit\Framework\Attributes\RequiresPhpExtension;
1114

1215
#[Group('atlas')]
1316
class VectorSearchTest extends BaseTestCase
@@ -45,9 +48,21 @@ public function testAtlasVectorSearch(): void
4548
// Index must be created after data insertion, so the index status is not immediately "READY"
4649
$schemaManager->createDocumentSearchIndexes(VectorEmbedding::class);
4750

48-
// Wait for search index to be ready (Atlas Local needs time to build the index)
51+
// Wait for the search index to be ready (Atlas Local needs time to build the index)
4952
$schemaManager->waitForSearchIndexes([VectorEmbedding::class]);
5053

54+
$results = $this->dm->createQueryBuilder(VectorEmbedding::class)->getQuery()->toArray();
55+
$this->assertCount(3, $results, 'All documents should be present in the collection');
56+
57+
foreach ($results as $result) {
58+
$this->assertInstanceOf(VectorEmbedding::class, $result);
59+
60+
$this->assertIsArray($result->vectorFloat);
61+
$this->assertCount(3, $result->vectorFloat);
62+
$this->assertIsArray($result->vectorInt);
63+
$this->assertCount(3, $result->vectorInt);
64+
}
65+
5166
$results = $this->dm->createAggregationBuilder(VectorEmbedding::class)
5267
->vectorSearch()
5368
->index('default')
@@ -68,6 +83,7 @@ public function testAtlasVectorSearch(): void
6883

6984
// Test with filter
7085
$results = ($builder = $this->dm->createAggregationBuilder(VectorEmbedding::class))
86+
->hydrate(VectorEmbedding::class)
7187
->vectorSearch()
7288
->index('vector_int')
7389
->queryVector([1, 1, 3])
@@ -79,8 +95,28 @@ public function testAtlasVectorSearch(): void
7995

8096
$this->assertCount(2, $results);
8197
foreach ($results as $result) {
82-
$this->assertIsArray($result);
83-
$this->assertEquals('active', $result['filterField'], 'Filtered results should only contain active documents');
98+
$this->assertInstanceOf(VectorEmbedding::class, $result);
99+
$this->assertEquals('active', $result->filterField, 'Filtered results should only contain active documents');
84100
}
85101
}
102+
103+
#[RequiresPhpExtension('mongodb', '>= 2.2')]
104+
public function testAtlasVectorSearchWithBinaryType(): void
105+
{
106+
$cm = $this->dm->getClassMetadata(VectorEmbedding::class);
107+
108+
$cm->fieldMappings['vectorFloat']['type'] = Type::VECTOR_FLOAT32;
109+
$cm->fieldMappings['vectorInt']['type'] = Type::VECTOR_INT8;
110+
111+
// Change the collection name to avoid conflicts with asynchronous index building
112+
$cm->collection .= '_binary_type';
113+
114+
$this->testAtlasVectorSearch();
115+
116+
// Ensure that the vectors are stored in as binary vectors
117+
$doc = $this->dm->getDocumentCollection(VectorEmbedding::class)->findOne(['filterField' => 'active']);
118+
$this->assertIsArray($doc);
119+
$this->assertInstanceOf(Binary::class, $doc['vectorInt']);
120+
$this->assertInstanceOf(Binary::class, $doc['db_vector_float']);
121+
}
86122
}

tests/Doctrine/ODM/MongoDB/Tests/Types/VectorTypeTest.php

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,13 +12,15 @@
1212
use PHPUnit\Framework\Attributes\RequiresPhpExtension;
1313
use PHPUnit\Framework\TestCase;
1414

15+
use function get_debug_type;
16+
1517
#[RequiresPhpExtension('mongodb', '>= 2.2')]
1618
class VectorTypeTest extends TestCase
1719
{
1820
#[DataProvider('providePhpVectors')]
1921
public function testConvertToDatabaseValue(string $name, mixed $value, mixed $expectedValue): void
2022
{
21-
$this->assertEquals($expectedValue, Type::getType($name)->convertToDatabaseValue($value));
23+
$this->assertSameTypeAndValue($expectedValue, Type::getType($name)->convertToDatabaseValue($value));
2224
}
2325

2426
#[DataProvider('providePhpVectors')]
@@ -27,7 +29,7 @@ public function testClosureToDatabase(string $name, mixed $value, mixed $expecte
2729
$return = $this;
2830
eval(Type::getType($name)->closureToMongo());
2931

30-
$this->assertEquals($expectedValue, $return);
32+
$this->assertSameTypeAndValue($expectedValue, $return);
3133
}
3234

3335
/** @return iterable<array{0: Type::VECTOR_*, 1: mixed, 2: mixed}> */
@@ -149,4 +151,10 @@ public static function providePHPValueException(): iterable
149151
yield [new Binary("\x03\x00\x01\x02\x03", Binary::TYPE_GENERIC), 'Invalid binary data of type 0 received for vector field'];
150152
yield [Binary::fromVector([1, 2, 3], VectorType::Int8), 'Invalid binary vector data of vector type Int8 received for vector field, expected vector type Float32'];
151153
}
154+
155+
private function assertSameTypeAndValue(mixed $expected, mixed $actual): void
156+
{
157+
$this->assertSame(get_debug_type($expected), get_debug_type($actual));
158+
$this->assertEquals($expected, $actual);
159+
}
152160
}

0 commit comments

Comments
 (0)