diff --git a/changes/6492.feature.md b/changes/6492.feature.md new file mode 100644 index 00000000000..5a8b2d5e288 --- /dev/null +++ b/changes/6492.feature.md @@ -0,0 +1 @@ +Introduce `BaseMinilangFilterConverter` and `BaseMinilangOrderParser` abstract base classes for converting minilang filter expressions and order strings to domain-specific Filter and OrderingOption objects diff --git a/src/ai/backend/common/exception.py b/src/ai/backend/common/exception.py index 0c889b6a6b8..0e2f69f1254 100644 --- a/src/ai/backend/common/exception.py +++ b/src/ai/backend/common/exception.py @@ -790,3 +790,55 @@ def error_code(cls) -> ErrorCode: operation=ErrorOperation.READ, error_detail=ErrorDetail.NOT_FOUND, ) + + +class InvalidParameter(BackendAIError, web.HTTPBadRequest): + error_type = "https://api.backend.ai/probs/invalid-parameter" + error_title = "Invalid Parameter" + + @classmethod + def error_code(cls) -> ErrorCode: + return ErrorCode( + domain=ErrorDomain.BACKENDAI, + operation=ErrorOperation.REQUEST, + error_detail=ErrorDetail.BAD_REQUEST, + ) + + +class ASTParsingFailed(BackendAIError, web.HTTPBadRequest): + error_type = "https://api.backend.ai/probs/ast-parsing-failed" + error_title = "AST Parsing Failed" + + @classmethod + def error_code(cls) -> ErrorCode: + return ErrorCode( + domain=ErrorDomain.BACKENDAI, + operation=ErrorOperation.PARSING, + error_detail=ErrorDetail.BAD_REQUEST, + ) + + +class UnsupportedOperation(BackendAIError, web.HTTPBadRequest): + error_type = "https://api.backend.ai/probs/unsupported-operation" + error_title = "Unsupported Operation" + + @classmethod + def error_code(cls) -> ErrorCode: + return ErrorCode( + domain=ErrorDomain.BACKENDAI, + operation=ErrorOperation.REQUEST, + error_detail=ErrorDetail.BAD_REQUEST, + ) + + +class UnsupportedFieldType(BackendAIError, web.HTTPBadRequest): + error_type = "https://api.backend.ai/probs/unsupported-field-type" + error_title = "Unsupported Field Type" + + @classmethod + def error_code(cls) -> ErrorCode: + return ErrorCode( + domain=ErrorDomain.BACKENDAI, + operation=ErrorOperation.REQUEST, + error_detail=ErrorDetail.BAD_REQUEST, + ) diff --git a/src/ai/backend/manager/api/types.py b/src/ai/backend/manager/api/types.py index 0a33b8acabf..b3a945ffcb5 100644 --- a/src/ai/backend/manager/api/types.py +++ b/src/ai/backend/manager/api/types.py @@ -1,20 +1,32 @@ from __future__ import annotations +from abc import ABC, abstractmethod +from enum import StrEnum from typing import ( TYPE_CHECKING, + Any, AsyncContextManager, Awaitable, Callable, + Generic, Iterable, Mapping, + Self, Tuple, + TypeVar, ) import aiohttp_cors from aiohttp import web from aiohttp.typedefs import Middleware +from lark import Tree, UnexpectedCharacters, UnexpectedToken +from lark.lexer import Token from typing_extensions import TypeAlias +from ai.backend.common.exception import ASTParsingFailed, InvalidParameter, UnsupportedOperation +from ai.backend.manager.api.gql.base import StringFilter +from ai.backend.manager.models.minilang.queryfilter import _parser as parser + if TYPE_CHECKING: from .context import RootContext @@ -32,3 +44,385 @@ ] CleanupContext: TypeAlias = Callable[["RootContext"], AsyncContextManager[None]] + + +class BaseMinilangFilterConverter(ABC): + """ + Abstract base class for converting minilang filter expressions to domain Filter objects. + + This class converts string-based filter expressions (minilang) into structured Filter + objects by parsing the AST and mapping fields to domain-specific filter types. + + Domain-specific Filter classes MUST inherit from this and implement + the _create_from_condition() method. + + Conversion flow: + 1. Parse minilang string expression → AST (Lark Tree) + 2. Convert AST → Filter object structure with field mappings + 3. Support logical operations (AND, OR) for complex filters + + Example: + @dataclass + class AgentFilter(BaseMinilangFilterConverter): + id: Optional[StringFilter] = None + status: Optional[AgentStatusFilter] = None + + AND: Optional[list[Self]] = None + OR: Optional[list[Self]] = None + NOT: Optional[list[Self]] = None + + @classmethod + def from_minilang(cls, expr: str) -> AgentFilter: + '''Convert minilang expression like "id ilike '%abc%' & status == 'ALIVE'" to AgentFilter''' + from ai.backend.manager.models.minilang.queryfilter import _parser as parser + ast = parser.parse(expr) + return cls._from_ast(ast) + + @classmethod + def _create_from_condition(cls, field: str, operator: str, value: Any) -> Self: + match field.lower(): + case "id": + return cls(id=cls._create_string_filter(operator, value)) + case _: + raise ValueError(f"Unsupported filter field: {field}") + """ + + @staticmethod + def _parse_value(value_tree: Tree | Token | list | Any) -> Any: + """ + Extract the actual value from the AST value node. + + Args: + value_tree: AST node representing a value (string, number, atom, array) + + Returns: + Parsed Python value (str, int, float, bool, None, or list) + """ + if isinstance(value_tree, Token): + # Direct token (e.g., string, number) + token_value: str = str(value_tree.value) + match value_tree.type: + case "ESCAPED_STRING": + # Remove quotes and unescape + return token_value[1:-1].replace('\\"', '"') + case "SIGNED_NUMBER": + # Parse number + if "." in token_value: + return float(token_value) + return int(token_value) + case "ATOM": + # Atom values like null, true, false + match token_value: + case "null": + return None + case "true": + return True + case "false": + return False + case _: + return value_tree.value + case _: + raise ASTParsingFailed(f"Unknown token type in value: {value_tree.type}") + + if isinstance(value_tree, Tree): + match value_tree.data: + case "string": + # String value + token = value_tree.children[0] + if isinstance(token, Token): + token_value = str(token.value) + return token_value[1:-1].replace('\\"', '"') + case "number": + # Number value + token = value_tree.children[0] + if isinstance(token, Token): + token_value = str(token.value) + if "." in token_value: + return float(token_value) + return int(token_value) + case "atom": + # Atomic value + token = value_tree.children[0] + if isinstance(token, Token): + token_value = str(token.value) + match token_value: + case "null": + return None + case "true": + return True + case "false": + return False + case "array": + # Array value + return [ + BaseMinilangFilterConverter._parse_value(child) + for child in value_tree.children + ] + case "value": + # Nested value wrapper + return BaseMinilangFilterConverter._parse_value(value_tree.children[0]) + case _: + raise ASTParsingFailed(f"Unknown tree type in value: {value_tree.data}") + + if isinstance(value_tree, list): + # Array of values + return [BaseMinilangFilterConverter._parse_value(v) for v in value_tree] + + return value_tree + + @staticmethod + def _create_string_filter(operator: str, value: Any) -> StringFilter: + """ + Create StringFilter from operator and value. + + Args: + operator: Operator string (e.g., "==", "!=", "like", "ilike", "contains") + value: String value to filter + + Returns: + StringFilter with appropriate filter type + """ + str_value = str(value) + + match operator: + case "==": + return StringFilter(equals=str_value) + case "!=": + return StringFilter(not_equals=str_value) + case "contains": + return StringFilter(contains=str_value) + case "like": + # Parse LIKE pattern: %value% -> contains, value% -> starts_with, %value -> ends_with + if str_value.startswith("%") and str_value.endswith("%"): + return StringFilter(contains=str_value[1:-1]) + if str_value.endswith("%"): + return StringFilter(starts_with=str_value[:-1]) + if str_value.startswith("%"): + return StringFilter(ends_with=str_value[1:]) + return StringFilter(equals=str_value) + case "ilike": + # Parse ILIKE pattern (case-insensitive) + if str_value.startswith("%") and str_value.endswith("%"): + return StringFilter(i_contains=str_value[1:-1]) + if str_value.endswith("%"): + return StringFilter(i_starts_with=str_value[:-1]) + if str_value.startswith("%"): + return StringFilter(i_ends_with=str_value[1:]) + return StringFilter(i_equals=str_value) + case _: + raise UnsupportedOperation(f"Unsupported string operator: {operator}") + + @classmethod + @abstractmethod + def _create_from_condition(cls, field: str, operator: str, value: Any) -> Self: + """ + Create Filter instance with single condition from field, operator, and value. + + Args: + field: Field name (e.g., "id", "status", "region") + operator: Operator string (e.g., "==", "!=", "ilike") + value: Filter value + + Returns: + Filter instance of the calling class type + + Raises: + ValueError: If field name is not supported + + Note: + Subclasses MUST implement this method to map field names + to their specific filter structure. + """ + raise NotImplementedError() + + @classmethod + def from_minilang(cls, expr: str) -> Self: + """ + Convert minilang expression string to Filter object. + """ + try: + ast = parser.parse(expr) + except UnexpectedToken as e: + raise ASTParsingFailed(f"Failed to parse minilang expression: {e}") + except UnexpectedCharacters as e: + raise UnsupportedOperation(f"Failed to parse minilang expression: {e}") + return cls._from_ast(ast) + + @classmethod + def _from_ast(cls, ast: Tree) -> Self: + """ + Convert Lark AST to Filter object recursively. + + This is a common implementation that all domain filters can use. + Subclasses must implement _create_from_condition(field, operator, value) + and have AND/OR/NOT fields for logical operations. + + Args: + ast: Parsed AST tree from minilang parser + + Returns: + Filter object of the calling class type + + Note: + The calling class must have: + - _create_from_condition(field: str, operator: str, value: Any) classmethod (abstract) + - AND: Optional[list[Self]] field + - OR: Optional[list[Self]] field + - NOT: Optional[list[Self]] field (optional) + """ + match ast.data: + case "binary_expr": + # Single filter expression: field operator value + field_token = ast.children[0] + operator_token = ast.children[1] + if not isinstance(field_token, Token) or not isinstance(operator_token, Token): + raise ASTParsingFailed("Invalid AST structure for binary expression") + + field_name = str(field_token.value) + operator = str(operator_token.value) + value = cls._parse_value(ast.children[2]) + + # Call domain-specific field mapping + return cls._create_from_condition(field_name, operator, value) + + case "combine_expr": + # Combined expression: expr & expr or expr | expr + left_tree = ast.children[0] + combine_op_token = ast.children[1] + right_tree = ast.children[2] + + if ( + not isinstance(left_tree, Tree) + or not isinstance(combine_op_token, Token) + or not isinstance(right_tree, Tree) + ): + raise ASTParsingFailed("Invalid AST structure for combine expression") + + left_filter = cls._from_ast(left_tree) + right_filter = cls._from_ast(right_tree) + combine_op = str(combine_op_token.value) + + match combine_op: + case "&": + return cls(AND=[left_filter, right_filter]) # type: ignore[call-arg] + case "|": + return cls(OR=[left_filter, right_filter]) # type: ignore[call-arg] + case _: + raise ASTParsingFailed(f"Unknown combine operator: {combine_op}") + + case "unary_expr": + raise ASTParsingFailed("NOT operator (!) is not supported in this implementation") + + case "paren_expr": + # Parenthesized expression: ( expr ) + return cls._from_ast(ast.children[0]) + + case _: + raise ASTParsingFailed(f"Unknown AST node type: {ast.data}") + + +TOrderField = TypeVar("TOrderField", bound=StrEnum) +TOrderingOptions = TypeVar("TOrderingOptions") + + +class BaseMinilangOrderConverter(ABC, Generic[TOrderField, TOrderingOptions]): + """ + Abstract base class for parsing order expressions into OrderingOption objects. + + This class converts string-based order expressions (e.g., "+id,-status") into + OrderingOption objects by parsing the field names, directions, and mapping + to domain-specific field enums. + + Conversion flow: + 1. Parse order string expression → dict of field names to ascending bool + 2. Convert field names → domain-specific field enums (e.g., AgentOrderField) + 3. Create OrderingOption with list of (field, desc) tuples + + Type Parameters: + TOrderField: Enum type for order field names (domain-specific) + TOrderingOptions: Domain-specific OrderingOption type (e.g., AgentOrderingOptions) + + Example: + class AgentOrderParser(BaseOrderParser[AgentOrderField, AgentOrderingOptions]): + + def _convert_field(self, parsed_expr: dict[str, bool]) -> dict[AgentOrderField, bool]: + return {AgentOrderField(name): asc for name, asc in parsed_expr.items()} + + def _create_ordering_option(self, order_by: dict[AgentOrderField, bool]) -> AgentOrderingOptions: + # Convert to list of (field, desc) tuples where desc = not ascending + order_list = [(field, not asc) for field, asc in order_by.items()] + return AgentOrderingOptions(order_by=order_list) + """ + + def _parse_raw_expr(self, order_expr: str) -> dict[str, bool]: + """ + Parse raw order expression string to dict of field name to ascending bool. + Args: + order_expr: Order expression string (e.g., "+id,-status") + Returns: + Dict mapping raw field names to ascending bool (True=asc, False=desc) + """ + + expr_list = [expr.strip() for expr in order_expr.split(",") if expr.strip()] + if not expr_list: + raise InvalidParameter("Order expression cannot be empty") + result = {} + for expr in expr_list: + # Check for prefix + if expr and expr[0] in ("+", "-"): + ascending = expr[0] == "+" + field_name = expr[1:].strip() + else: + # Default to ascending if no prefix + ascending = True + field_name = expr + if field_name == "": + raise InvalidParameter("Field name in order expression cannot be empty") + result[field_name] = ascending + return result + + @abstractmethod + def _convert_field(self, parsed_expr: dict[str, bool]) -> dict[TOrderField, bool]: + """ + Convert raw field names to domain-specific field enums. + + Args: + parsed_expr: Dict mapping raw field names to ascending bool + + Returns: + Dict mapping domain-specific field enums to ascending bool + """ + raise NotImplementedError() + + @abstractmethod + def _create_ordering_option(self, order_by: dict[TOrderField, bool]) -> TOrderingOptions: + """ + Create domain-specific OrderingOption object. + Args: + order_by: Dict mapping domain-specific field enums to ascending bool + Returns: + OrderingOption object with list of (field, desc) tuples + """ + raise NotImplementedError() + + def from_minilang(self, order_expr: str) -> TOrderingOptions: + """ + Parse order expression string from minilang to OrderingOption object. + + Format: "+field" or "-field" + - '+' or no prefix means ascending + - '-' means descending + + Args: + order_expr: Order expression string (e.g., "+id", "-status") + + Returns: + OrderingOption object + + Raises: + ValueError: If expression is invalid or empty + """ + raw_orders = self._parse_raw_expr(order_expr) + converted_orders = self._convert_field(raw_orders) + + return self._create_ordering_option(order_by=converted_orders) diff --git a/tests/manager/api/test_types.py b/tests/manager/api/test_types.py new file mode 100644 index 00000000000..90f665b4e1b --- /dev/null +++ b/tests/manager/api/test_types.py @@ -0,0 +1,386 @@ +from __future__ import annotations + +from dataclasses import dataclass +from enum import StrEnum +from typing import Any, Optional, Self, override + +import pytest + +from ai.backend.common.exception import ( + ASTParsingFailed, + InvalidParameter, + UnsupportedFieldType, + UnsupportedOperation, +) +from ai.backend.manager.api.gql.base import StringFilter +from ai.backend.manager.api.types import BaseMinilangFilterConverter, BaseMinilangOrderConverter + + +class TestBaseMinilangFilterConverter: + @pytest.fixture + def sample_filter_class(self) -> Any: + """Create a concrete implementation of BaseMinilangFilterConverter for testing.""" + + @dataclass + class SampleFilter(BaseMinilangFilterConverter): + id: Optional[StringFilter] = None + name: Optional[StringFilter] = None + email: Optional[StringFilter] = None + status: Optional[StringFilter] = None + active: Optional[StringFilter] = None + role: Optional[StringFilter] = None + department: Optional[StringFilter] = None + + AND: Optional[list[Self]] = None + OR: Optional[list[Self]] = None + NOT: Optional[list[Self]] = None + + @override + @classmethod + def _create_from_condition(cls, field: str, operator: str, value: Any) -> Self: + match field.lower(): + case "id": + return cls(id=cls._create_string_filter(operator, value)) + case "name": + return cls(name=cls._create_string_filter(operator, value)) + case "email": + return cls(email=cls._create_string_filter(operator, value)) + case "status": + return cls(status=cls._create_string_filter(operator, value)) + case "active": + return cls(active=cls._create_string_filter(operator, value)) + case "role": + return cls(role=cls._create_string_filter(operator, value)) + case "department": + return cls(department=cls._create_string_filter(operator, value)) + case _: + raise UnsupportedFieldType(f"Unsupported filter field: {field}") + + return SampleFilter + + @pytest.mark.parametrize( + ("query", "field", "attribute", "expected"), + [ + ('id == "user123"', "id", "equals", "user123"), + ('status != "inactive"', "status", "not_equals", "inactive"), + ('email == "user+test@example.com"', "email", "equals", "user+test@example.com"), + ('name ilike "%john%"', "name", "i_contains", "john"), + ('name like "%Test%"', "name", "contains", "Test"), + ('email like "%@Company.com"', "email", "ends_with", "@Company.com"), + ('name ilike "%래블업%"', "name", "i_contains", "래블업"), + ('name like "백엔드닷에이아이%"', "name", "starts_with", "백엔드닷에이아이"), + ("active == true", "active", "equals", "True"), + ("active == false", "active", "equals", "False"), + ("status == null", "status", "equals", "None"), + ('name like "test"', "name", "equals", "test"), # No wildcards + ('name like "%%"', "name", "contains", ""), # Both wildcards with empty string + ('name == ""', "name", "equals", ""), # Empty string + ('name == "test\\"name"', "name", "equals", 'test"name'), # Escaped quotes + ], + ) + def test_basic_operators( + self, sample_filter_class: Any, query: str, field: str, attribute: str, expected: str + ) -> None: + result = sample_filter_class.from_minilang(query) + + # Verify the specific field filter + field_filter = getattr(result, field) + assert field_filter is not None + + # Verify the expected attribute and value is set correctly + assert getattr(field_filter, attribute) == expected + + @pytest.mark.parametrize( + ("query", "expected_conditions"), + [ + ( + '(role == "admin") & (department == "Engineering")', + [("role", "equals", "admin"), ("department", "equals", "Engineering")], + ), + ( + '(email ilike "%@company.com%") & (status == "active") & (name ilike "%admin%")', + [ + ("email", "i_contains", "@company.com"), + ("status", "equals", "active"), + ("name", "i_contains", "admin"), + ], + ), + ( + '(name like "Admin%") & (role == "superuser") & (status != "active") & (department == "devops")', + [ + ("name", "starts_with", "Admin"), + ("role", "equals", "superuser"), + ("status", "not_equals", "active"), + ("department", "equals", "devops"), + ], + ), + ], + ) + def test_multiple_filters_with_and( + self, + sample_filter_class: Any, + query: str, + expected_conditions: list[tuple[str, str, str]], + ) -> None: + result = sample_filter_class.from_minilang(query) + + # Verify that result has AND structure + assert result.AND is not None + + # Recursively check all AND filters + def check_filter(f: Any, field: str, attribute: str, expected: str) -> bool: + # Check direct field + field_filter = getattr(f, field, None) + if (field_filter is not None) and (getattr(field_filter, attribute, None) == expected): + return True + + # Check nested AND recursively + if f.AND is not None: + for nested in f.AND: + if check_filter(nested, field, attribute, expected): + return True + + return False + + # Verify each expected condition + for field, attribute, expected in expected_conditions: + assert check_filter(result, field, attribute, expected), ( + f"Expected {field}.{attribute}={expected} not found in AND filters" + ) + + @pytest.mark.parametrize( + ("query", "expected_conditions"), + [ + ( + '(status == "invited") | (status != "registered")', + [("status", "equals", "invited"), ("status", "not_equals", "registered")], + ), + ( + '(department == "Sales") | (department ilike "Market%")', + [("department", "equals", "Sales"), ("department", "i_starts_with", "Market")], + ), + ( + '(status == "active") | (status != "pending") | (status == "invited")', + [ + ("status", "equals", "active"), + ("status", "not_equals", "pending"), + ("status", "equals", "invited"), + ], + ), + ( + '(role == "admin") | (role == "superadmin") | (role == "user") | (role == "monitor")', + [ + ("role", "equals", "admin"), + ("role", "equals", "superadmin"), + ("role", "equals", "user"), + ("role", "equals", "monitor"), + ], + ), + ], + ) + def test_multiple_filters_with_or( + self, + sample_filter_class: Any, + query: str, + expected_conditions: list[tuple[str, str, str]], + ) -> None: + result = sample_filter_class.from_minilang(query) + + # Verify that result has OR structure + assert result.OR is not None + + # Recursively check all OR filters + def check_filter(f: Any, field: str, attribute: str, expected: str) -> bool: + # Check direct field + field_filter = getattr(f, field, None) + if field_filter is not None and getattr(field_filter, attribute, None) == expected: + return True + + # Check nested OR recursively + if f.OR is not None: + for nested in f.OR: + if check_filter(nested, field, attribute, expected): + return True + + return False + + # Verify each expected condition + for field, attribute, expected in expected_conditions: + assert check_filter(result, field, attribute, expected), ( + f"Expected {field}.{attribute}={expected} not found in OR filters" + ) + + @pytest.mark.parametrize( + ("query", "and_conditions", "or_conditions"), + [ + ( + '(name ilike "%admin%") & ((status == "active") | (status != "pending"))', + [("name", "i_contains", "admin")], + [("status", "equals", "active"), ("status", "not_equals", "pending")], + ), + ( + '(email ilike "%@company.com") & ((role == "admin") | (role != "manager"))', + [("email", "i_ends_with", "@company.com")], + [("role", "equals", "admin"), ("role", "not_equals", "manager")], + ), + ( + '(active == true) & ((department ilike "%Sales%") | (department == "Marketing"))', + [("active", "equals", "True")], + [("department", "i_contains", "Sales"), ("department", "equals", "Marketing")], + ), + ( + '(id == "user123") & ((status == "active") | (status == "invited"))', + [("id", "equals", "user123")], + [("status", "equals", "active"), ("status", "equals", "invited")], + ), + ], + ) + def test_mixed_and_or_filters( + self, + sample_filter_class: Any, + query: str, + and_conditions: list[tuple[str, str, str]], + or_conditions: list[tuple[str, str, str]], + ) -> None: + result = sample_filter_class.from_minilang(query) + + # Recursively check all filters (both AND and OR chains) + def check_filter(f: Any, field: str, attribute: str, expected: str) -> bool: + # Check direct field + field_filter = getattr(f, field, None) + if field_filter is not None and getattr(field_filter, attribute, None) == expected: + return True + + # Check nested AND + if f.AND is not None: + for nested in f.AND: + if check_filter(nested, field, attribute, expected): + return True + + # Check nested OR + if f.OR is not None: + for nested in f.OR: + if check_filter(nested, field, attribute, expected): + return True + + return False + + # Verify AND conditions + for field, attribute, expected in and_conditions: + assert check_filter(result, field, attribute, expected), ( + f"AND condition {field}.{attribute}={expected} not found" + ) + + # Verify OR conditions + for field, attribute, expected in or_conditions: + assert check_filter(result, field, attribute, expected), ( + f"OR condition {field}.{attribute}={expected} not found" + ) + + @pytest.mark.parametrize( + ("query", "expected_exception"), + [ + ('invalid_field == "test"', UnsupportedFieldType), + ("", ASTParsingFailed), # Empty string + (" ", ASTParsingFailed), # Whitespace only + ("name == ", ASTParsingFailed), # Missing value + ('name "test"', ASTParsingFailed), # Missing operator + ('== "test"', ASTParsingFailed), # Missing field + ('(name == "test"', ASTParsingFailed), # Unmatched left parenthesis + ('name == "test")', ASTParsingFailed), # Unmatched right parenthesis + ('name === "test"', UnsupportedOperation), # Invalid operator + ( + 'name == "test" & & status == "active"', + ASTParsingFailed, + ), # Consecutive logical operators + ('name == "test" &', ASTParsingFailed), # Trailing logical operator + ('& name == "test"', ASTParsingFailed), # Leading logical operator + ], + ) + def test_malformed_expressions_raise_error( + self, sample_filter_class: Any, query: str, expected_exception: type[Exception] + ) -> None: + with pytest.raises(expected_exception): + sample_filter_class.from_minilang(query) + + +class TestBaseMinilangOrderConverter: + @pytest.fixture + def sample_order_converter(self) -> Any: + class SampleOrderField(StrEnum): + ID = "id" + NAME = "name" + STATUS = "status" + CREATED_AT = "created_at" + UPDATED_AT = "updated_at" + EMAIL = "email" + + @dataclass + class SampleOrderingOptions: + order_by: list[tuple[SampleOrderField, bool]] + + class SampleOrderConverter( + BaseMinilangOrderConverter[SampleOrderField, SampleOrderingOptions] + ): + @override + def _convert_field(self, parsed_expr: dict[str, bool]) -> dict[SampleOrderField, bool]: + result = {} + for name, asc in parsed_expr.items(): + try: + result[SampleOrderField(name)] = asc + except ValueError: + raise InvalidParameter(f"Invalid field name: {name}") + return result + + @override + def _create_ordering_option( + self, order_by: dict[SampleOrderField, bool] + ) -> SampleOrderingOptions: + order_list = [(field, not asc) for field, asc in order_by.items()] + return SampleOrderingOptions(order_by=order_list) + + return SampleOrderConverter() + + @pytest.mark.parametrize( + ("query", "expected_fields", "expected_order_condition_count"), + [ + ("+id", {"id": False}, 1), + ("+name,+status", {"name": False, "status": False}, 2), + (" +name,-status ", {"name": False, "status": True}, 2), # With whitespace + ("+name,-created_at,status", {"name": False, "created_at": True, "status": False}, 3), + ( + "+id , -name , status", + {"id": False, "name": True, "status": False}, + 3, + ), # With whitespace + ], + ) + def test_multiple_fields_ordering( + self, + sample_order_converter: Any, + query: str, + expected_fields: dict[str, bool], + expected_order_condition_count: int, + ) -> None: + result = sample_order_converter.from_minilang(query) + assert len(result.order_by) == expected_order_condition_count + + field_to_desc = {field.value: desc for field, desc in result.order_by} + assert field_to_desc == expected_fields + + @pytest.mark.parametrize( + ("query", "error_match"), + [ + ("", "Order expression cannot be empty"), + (" ", "Order expression cannot be empty"), + ("+", "Field name in order expression cannot be empty"), + ("-", "Field name in order expression cannot be empty"), + ("+invalid_field", "Invalid field name"), + ("+name,-invalid,-status", "Invalid field name"), + ], + ) + def test_order_error_cases( + self, sample_order_converter: Any, query: str, error_match: str + ) -> None: + with pytest.raises(InvalidParameter, match=error_match): + sample_order_converter.from_minilang(query)