Skip to content

Commit 1762bd0

Browse files
committed
fead: Implement In-Tree Embedding Similarity Matching
Signed-off-by: Sophie8 <[email protected]>
1 parent f66c341 commit 1762bd0

File tree

1 file changed

+158
-0
lines changed

1 file changed

+158
-0
lines changed

src/semantic-router/pkg/api/server.go

Lines changed: 158 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -162,6 +162,44 @@ type BatchSimilarityRequest struct {
162162
LatencyPriority float32 `json:"latency_priority,omitempty"` // 0.0-1.0, only for "auto" model
163163
}
164164

165+
// KeywordSimilarityMatch represents a single match in batch similarity matching
166+
type KeywordSimilarityMatch struct {
167+
Index int `json:"index"` // Index of the candidate in the input array
168+
SimilarityThreshold float32 `json:"similarity_threshold"` // threshold as matched for the cosine similarity score between query and the keyword
169+
Keyword string `json:"text"` // The keyword to calculate similarity with
170+
}
171+
172+
// KeywordSimilarityMatchrResponse represents a single match in batch similarity matching
173+
type KeywordSimilarityMatchResponse struct {
174+
Index int `json:"index"` // Index of the candidate in the input array
175+
SimilarityThreshold float32 `json:"similarity_threshold"` // threshold as matched for the cosine similarity score between query and the keyword
176+
SimilarityCalculated float32 `json:"similarity_calculated"` // threshold as matched for the cosine similarity score between query and the keyword
177+
Keyword string `json:"text"` // The keyword to calculate similarity with
178+
Matched bool `json:"matched"` // The query matched the keyword or not
179+
ModelUsed string `json:"model_used"` // "qwen3", "gemma", or "unknown"
180+
ProcessingTimeMs float32 `json:"processing_time_ms"` // Processing time in milliseconds
181+
}
182+
183+
// BatchEmbeddingSimilarityMatchRequest represents a request to find the similarity between a query and configurable keywords
184+
type BatchEmbeddingSimilarityMatchRequest struct {
185+
Query string `json:"query"` // Query text
186+
Keywords []string `json:"keywords"` // Array of keyword texts
187+
Model string `json:"model,omitempty"` // "auto" (default), "qwen3", "gemma"
188+
Dimension int `json:"dimension,omitempty"` // Target dimension: 768 (default), 512, 256, 128
189+
SimilarityThresholds []KeywordSimilarityMatch `json:"similarity_thresholds"` // Configurable thresholds per keyword (e.g. keyword A: 80%, keyword B: 60%)
190+
AggregationMethod string `json:"aggregation_method"` // Aggregation method to pick the best matched category, support max now. Placeholder for further extension
191+
QualityPriority float32 `json:"quality_priority,omitempty"` // 0.0-1.0, only for "auto" model
192+
LatencyPriority float32 `json:"latency_priority,omitempty"` // 0.0-1.0, only for "auto" model
193+
}
194+
195+
// BatchEmbeddingSimilarityMatchResponse represents a response to find the similarity between a query and configurable keywords
196+
type BatchEmbeddingSimilarityMatchResponse struct {
197+
Query string `json:"query"` // Query text
198+
KeywordMatches []KeywordSimilarityMatchResponse `json:"keyword_matches"` // Array of KeywordSimilarityMatchResponse
199+
AggregationMethod string `json:"aggregation_method"` // Aggregation method to pick the best matched category, support max now. Placeholder for further extension
200+
BestMatchedCategory string `json:"best_matched_category"` // The best matched category based on the aggregation method above
201+
}
202+
165203
// BatchSimilarityMatch represents a single match in batch similarity matching
166204
type BatchSimilarityMatch struct {
167205
Index int `json:"index"` // Index of the candidate in the input array
@@ -691,6 +729,12 @@ func (s *ClassificationAPIServer) handleCombinedClassification(w http.ResponseWr
691729
s.writeErrorResponse(w, http.StatusNotImplemented, "NOT_IMPLEMENTED", "Combined classification not implemented yet")
692730
}
693731

732+
// Placeholder funtion to fusion all the designed internal signal providers: Keyword matcher, reges scanner, embedding similarity, BERT classifier
733+
// func (s *ClassificationAPIServer) handleAllInTreeSinganlProviders(w http.ResponseWriter, _ *http.Request) {
734+
// response, err := s.classificationSvc.handleBatchClassification(..)
735+
736+
// }
737+
694738
func (s *ClassificationAPIServer) handleBatchClassification(w http.ResponseWriter, r *http.Request) {
695739
// Record batch classification request
696740
metrics.RecordBatchClassificationRequest("unified")
@@ -1602,3 +1646,117 @@ func (s *ClassificationAPIServer) handleBatchSimilarity(w http.ResponseWriter, r
16021646

16031647
s.writeJSONResponse(w, http.StatusOK, response)
16041648
}
1649+
1650+
// handleBatchInTreeSimilarityMatching handles batch embedding based similarity matching requests
1651+
func (s *ClassificationAPIServer) handleBatchInTreeSimilarityMatching(w http.ResponseWriter, r *http.Request) {
1652+
// Parse request
1653+
var req BatchEmbeddingSimilarityMatchRequest
1654+
if err := s.parseJSONRequest(r, &req); err != nil {
1655+
s.writeErrorResponse(w, http.StatusBadRequest, "INVALID_INPUT", err.Error())
1656+
return
1657+
}
1658+
1659+
// Validate input
1660+
if req.Query == "" {
1661+
s.writeErrorResponse(w, http.StatusBadRequest, "INVALID_INPUT", "query must be provided")
1662+
return
1663+
}
1664+
if len(req.Keywords) == 0 {
1665+
s.writeErrorResponse(w, http.StatusBadRequest, "INVALID_INPUT", "keyword array cannot be empty")
1666+
return
1667+
}
1668+
1669+
// Set defaults
1670+
if req.Model == "" {
1671+
req.Model = "auto"
1672+
}
1673+
if req.Dimension == 0 {
1674+
req.Dimension = 768 // Default to full dimension
1675+
}
1676+
if req.Model == "auto" && req.QualityPriority == 0 && req.LatencyPriority == 0 {
1677+
req.QualityPriority = 0.5
1678+
req.LatencyPriority = 0.5
1679+
}
1680+
1681+
// Validate dimension
1682+
validDimensions := map[int]bool{128: true, 256: true, 512: true, 768: true, 1024: true}
1683+
if !validDimensions[req.Dimension] {
1684+
s.writeErrorResponse(w, http.StatusBadRequest, "INVALID_DIMENSION",
1685+
fmt.Sprintf("dimension must be one of: 128, 256, 512, 768, 1024 (got %d)", req.Dimension))
1686+
return
1687+
}
1688+
1689+
// Calculate batch similarity
1690+
result, err := candle_binding.CalculateSimilarityBatch(
1691+
req.Query,
1692+
req.Keywords,
1693+
0, // return scores for all the keywords
1694+
req.Model,
1695+
req.Dimension,
1696+
)
1697+
if err != nil {
1698+
s.writeErrorResponse(w, http.StatusInternalServerError, "BATCH_SIMILARITY_FAILED",
1699+
fmt.Sprintf("failed to calculate batch similarity: %v", err))
1700+
return
1701+
}
1702+
1703+
// Build embedding based similarity response
1704+
matches := make([]KeywordSimilarityMatchResponse, len(result.Matches))
1705+
for i, match := range result.Matches {
1706+
if match.Similarity >= req.SimilarityThresholds[i].SimilarityThreshold {
1707+
matches[i] = KeywordSimilarityMatchResponse{
1708+
Index: match.Index,
1709+
SimilarityThreshold: req.SimilarityThresholds[i].SimilarityThreshold,
1710+
SimilarityCalculated: match.Similarity,
1711+
Keyword: req.Keywords[match.Index],
1712+
Matched: true,
1713+
ModelUsed: result.ModelType,
1714+
ProcessingTimeMs: result.ProcessingTimeMs,
1715+
}
1716+
} else {
1717+
matches[i] = KeywordSimilarityMatchResponse{
1718+
Index: match.Index,
1719+
SimilarityThreshold: req.SimilarityThresholds[i].SimilarityThreshold,
1720+
SimilarityCalculated: match.Similarity,
1721+
Keyword: req.Keywords[match.Index],
1722+
Matched: false,
1723+
ModelUsed: result.ModelType,
1724+
ProcessingTimeMs: result.ProcessingTimeMs,
1725+
}
1726+
}
1727+
}
1728+
// Validate input
1729+
if req.AggregationMethod != "" && req.AggregationMethod != "max" {
1730+
s.writeErrorResponse(w, http.StatusBadRequest, "INVALID_INPUT", "Aggregation method only supports max now")
1731+
return
1732+
}
1733+
var aggregationMethod string
1734+
// Set default value
1735+
if req.AggregationMethod == "" {
1736+
aggregationMethod = "max"
1737+
}
1738+
// Support mean/max/any aggregation methods to find the best match
1739+
var bestMatchedCategory string
1740+
var bestScore float32
1741+
if aggregationMethod == "max" {
1742+
// pick the most matched category based on max of all cosine similarity scores
1743+
for _, match := range matches {
1744+
if match.SimilarityCalculated > bestScore {
1745+
bestScore = match.SimilarityCalculated
1746+
bestMatchedCategory = match.Keyword
1747+
}
1748+
}
1749+
}
1750+
// Make Response
1751+
response := BatchEmbeddingSimilarityMatchResponse{
1752+
Query: req.Query,
1753+
KeywordMatches: matches,
1754+
AggregationMethod: aggregationMethod,
1755+
BestMatchedCategory: bestMatchedCategory,
1756+
}
1757+
1758+
observability.Infof("Calculated batch embedding similarity: query='%s', %d keywords, (model: %s, took: %.2fms)",
1759+
req.Query, len(req.Keywords), result.ModelType, result.ProcessingTimeMs)
1760+
1761+
s.writeJSONResponse(w, http.StatusOK, response)
1762+
}

0 commit comments

Comments
 (0)