diff --git a/Cargo.lock b/Cargo.lock index b5a8548..3774a93 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -73,6 +73,12 @@ version = "0.8.7" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "773648b94d0e5d620f64f280777445740e61fe701025087ec8b57f45c791888b" +[[package]] +name = "equivalent" +version = "1.0.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "877a4ace8713b0bcf2a4e7eec82529c029f1d0619886d18145fea96c3ffe5c0f" + [[package]] name = "getrandom" version = "0.3.3" @@ -85,6 +91,12 @@ dependencies = [ "wasi", ] +[[package]] +name = "hashbrown" +version = "0.16.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5419bdc4f6a9207fbeba6d11b604d481addf78ecd10c11ad51e76c2f6482748d" + [[package]] name = "heck" version = "0.5.0" @@ -115,6 +127,16 @@ dependencies = [ "cc", ] +[[package]] +name = "indexmap" +version = "2.12.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6717a8d2a5a929a1a2eb43a12812498ed141a0bcfb7e8f7844fbdbe4303bba9f" +dependencies = [ + "equivalent", + "hashbrown", +] + [[package]] name = "indoc" version = "2.0.6" @@ -376,9 +398,10 @@ checksum = "69cdb34c158ceb288df11e18b4bd39de994f6657d83847bdffdbd7f346754b0f" [[package]] name = "rapidquery" -version = "0.1.0-alpha1" +version = "0.1.0-alpha5" dependencies = [ "chrono", + "indexmap", "ipnetwork", "mac_address", "once_cell", diff --git a/Cargo.toml b/Cargo.toml index 8b63083..a574fdb 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "rapidquery" -version = "0.1.0-alpha1" +version = "0.1.0-alpha5" edition = "2021" description = "RapiQuery is the fastest, full-feature, and easy-to-use Python SQL query builder written in Rust." readme = "README.md" @@ -13,9 +13,6 @@ authors = ["awolverp"] name = "rapidquery" crate-type = ["cdylib"] -[features] -optimize = [] - [profile.dev] lto = true panic = "unwind" @@ -38,6 +35,7 @@ chrono = { version = "0.4.27", default-features = false, features = ["clock"] } serde_json = { version = "1", default-features = false, features = ["std"] } rust_decimal = { version = "1.38.0", default-features = false } once_cell = { version = "1.21.3", default-features = false, features = ["parking_lot"]} +indexmap = { version = "2.12.0", default-features = false, features = ["std"]} [dependencies.parking_lot] version = "0.12.4" diff --git a/Makefile b/Makefile index c51a544..0f074cf 100644 --- a/Makefile +++ b/Makefile @@ -19,7 +19,6 @@ build-prod: $(BUILD_CMD) --uv --release test: - cargo clippy $(BUILD_CMD) --uv pytest -s -vv -rm -rf .pytest_cache diff --git a/README.md b/README.md index 8d53df1..bb64986 100644 --- a/README.md +++ b/README.md @@ -70,6 +70,7 @@ stmt.to_sql("postgres") 5. Advanced Usage 1. [**ORM-like**](#orm-like) 2. [**Table Alias**](#table-alias) + 2. [**Complex Examples**](#complex-examples) 6. Performance 1. [**Benchmarks**](#benchmarks) 2. [**Performance Tips**](#performance-tips) @@ -260,8 +261,8 @@ sql, params = query.build("postgresql") query = ( rq.Select( - rq.SelectExpr(rq.FunctionCall.count(rq.ASTERISK), alias="total_customers"), - rq.SelectExpr(rq.FunctionCall.avg(rq.Expr.col("age")), alias="average_age"), + rq.SelectCol(rq.FunctionCall.count(rq.ASTERISK), alias="total_customers"), + rq.SelectCol(rq.FunctionCall.avg(rq.Expr.col("age")), alias="average_age"), ) .from_table("customers") ) @@ -269,48 +270,6 @@ sql, params = query.build("postgresql") # -> SELECT COUNT(*) AS "total_customers", AVG("age") AS "average_age" FROM "customers" ``` -**Complex** -```python -# This query would be easier to create by using `AliasedTable` class, -# which introduced in "Advanced" part of this page -query = ( - rq.Select( - rq.Expr.col("c.customer_name"), - rq.SelectExpr( - rq.FunctionCall.count(rq.Expr.col("o.order_id")), - "total_orders" - ), - rq.SelectExpr( - rq.FunctionCall.sum(rq.Expr.col("oi.quantity") * rq.Expr.col("oi.unit_price")), - "total_spent" - ), - ) - .from_table(rq.TableName("customers", alias="c")) - .join( - rq.TableName("orders", alias="o"), - rq.Expr.col("c.customer_id") == rq.Expr.col("o.customer_id"), - type="left" - ) - .join( - rq.TableName("order_items", alias="oi"), - rq.Expr.col("o.order_id") == rq.Expr.col("oi.order_id"), - type="left" - ) - .where( - rq.Expr.col("o.order_date") >= (datetime.datetime.now() - datetime.timedelta(days=360)) - ) -) -sql, params = query.build("postgresql") -# SELECT -# "c"."customer_name", -# COUNT("o"."order_id") AS "total_orders", -# SUM("oi"."quantity" * "oi"."unit_price") AS "total_spent" -# FROM "customers" AS "c" -# LEFT JOIN "orders" AS "o" ON "c"."customer_id" = "o"."customer_id" -# LEFT JOIN "order_items" AS "oi" ON "o"."order_id" = "oi"."order_id" -# WHERE "o"."order_date" >= $1 -``` - #### Query Insert `Insert` provides a chainable API for constructing INSERT queries with support for: - Single or multiple row insertion @@ -364,7 +323,7 @@ query = ( ) sql, params = query.build("postgresql") # INSERT INTO "users" ("username", "role") VALUES ($1, $2) -# ON CONFLICT ("id") DO UPDATE SET "author" = $3 +# ON CONFLICT ("id") DO UPDATE SET "role" = $3 ``` #### Query Update @@ -660,17 +619,17 @@ employees = rq.Table( query = ( rq.Select( employees.c.id.to_column_ref().copy_with(table="emp"), - rq.SelectExpr( + rq.SelectCol( employees.c.name.to_column_ref().copy_with(table="emp"), "employee_name", ), employees.c.job_title.to_column_ref().copy_with(table="emp"), - rq.SelectExpr(employees.c.id.to_column_ref().copy_with(table="mgr"), "manager_id"), - rq.SelectExpr( + rq.SelectCol(employees.c.id.to_column_ref().copy_with(table="mgr"), "manager_id"), + rq.SelectCol( employees.c.name.to_column_ref().copy_with(table="mgr"), "employee_name", ), - rq.SelectExpr( + rq.SelectCol( employees.c.job_title.to_column_ref().copy_with(table="mgr"), "manager_title" ), ) @@ -703,11 +662,11 @@ mgr = rq.AliasedTable(employees, "mgr") query = ( rq.Select( emp.c.id, - rq.SelectExpr(emp.c.name, "employee_name"), + rq.SelectCol(emp.c.name, "employee_name"), emp.c.job_title, - rq.SelectExpr(emp.c.id, "manager_id"), - rq.SelectExpr(emp.c.name, "employee_name"), - rq.SelectExpr(emp.c.job_title, "manager_title"), + rq.SelectCol(emp.c.id, "manager_id"), + rq.SelectCol(emp.c.name, "employee_name"), + rq.SelectCol(emp.c.job_title, "manager_title"), ) .from_table(emp) .join( @@ -727,100 +686,132 @@ sql, params = query.build("postgresql") As you saw, it's much simpler. -### Performance -#### Benchmarks - -> [!NOTE] -> Benchmarks run on *Linux-6.15.11-2-MANJARO-x86_64-with-glibc2.42* with CPython 3.13. Your results may vary. - -**Generating Select Query 100,000x times** -```python -# RapidQuery -query = rq.Select(rq.Expr.asterisk()).from_table("users").where(rq.Expr.col("name").like(r"%linus%")) \ - .offset(20).limit(20) - -query.to_sql('postgresql') -# PyPika -query = pypika.Query.from_("users").where(pypika.Field("name").like(r"%linus%")) \ - .offset(20).limit(20).select("*") - -str(query) +#### Complex Examples +There are some complex examples for `SELECT` query. + +RapidQuery: +```python +rq.Select( + rq.Expr.col("account_number"), + rq.Expr.col("transaction_date"), + rq.Expr.col("transaction_type"), + rq.Expr.col("amount"), + rq.SelectCol( + rq.Case() + .when(rq.Expr.col("transaction_type") == "DEBIT", -rq.Expr.col("amount")) + .else_(rq.Expr.col("amount")), + alias="signed_amount", + ), + rq.SelectCol( + rq.FunctionCall.sum( + rq.Case() + .when(rq.Expr.col("transaction_type") == "DEBIT", -rq.Expr.col("amount")) + .else_(rq.Expr.col("amount")) + ), + alias="running_balance", + window="account_window", + ), + rq.SelectCol( + rq.FunctionCall.avg(rq.Expr.col("amount")), + alias="avg_transaction_by_type", + window=rq.Window(rq.Expr.col("account_number"), rq.Expr.col("transaction_type")), + ), + rq.SelectCol( + rq.FunctionCall.percent_rank(), + alias="amount_percentile", + window=rq.Window(rq.Expr.col("account_number")).order_by(rq.Expr.col("amount"), "desc"), + ), +) +.from_table("bank_transactions") +.where( + rq.Expr.col("transaction_date").between( + rq.Expr.custom("'2024-01-01'"), rq.Expr.custom("'2024-12-31'") + ) +) +.window( + "account_window", + rq.Window(rq.Expr.col("account_number")) + .order_by(rq.Expr.col("transaction_date"), "desc") + .order_by(rq.Expr.col("transaction_id"), "desc") + .frame("rows", rq.WindowFrame.unbounded_preceding(), rq.WindowFrame.current_row()), +) ``` -``` -RapidQuery: 254ms -PyPika: 3983ms +SQL: +```sql +SELECT + "account_number", + "transaction_date", + "transaction_type", + "amount", + (CASE WHEN ("transaction_type" = 'DEBIT') THEN "amount" * -1 ELSE "amount" END) AS "signed_amount", + SUM( + (CASE WHEN ("transaction_type" = 'DEBIT') THEN "amount" * -1 ELSE "amount" END) + ) OVER "account_window" AS "running_balance", + AVG("amount") OVER (PARTITION BY "account_number", "transaction_type") AS "avg_transaction_by_type", + PERCENT_RANK() OVER ( + PARTITION BY "account_number" ORDER BY "amount" DESC + ) AS "amount_percentile" +FROM "bank_transactions" +WHERE + "transaction_date" BETWEEN ('2024-01-01') AND ('2024-12-31') +WINDOW + "account_window" AS ( + PARTITION BY "account_number" + ORDER BY "transaction_date" DESC, "transaction_id" DESC + ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW + ) ``` -**Generating Insert Query 100,000x times** -```python -# RapidQuery -query = rq.Insert().into("glyph").columns("aspect", "image") \ - .values(5.15, "12A") \ - .values(16, "14A") \ - .returning("id") +### Performance +#### Benchmarks -query.to_sql('postgresql') +> [!NOTE] +> Benchmarks run on *Linux-6.15.11-2-MANJARO-x86_64-with-glibc2.42* with CPython 3.13. Your results may vary. -# PyPika -query = pypika.Query.into("glyph").columns("aspect", "image") \ - .insert(5.15, "12A") \ - .insert(16, "14A") +**Iterations per test:** 100,000 +**Python version:** 3.13.7 -str(query) -``` +--- -``` -RapidQuery: 267ms -PyPika: 4299ms -``` +**šŸ“Š SELECT Query Benchmark** -**Generating Update Query 100,000x times** -```python -# RapidQuery -query = rq.Update().table("wallets").values(amount=rq.Expr.col("amount") + 10).where(rq.Expr.col("id").between(10, 30)) +| Library | Time (ms) | vs Fastest | Status | +|---------|-----------|------------|--------| +| RapidQuery | 247.79 | 1.00x (FASTEST) | šŸ† | +| PyPika | 4030.62 | 16.27x slower | | +| SQLAlchemy | 9238.36 | 37.28x slower | | -query.to_sql('postgresql') +--- -# PyPika -query = pypika.Query.update("wallets").set("amount", pypika.Field("amount") + 10) \ - .where(pypika.Field("id").between(10, 30)) +**šŸ“Š INSERT Query Benchmark** -str(query) -``` +| Library | Time (ms) | vs Fastest | Status | +|---------|-----------|------------|--------| +| RapidQuery | 275.13 | 1.00x (FASTEST) | šŸ† | +| PyPika | 4268.81 | 15.52x slower | | +| SQLAlchemy | 6849.45 | 24.90x slower | | -``` -RapidQuery: 252ms -PyPika: 4412ms -``` +--- -**Generating Delete Query 100,000x times** -```python -# RapidQuery -query = rq.Delete().from_table("users") \ - .where( - rq.all( - rq.Expr.col("id") > 10, - rq.Expr.col("id") < 30, - ) - ) \ - .limit(10) +**šŸ“Š UPDATE Query Benchmark** -query.to_sql('postgresql') +| Library | Time (ms) | vs Fastest | Status | +|---------|-----------|------------|--------| +| RapidQuery | 270.03 | 1.00x (FASTEST) | šŸ† | +| PyPika | 4450.08 | 16.48x slower | | +| SQLAlchemy | 11637.68 | 43.10x slower | | -# PyPika -query = pypika.Query.from_("users") \ - .where((pypika.Field("id") > 10) & (pypika.Field("id") < 30)) \ - .limit(10).delete() +--- -str(query) -``` +**šŸ“Š DELETE Query Benchmark** -``` -RapidQuery: 240ms -PyPika: 4556ms -``` +| Library | Time (ms) | vs Fastest | Status | +|---------|-----------|------------|--------| +| RapidQuery | 242.09 | 1.00x (FASTEST) | šŸ† | +| PyPika | 4154.24 | 17.16x slower | | +| SQLAlchemy | 7873.16 | 32.52x slower | | #### Performance Tips - Using [`ORM-like`](#orm-like) is always slower than using `Expr.col` and literal `str` diff --git a/benchmarks.py b/benchmarks.py new file mode 100644 index 0000000..08a343a --- /dev/null +++ b/benchmarks.py @@ -0,0 +1,232 @@ +import rapidquery as rq +import sqlalchemy as sa +from sqlalchemy.dialects.postgresql import dialect +import pypika +import typing +import time +import sys + + +# Postgres dialect is faster than other dialects (according to benchmarks) +# and also providing dialect in SQLALchemy can make it faster +SA_DIALECT = dialect() + +# Benchmark configuration +ITERATIONS = 100_000 +WARMUP_ITERATIONS = 1000 + + +def benchmark(func: typing.Callable, number=ITERATIONS) -> float: + for _ in range(min(WARMUP_ITERATIONS, number // 10)): + func() + + perf = time.perf_counter_ns() + for _ in range(number): + func() + perf = time.perf_counter_ns() - perf + + return perf / 1000000 + + +def format_results(results: typing.Dict[str, float]) -> str: + if not results: + return "No results to display" + + # Find fastest time + fastest = min(results.values()) + + lines = [] + lines.append("-" * 70) + lines.append(f"{'Library':<20} {'Time (ms)':<15} {'vs Fastest':<15} {'Status':<20}") + lines.append("-" * 70) + + for lib, time_ms in sorted(results.items(), key=lambda x: x[1]): + if time_ms == fastest: + ratio = "1.00x (FASTEST)" + status = "šŸ†" + else: + ratio = f"{time_ms / fastest:.2f}x slower" + status = "" + + lines.append(f"{lib:<20} {time_ms:>10.2f} {ratio:<15} {status}") + + lines.append("-" * 70) + return "\n".join(lines) + + +# SELECT Query Benchmarks + + +def bench_select_rapidquery(): + query = ( + rq.Select(rq.Expr.asterisk()) + .from_table("users") + .where(rq.Expr.col("name").like(r"%linus%")) + .offset(20) + .limit(20) + ) + query.to_sql("postgresql") + + +def bench_select_sqlalchemy(): + query = ( + sa.select(sa.text("*")) + .select_from(sa.table("users")) + .where(sa.column("name").like(r"%linus%")) + .offset(20) + .limit(20) + ) + str(query.compile(dialect=SA_DIALECT, compile_kwargs={"literal_binds": True})) + + +def bench_select_pypika(): + query = ( + pypika.Query.from_("users") + .where(pypika.Field("name").like(r"%linus%")) + .offset(20) + .limit(20) + .select("*") + ) + str(query) + + +# INSERT Query Benchmarks + + +def bench_insert_rapidquery(): + query = ( + rq.Insert() + .into("glyph") + .columns("aspect", "image") + .values(5.15, "12A") + .values(16, "14A") + .returning("id") + ) + query.to_sql("postgresql") + + +sa_glyph = sa.table("glyph", sa.column("aspect", sa.Float), sa.column("image", sa.String)) + + +def bench_insert_sqlalchemy(): + query = sa.insert(sa_glyph).values( + [{"aspect": 5.15, "image": "12A"}, {"aspect": 16, "image": "14A"}] + ) + str(query.compile(dialect=SA_DIALECT, compile_kwargs={"literal_binds": True})) + + +def bench_insert_pypika(): + query = ( + pypika.Query.into("glyph").columns("aspect", "image").insert(5.15, "12A").insert(16, "14A") + ) + str(query) + + +# UPDATE Query Benchmarks + + +def bench_update_rapidquery(): + query = ( + rq.Update() + .table("wallets") + .values(amount=rq.Expr.col("amount") + 10) + .where(rq.Expr.col("id").between(10, 30)) + ) + query.to_sql("postgresql") + + +sa_wallets = sa.table("wallets", sa.column("amount", sa.Integer), sa.column("id", sa.Integer)) + + +def bench_update_sqlalchemy(): + query = ( + sa.update(sa_wallets) + .values(amount=sa_wallets.c.amount + 10) + .where(sa.between(sa_wallets.c.id, 10, 30)) + ) + str(query.compile(dialect=SA_DIALECT, compile_kwargs={"literal_binds": True})) + + +def bench_update_pypika(): + query = ( + pypika.Query.update("wallets") + .set("amount", pypika.Field("amount") + 10) + .where(pypika.Field("id").between(10, 30)) + ) + str(query) + + +# DELETE Query Benchmarks + + +def bench_delete_rapidquery(): + query = ( + rq.Delete() + .from_table("users") + .where( + rq.all( + rq.Expr.col("id") > 10, + rq.Expr.col("id") < 30, + ) + ) + ) + query.to_sql("postgresql") + + +sa_users = sa.table("users", sa.column("id", sa.Integer)) + + +def bench_delete_sqlalchemy(): + query = sa.delete(sa_users).where(sa.and_(sa_users.c.id > 10, sa_users.c.id < 30)) + str(query.compile(dialect=SA_DIALECT, compile_kwargs={"literal_binds": True})) + + +def bench_delete_pypika(): + query = ( + pypika.Query.from_("users") + .where((pypika.Field("id") > 10) & (pypika.Field("id") < 30)) + .delete() + ) + str(query) + + +def run_benchmarks(): + print(f"Iterations per test: {ITERATIONS:,}") + print(f"Python version: {sys.version.split()[0]}") + print() + + print("\nšŸ“Š SELECT Query Benchmark") + results = { + "RapidQuery": benchmark(bench_select_rapidquery), + "SQLAlchemy": benchmark(bench_select_sqlalchemy), + "PyPika": benchmark(bench_select_pypika), + } + print(format_results(results)) + + print("\nšŸ“Š INSERT Query Benchmark") + results = { + "RapidQuery": benchmark(bench_insert_rapidquery), + "SQLAlchemy": benchmark(bench_insert_sqlalchemy), + "PyPika": benchmark(bench_insert_pypika), + } + print(format_results(results)) + + print("\nšŸ“Š UPDATE Query Benchmark") + results = { + "RapidQuery": benchmark(bench_update_rapidquery), + "SQLAlchemy": benchmark(bench_update_sqlalchemy), + "PyPika": benchmark(bench_update_pypika), + } + print(format_results(results)) + + print("\nšŸ“Š DELETE Query Benchmark") + results = { + "RapidQuery": benchmark(bench_delete_rapidquery), + "SQLAlchemy": benchmark(bench_delete_sqlalchemy), + "PyPika": benchmark(bench_delete_pypika), + } + print(format_results(results)) + + +if __name__ == "__main__": + run_benchmarks() diff --git a/pyproject.toml b/pyproject.toml index b431f61..3c34094 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -34,7 +34,6 @@ dynamic = [ 'readme', 'version' ] -dependencies = [] [project.urls] Homepage = 'https://github.com/awolverp/rapidquery' diff --git a/rapidquery/__init__.py b/rapidquery/__init__.py index d2337a2..16d98ef 100644 --- a/rapidquery/__init__.py +++ b/rapidquery/__init__.py @@ -1,12 +1,3 @@ -""" -**RapidQuery: High-Performance SQL Query Builder for Python** - -RapidQuery is a powerful SQL query builder library designed for Python, combining -the simplicity of Python with the raw speed and safety of **Rust**. Build complex -SQL queries effortlessly and efficiently, with a library that prioritizes both -performance and ease of use. -""" - from ._lib import ASTERISK as ASTERISK from ._lib import AdaptedValue as AdaptedValue from ._lib import AliasedTable as AliasedTable @@ -25,6 +16,7 @@ from ._lib import BitType as BitType from ._lib import BlobType as BlobType from ._lib import BooleanType as BooleanType +from ._lib import Case as Case from ._lib import CharType as CharType from ._lib import CidrType as CidrType from ._lib import Column as Column @@ -71,7 +63,7 @@ from ._lib import RenameTable as RenameTable from ._lib import SchemaStatement as SchemaStatement from ._lib import Select as Select -from ._lib import SelectExpr as SelectExpr +from ._lib import SelectCol as SelectCol from ._lib import SmallIntegerType as SmallIntegerType from ._lib import SmallUnsignedType as SmallUnsignedType from ._lib import StringType as StringType @@ -90,10 +82,9 @@ from ._lib import VarBinaryType as VarBinaryType from ._lib import VarBitType as VarBitType from ._lib import VectorType as VectorType +from ._lib import Window as Window +from ._lib import WindowFrame as WindowFrame from ._lib import YearType as YearType -from ._lib import _AliasedTableColumnsSequence as _AliasedTableColumnsSequence -from ._lib import _AsteriskType as _AsteriskType -from ._lib import _TableColumnsSequence as _TableColumnsSequence from ._lib import all as all from ._lib import any as any from ._lib import not_ as not_ diff --git a/rapidquery/_lib.pyi b/rapidquery/_lib.pyi index ca04096..78ffde3 100644 --- a/rapidquery/_lib.pyi +++ b/rapidquery/_lib.pyi @@ -72,25 +72,12 @@ class ColumnTypeMeta(typing.Generic[T]): def __repr__(self) -> str: ... class _LengthColumnType(ColumnTypeMeta[T]): - """ - Base class for column types that have a length parameter. - - This is an internal base class for column types like CHAR, VARCHAR, - BINARY, and VARBINARY that specify a maximum length constraint. - """ - length: typing.Optional[int] """The maximum length constraint for this column type.""" def __new__(cls, length: typing.Optional[int] = ...) -> Self: ... class _PrecisionScaleColumnType(ColumnTypeMeta[T]): - """ - Base class for numeric column types with precision and scale parameters. - - This is an internal base class for numeric types like DECIMAL and NUMERIC - that require both precision (total digits) and scale (decimal places) specification. - """ def __new__(cls, precision_scale: typing.Optional[typing.Tuple[int, int]] = ...) -> Self: ... @property def precision_scale(self) -> typing.Optional[typing.Tuple[int, int]]: @@ -577,6 +564,8 @@ class AdaptedValue(typing.Generic[T]): AdaptedValue("127.0.0.1", Inet()) # -> INET SQL type (network address) AdaptedValue([4.3, 5.6], Vector()) # -> VECTOR SQL type (for AI embeddings) + Possible exceptions are `TypeError`, `ValueError`, and `OverflowError`. + NOTE: this class is immutable and frozen. """ ... @@ -664,16 +653,6 @@ class ColumnRef: This class is used to uniquely identify columns in SQL queries, supporting schema-qualified and table-qualified column references. - - Attributes: - name: The column name - table: The table name containing the column, if specified - schema: The schema name containing the table, if specified - - Example: - >>> ColumnRef("id") - >>> ColumnRef("id", table="users") - >>> ColumnRef("id", table="users", schema="public") """ def __new__( @@ -756,6 +735,8 @@ _ExprValue = typing.Union[ _AsteriskType, typing.Any, Select, + Case, + FunctionCall, ] class Expr: @@ -770,6 +751,7 @@ class Expr: when building the final SQL statement. Example:: + # Basic comparison e = Expr(1) > 2 e.to_sql("mysql") @@ -815,19 +797,6 @@ class Expr: """ ... - @classmethod - def func(cls, value: FunctionCall) -> Self: - """ - Create an expression from a FunctionCall. - - Args: - value: The function call to convert to an expression - - Returns: - An Expr representing the function call - """ - ... - @classmethod def col(cls, name: typing.Union[str, ColumnRef]) -> Self: """ @@ -1015,6 +984,7 @@ class Expr: """ ... + def __neq__(self) -> Self: ... def __sub__(self, other: _ExprValue) -> Self: """ Create a subtraction expression. @@ -1582,6 +1552,27 @@ class FunctionCall: """ ... + @classmethod + def rank(cls) -> Self: + """ + Create a RANK function call. + """ + ... + + @classmethod + def dense_rank(cls) -> Self: + """ + Create a DENSE_RANK function call. + """ + ... + + @classmethod + def percent_rank(cls) -> Self: + """ + Create a PERCENT_RANK function call. + """ + ... + @classmethod def round(cls, expr: _ExprValue) -> Self: """ @@ -1653,6 +1644,7 @@ def all(arg1: Expr, *args: Expr) -> Expr: An Expr representing the logical AND of all input expressions Example: + >>> all(Expr.col("age") > 18, Expr.col("status") == "active") # Equivalent to: age > 18 AND status = 'active' """ @@ -1672,6 +1664,7 @@ def any(arg1: Expr, *args: Expr) -> Expr: An Expr representing the logical OR of all input expressions Example: + >>> any(Expr.col("status") == "pending", Expr.col("status") == "approved") # Equivalent to: status = 'pending' OR status = 'approved' """ @@ -1682,7 +1675,8 @@ def not_(arg1: Expr) -> Expr: Create a logical NOT. Example: - >>> not_(Expr.col("status") == "pending", Expr.col("status")) + + >>> not_(Expr.col("status") == "pending") # Equivalent to: NOT status = 'pending' """ ... @@ -1702,7 +1696,8 @@ class Column(typing.Generic[T]): of table columns. It encapsulates all the properties that define how a column behaves and what data it can store. - Example: + Example:: + >>> Column("id", Integer(), primary_key=True, auto_increment=True) >>> Column("name", String(255), nullable=False, default="unknown") >>> Column("created_at", Timestamp(), default=Expr.current_timestamp()) @@ -1838,7 +1833,8 @@ class TableName: The class provides parsing capabilities for string representations and supports comparison operations. - Examples: + Examples:: + >>> TableName("users") # Simple table name >>> TableName("users", schema="public") # Schema-qualified table >>> TableName("users", schema="hr", database="company") # Fully qualified @@ -2277,7 +2273,7 @@ class _TableColumnsSequence: def __getattr__(self, name: str) -> Column: ... def get(self, name: str) -> Column: ... def append(self, col: Column) -> None: ... - def remove(self, name: str) -> None: ... + def remove(self, name: str) -> Column: ... def to_list(self) -> typing.Sequence[Column]: ... def clear(self) -> None: ... def __len__(self) -> int: ... @@ -3250,7 +3246,35 @@ class Update(QueryStatement): def __repr__(self) -> str: ... -class SelectExpr: +class WindowFrame: + @classmethod + def unbounded_preceding(cls) -> Self: ... + @classmethod + def unbounded_following(cls) -> Self: ... + @classmethod + def current_row(cls) -> Self: ... + @classmethod + def preceding(cls, val: int) -> Self: ... + @classmethod + def following(cls, val: int) -> Self: ... + +class Window: + def __new__(cls, *partition_by: Expr) -> Self: ... + def partition(self, *partition_by: Expr) -> Self: ... + def order_by( + self, + target: _ExprValue, + order: typing.Literal["asc", "desc"], + null_order: typing.Optional[typing.Literal["first", "last"]] = ..., + ) -> Self: ... + def frame( + self, + type: typing.Literal["rows", "range"], + start: WindowFrame, + end: typing.Optional[WindowFrame] = None, + ) -> Self: ... + +class SelectCol: """ Represents a column expression with an optional alias in a SELECT clause. @@ -3258,20 +3282,26 @@ class SelectExpr: for the result column. Example: - >>> SelectExpr(Expr.col("price") * 1.1, "price_with_tax") - >>> SelectExpr(Expr.count(), "total_count") + + >>> SelectCol(Expr.col("price") * 1.1, "price_with_tax") + >>> SelectCol(Expr.count(), "total_count") """ - def __new__(cls, expr: _ExprValue, alias: typing.Optional[str] = ...): + def __new__( + cls, + expr: _ExprValue, + alias: typing.Optional[str] = ..., + window: typing.Union[str, Window, None] = ..., + ): """ - Create a new SelectExpr. + Create a new SelectCol. Args: expr: The expression to select alias: Optional alias name for the result column Returns: - A new SelectExpr instance + A new SelectCol instance """ ... @@ -3285,6 +3315,8 @@ class SelectExpr: """The alias name for the result column, if any.""" ... + @property + def window(self) -> typing.Union[str, Window, None]: ... def __repr__(self) -> str: ... class Select(QueryStatement): @@ -3303,6 +3335,7 @@ class Select(QueryStatement): - DISTINCT queries Example: + >>> Select(Expr.col("name"), Expr.col("email")).from_table("users") \\ ... .where(Expr.col("active") == True) \\ ... .order_by(Order(Expr.col("created_at"), ORDER_DESC)) \\ @@ -3312,12 +3345,12 @@ class Select(QueryStatement): ... .where(Expr.col("published") == True) """ - def __new__(cls, *cols: typing.Union[SelectExpr, _ExprValue]) -> Self: + def __new__(cls, *cols: typing.Union[SelectCol, _ExprValue]) -> Self: """ Create a new SELECT statement builder. Args: - *cols: Optional initial columns to select (expressions or SelectExpr objects) + *cols: Optional initial columns to select (expressions or SelectCol objects) Returns: A new Select instance @@ -3336,12 +3369,12 @@ class Select(QueryStatement): """ ... - def columns(self, *cols: typing.Union[SelectExpr, _ExprValue]) -> Self: + def columns(self, *cols: typing.Union[SelectCol, _ExprValue]) -> Self: """ Specify or add columns to select. Args: - *cols: Column names, expressions, or SelectExpr objects to select + *cols: Column names, expressions, or SelectCol objects to select Returns: Self for method chaining @@ -3556,4 +3589,12 @@ class Select(QueryStatement): """ ... + def window(self, name: str, statement: Window) -> Self: ... + def __repr__(self) -> str: ... + +class Case: + def __new__(cls) -> Self: ... + def when(self, cond: _ExprValue, then: _ExprValue) -> Self: ... + def else_(self, expr: _ExprValue) -> Self: ... + def to_expr(self) -> Expr: ... def __repr__(self) -> str: ... diff --git a/src/adaptation/deserialize.rs b/src/adaptation/deserialize.rs index d9a6d88..64599f0 100644 --- a/src/adaptation/deserialize.rs +++ b/src/adaptation/deserialize.rs @@ -1,7 +1,7 @@ use std::ptr::NonNull; #[derive(Debug, Default)] -pub enum DeserializedValue { +pub enum PythonValue { #[default] Null, Bool(bool), @@ -19,11 +19,11 @@ pub enum DeserializedValue { ), Uuid(NonNull), Decimal(NonNull), - Array(Vec), + Array(Vec), Vector(NonNull), } -impl Clone for DeserializedValue { +impl Clone for PythonValue { fn clone(&self) -> Self { unsafe { match self { @@ -74,7 +74,7 @@ impl Clone for DeserializedValue { } } -impl Drop for DeserializedValue { +impl Drop for PythonValue { fn drop(&mut self) { unsafe { match self { @@ -98,7 +98,7 @@ impl Drop for DeserializedValue { } } -impl DeserializedValue { +impl PythonValue { pub unsafe fn as_pyobject(&self) -> *mut pyo3::ffi::PyObject { match self { Self::Null => pyo3::ffi::Py_None(), @@ -174,16 +174,16 @@ impl DeserializedValue { } } - pub fn serialize(&self, py: pyo3::Python<'_>) -> pyo3::PyResult { + pub fn serialize(&self, py: pyo3::Python<'_>) -> pyo3::PyResult { use pyo3::types::PyAnyMethods; unsafe { match self { - Self::Null => Ok(super::serialize::SerializedValue::Null), - Self::Bool(op) => Ok(super::serialize::SerializedValue::Bool(*op)), - Self::BigInt(op) => Ok(super::serialize::SerializedValue::BigInt(*op)), - Self::BigUnsigned(op) => Ok(super::serialize::SerializedValue::BigUnsigned(*op)), - Self::Double(op) => Ok(super::serialize::SerializedValue::Double(*op)), + Self::Null => Ok(super::serialize::RustValue::Null), + Self::Bool(op) => Ok(super::serialize::RustValue::Bool(*op)), + Self::BigInt(op) => Ok(super::serialize::RustValue::BigInt(*op)), + Self::BigUnsigned(op) => Ok(super::serialize::RustValue::BigUnsigned(*op)), + Self::Double(op) => Ok(super::serialize::RustValue::Double(*op)), Self::String(op) => { let mut size: pyo3::ffi::Py_ssize_t = 0; let c_str = pyo3::ffi::PyUnicode_AsUTF8AndSize(op.as_ptr(), &mut size); @@ -192,14 +192,14 @@ impl DeserializedValue { Err(pyo3::PyErr::fetch(py)) } else { let val = std::ffi::CStr::from_ptr(c_str); - Ok(super::serialize::SerializedValue::String(val.to_bytes().to_vec())) + Ok(super::serialize::RustValue::String(val.to_bytes().to_vec())) } } Self::Bytes(op) => { let bytes = pyo3::Py::::from_borrowed_ptr(py, op.as_ptr()).extract::>(py)?; - Ok(super::serialize::SerializedValue::Bytes(bytes)) + Ok(super::serialize::RustValue::Bytes(bytes)) } Self::Json(op) => { let serialized = super::common::_serialize_object_with_pyjson(py, op.as_ptr())?; @@ -220,20 +220,20 @@ impl DeserializedValue { pyo3::PyErr::new::(x.to_string()) })?; - Ok(super::serialize::SerializedValue::Json(val)) + Ok(super::serialize::RustValue::Json(val)) } } Self::ChronoDate(op) => { let val: pyo3::Bound<'_, pyo3::types::PyDate> = pyo3::Bound::from_borrowed_ptr(py, op.as_ptr()).cast_into()?; - Ok(super::serialize::SerializedValue::ChronoDate(val.extract()?)) + Ok(super::serialize::RustValue::ChronoDate(val.extract()?)) } Self::ChronoTime(op) => { let val: pyo3::Bound<'_, pyo3::types::PyTime> = pyo3::Bound::from_borrowed_ptr(py, op.as_ptr()).cast_into()?; - Ok(super::serialize::SerializedValue::ChronoTime(val.extract()?)) + Ok(super::serialize::RustValue::ChronoTime(val.extract()?)) } Self::ChronoDateTime(op) => { let val: pyo3::Bound<'_, pyo3::types::PyDateTime> = @@ -245,9 +245,9 @@ impl DeserializedValue { debug_assert!(!tzinfo.is_null()); if pyo3::ffi::Py_IsNone(tzinfo) == 1 { - Ok(super::serialize::SerializedValue::ChronoDateTime(val.extract()?)) + Ok(super::serialize::RustValue::ChronoDateTime(val.extract()?)) } else { - Ok(super::serialize::SerializedValue::ChronoDateTimeWithTimeZone( + Ok(super::serialize::RustValue::ChronoDateTimeWithTimeZone( val.extract()?, )) } @@ -255,23 +255,23 @@ impl DeserializedValue { Self::Uuid(op) => { let val: uuid::Uuid = pyo3::Bound::from_borrowed_ptr(py, op.as_ptr()).extract()?; - Ok(super::serialize::SerializedValue::Uuid(val)) + Ok(super::serialize::RustValue::Uuid(val)) } Self::Decimal(op) => { let val: rust_decimal::Decimal = pyo3::Bound::from_borrowed_ptr(py, op.as_ptr()).extract()?; - Ok(super::serialize::SerializedValue::Decimal(val)) + Ok(super::serialize::RustValue::Decimal(val)) } Self::Array(op) => { - let mut values: Vec = Vec::with_capacity(op.len()); + let mut values: Vec = Vec::with_capacity(op.len()); for item in op { let item = item.serialize(py)?; values.push(item); } - Ok(super::serialize::SerializedValue::Array(values)) + Ok(super::serialize::RustValue::Array(values)) } Self::Vector(op) => { let mut values: Vec = Vec::new(); @@ -282,7 +282,7 @@ impl DeserializedValue { values.push(item.extract::()?); } - Ok(super::serialize::SerializedValue::Vector(values)) + Ok(super::serialize::RustValue::Vector(values)) } } } diff --git a/src/adaptation/mod.rs b/src/adaptation/mod.rs index e93f0f2..fee9a69 100644 --- a/src/adaptation/mod.rs +++ b/src/adaptation/mod.rs @@ -5,21 +5,21 @@ mod common; mod deserialize; mod serialize; -pub use deserialize::DeserializedValue; -pub use serialize::SerializedValue; +pub use deserialize::PythonValue; +pub use serialize::RustValue; /// A bridge between Python & [`sea_query::Value`] #[derive(Debug, Clone)] pub struct ReturnableValue { - deserialized: Option, - serialized: Option, + deserialized: Option, + serialized: Option, } unsafe impl Send for ReturnableValue {} -impl From for ReturnableValue { +impl From for ReturnableValue { #[inline] - fn from(value: DeserializedValue) -> Self { + fn from(value: PythonValue) -> Self { Self { deserialized: Some(value), serialized: None, @@ -27,9 +27,9 @@ impl From for ReturnableValue { } } -impl From for ReturnableValue { +impl From for ReturnableValue { #[inline] - fn from(value: SerializedValue) -> Self { + fn from(value: RustValue) -> Self { Self { deserialized: None, serialized: Some(value), @@ -52,7 +52,7 @@ impl ReturnableValue { return Err(typeerror!("expected bool, got {}", object.py(), object.as_ptr())); } - Ok(Self::from(DeserializedValue::Bool( + Ok(Self::from(PythonValue::Bool( pyo3::ffi::Py_True() == object.as_ptr(), ))) }, @@ -66,7 +66,7 @@ impl ReturnableValue { return Err(pyo3::PyErr::fetch(object.py())); } - Ok(Self::from(DeserializedValue::BigInt(val))) + Ok(Self::from(PythonValue::BigInt(val))) }, sea_query::ColumnType::TinyUnsigned | sea_query::ColumnType::SmallUnsigned @@ -77,7 +77,7 @@ impl ReturnableValue { return Err(pyo3::PyErr::fetch(object.py())); } - Ok(Self::from(DeserializedValue::BigUnsigned(val))) + Ok(Self::from(PythonValue::BigUnsigned(val))) }, sea_query::ColumnType::Char(_) | sea_query::ColumnType::String(_) @@ -91,7 +91,7 @@ impl ReturnableValue { return Err(typeerror!("expected str, got {}", object.py(), object.as_ptr())); } - Ok(Self::from(DeserializedValue::String(NonNull::new_unchecked( + Ok(Self::from(PythonValue::String(NonNull::new_unchecked( object.into_ptr(), )))) }, @@ -104,19 +104,30 @@ impl ReturnableValue { return Err(typeerror!("expected bytes, got {}", object.py(), object.as_ptr())); } - Ok(Self::from(DeserializedValue::Bytes(NonNull::new_unchecked( + Ok(Self::from(PythonValue::Bytes(NonNull::new_unchecked( object.into_ptr(), )))) }, sea_query::ColumnType::Float | sea_query::ColumnType::Double => unsafe { + if pyo3::ffi::PyFloat_CheckExact(object.as_ptr()) == 0 + && pyo3::ffi::PyLong_CheckExact(object.as_ptr()) == 0 + { + return Err(typeerror!( + "expected float or int, got {}", + object.py(), + object.as_ptr() + )); + } + let val = pyo3::ffi::PyFloat_AsDouble(object.as_ptr()); if val == -1.0 && !pyo3::ffi::PyErr_Occurred().is_null() { return Err(pyo3::PyErr::fetch(object.py())); } - Ok(Self::from(DeserializedValue::Double(val))) + Ok(Self::from(PythonValue::Double(val))) }, sea_query::ColumnType::Decimal(_) | sea_query::ColumnType::Money(_) => unsafe { + // TODO: Support float if pyo3::ffi::Py_IS_TYPE(object.as_ptr(), crate::typeref::STD_DECIMAL_TYPE) == 0 { return Err(typeerror!( "expected decimal.Decimal, got {}", @@ -125,7 +136,7 @@ impl ReturnableValue { )); } - Ok(Self::from(DeserializedValue::Decimal(NonNull::new_unchecked( + Ok(Self::from(PythonValue::Decimal(NonNull::new_unchecked( object.into_ptr(), )))) }, @@ -138,9 +149,9 @@ impl ReturnableValue { )); } - Ok(Self::from(DeserializedValue::ChronoDateTime( - NonNull::new_unchecked(object.into_ptr()), - ))) + Ok(Self::from(PythonValue::ChronoDateTime(NonNull::new_unchecked( + object.into_ptr(), + )))) }, sea_query::ColumnType::TimestampWithTimeZone => unsafe { if pyo3::ffi::Py_IS_TYPE(object.as_ptr(), crate::typeref::STD_DATETIME_TYPE) == 0 { @@ -151,9 +162,9 @@ impl ReturnableValue { )); } - Ok(Self::from(DeserializedValue::ChronoDateTime( - NonNull::new_unchecked(object.into_ptr()), - ))) + Ok(Self::from(PythonValue::ChronoDateTime(NonNull::new_unchecked( + object.into_ptr(), + )))) }, sea_query::ColumnType::Time => unsafe { if pyo3::ffi::Py_IS_TYPE(object.as_ptr(), crate::typeref::STD_TIME_TYPE) == 0 { @@ -164,7 +175,7 @@ impl ReturnableValue { )); } - Ok(Self::from(DeserializedValue::ChronoTime(NonNull::new_unchecked( + Ok(Self::from(PythonValue::ChronoTime(NonNull::new_unchecked( object.into_ptr(), )))) }, @@ -177,14 +188,14 @@ impl ReturnableValue { )); } - Ok(Self::from(DeserializedValue::ChronoDate(NonNull::new_unchecked( + Ok(Self::from(PythonValue::ChronoDate(NonNull::new_unchecked( object.into_ptr(), )))) }, sea_query::ColumnType::Json | sea_query::ColumnType::JsonBinary => unsafe { common::_validate_json_object(object.py(), object.as_ptr())?; - Ok(Self::from(DeserializedValue::Json(NonNull::new_unchecked( + Ok(Self::from(PythonValue::Json(NonNull::new_unchecked( object.into_ptr(), )))) }, @@ -197,7 +208,7 @@ impl ReturnableValue { )); } - Ok(Self::from(DeserializedValue::Uuid(NonNull::new_unchecked( + Ok(Self::from(PythonValue::Uuid(NonNull::new_unchecked( object.into_ptr(), )))) }, @@ -208,7 +219,7 @@ impl ReturnableValue { return Err(typeerror!("expected str, got {}", object.py(), object.as_ptr())); } - Ok(Self::from(DeserializedValue::String(NonNull::new_unchecked( + Ok(Self::from(PythonValue::String(NonNull::new_unchecked( object.into_ptr(), )))) }, @@ -227,7 +238,7 @@ impl ReturnableValue { values.push(x.deserialized.unwrap()); } - Ok(Self::from(DeserializedValue::Array(values))) + Ok(Self::from(PythonValue::Array(values))) }, sea_query::ColumnType::Vector(_) => unsafe { use pyo3::types::PyListMethods; @@ -243,7 +254,9 @@ impl ReturnableValue { let list = object.cast_into_unchecked::(); for item in list.iter() { - if pyo3::ffi::PyFloat_CheckExact(item.as_ptr()) == 0 { + if pyo3::ffi::PyFloat_CheckExact(item.as_ptr()) == 0 + && pyo3::ffi::PyLong_CheckExact(item.as_ptr()) == 0 + { return Err(typeerror!( "expected list of floats, found an {:?} element", item.py(), @@ -252,7 +265,7 @@ impl ReturnableValue { } } - Ok(Self::from(DeserializedValue::Vector(NonNull::new_unchecked( + Ok(Self::from(PythonValue::Vector(NonNull::new_unchecked( list.into_ptr(), )))) }, @@ -275,19 +288,19 @@ impl ReturnableValue { } if pyo3::ffi::PyUnicode_CheckExact(object.as_ptr()) == 1 { - return Ok(Self::from(DeserializedValue::String(NonNull::new_unchecked( + return Ok(Self::from(PythonValue::String(NonNull::new_unchecked( object.into_ptr(), )))); } if pyo3::ffi::PyBool_Check(object.as_ptr()) == 1 { - return Ok(Self::from(DeserializedValue::Bool( + return Ok(Self::from(PythonValue::Bool( pyo3::ffi::Py_True() == object.as_ptr(), ))); } if pyo3::ffi::PyBytes_CheckExact(object.as_ptr()) == 1 { - return Ok(Self::from(DeserializedValue::Bytes(NonNull::new_unchecked( + return Ok(Self::from(PythonValue::Bytes(NonNull::new_unchecked( object.into_ptr(), )))); } @@ -297,37 +310,37 @@ impl ReturnableValue { { common::_validate_json_object(object.py(), object.as_ptr())?; - return Ok(Self::from(DeserializedValue::Json(NonNull::new_unchecked( + return Ok(Self::from(PythonValue::Json(NonNull::new_unchecked( object.into_ptr(), )))); } if pyo3::ffi::Py_TYPE(object.as_ptr()) == crate::typeref::STD_DECIMAL_TYPE { - return Ok(Self::from(DeserializedValue::Decimal(NonNull::new_unchecked( + return Ok(Self::from(PythonValue::Decimal(NonNull::new_unchecked( object.into_ptr(), )))); } if pyo3::ffi::Py_TYPE(object.as_ptr()) == crate::typeref::STD_DATETIME_TYPE { - return Ok(Self::from(DeserializedValue::ChronoDateTime( - NonNull::new_unchecked(object.into_ptr()), - ))); + return Ok(Self::from(PythonValue::ChronoDateTime(NonNull::new_unchecked( + object.into_ptr(), + )))); } if pyo3::ffi::Py_TYPE(object.as_ptr()) == crate::typeref::STD_DATE_TYPE { - return Ok(Self::from(DeserializedValue::ChronoDate(NonNull::new_unchecked( + return Ok(Self::from(PythonValue::ChronoDate(NonNull::new_unchecked( object.into_ptr(), )))); } if pyo3::ffi::Py_TYPE(object.as_ptr()) == crate::typeref::STD_TIME_TYPE { - return Ok(Self::from(DeserializedValue::ChronoTime(NonNull::new_unchecked( + return Ok(Self::from(PythonValue::ChronoTime(NonNull::new_unchecked( object.into_ptr(), )))); } if pyo3::ffi::Py_TYPE(object.as_ptr()) == crate::typeref::STD_UUID_TYPE { - return Ok(Self::from(DeserializedValue::Uuid(NonNull::new_unchecked( + return Ok(Self::from(PythonValue::Uuid(NonNull::new_unchecked( object.into_ptr(), )))); } @@ -347,8 +360,8 @@ impl ReturnableValue { unsafe { if pyo3::ffi::Py_IsNone(object.as_ptr()) == 1 { return Ok(Self { - deserialized: Some(DeserializedValue::Null), - serialized: Some(SerializedValue::Null), + deserialized: Some(PythonValue::Null), + serialized: Some(RustValue::Null), }); } } @@ -369,7 +382,7 @@ impl ReturnableValue { } #[inline] - pub fn serialize(&mut self, py: pyo3::Python<'_>) -> &SerializedValue { + pub fn serialize(&mut self, py: pyo3::Python<'_>) -> &RustValue { unsafe { if self.serialized.is_none() { self.serialized = Some( @@ -386,7 +399,7 @@ impl ReturnableValue { } #[inline] - pub fn deserialize(&mut self, py: pyo3::Python<'_>) -> &DeserializedValue { + pub fn deserialize(&mut self, py: pyo3::Python<'_>) -> &PythonValue { unsafe { if self.deserialized.is_none() { self.deserialized = Some( @@ -453,8 +466,8 @@ impl PyAdaptedValue { fn is_null(&self) -> bool { let lock = self.inner.lock(); - matches!(lock.deserialized.as_ref(), Some(DeserializedValue::Null)) - || matches!(lock.serialized.as_ref(), Some(SerializedValue::Null)) + matches!(lock.deserialized.as_ref(), Some(PythonValue::Null)) + || matches!(lock.serialized.as_ref(), Some(RustValue::Null)) } #[getter] @@ -463,10 +476,10 @@ impl PyAdaptedValue { matches!( lock.deserialized.as_ref(), - Some(DeserializedValue::BigInt(_)) | Some(DeserializedValue::BigUnsigned(_)) + Some(PythonValue::BigInt(_)) | Some(PythonValue::BigUnsigned(_)) ) || matches!( lock.serialized.as_ref(), - Some(SerializedValue::BigInt(_)) | Some(SerializedValue::BigUnsigned(_)) + Some(RustValue::BigInt(_)) | Some(RustValue::BigUnsigned(_)) ) } @@ -474,101 +487,99 @@ impl PyAdaptedValue { fn is_float(&self) -> bool { let lock = self.inner.lock(); - matches!(lock.deserialized.as_ref(), Some(DeserializedValue::Double(_))) - || matches!(lock.serialized.as_ref(), Some(SerializedValue::Double(_))) + matches!(lock.deserialized.as_ref(), Some(PythonValue::Double(_))) + || matches!(lock.serialized.as_ref(), Some(RustValue::Double(_))) } #[getter] fn is_boolean(&self) -> bool { let lock = self.inner.lock(); - matches!(lock.deserialized.as_ref(), Some(DeserializedValue::Bool(_))) - || matches!(lock.serialized.as_ref(), Some(SerializedValue::Bool(_))) + matches!(lock.deserialized.as_ref(), Some(PythonValue::Bool(_))) + || matches!(lock.serialized.as_ref(), Some(RustValue::Bool(_))) } #[getter] fn is_string(&self) -> bool { let lock = self.inner.lock(); - matches!(lock.deserialized.as_ref(), Some(DeserializedValue::String(_))) - || matches!(lock.serialized.as_ref(), Some(SerializedValue::String(_))) + matches!(lock.deserialized.as_ref(), Some(PythonValue::String(_))) + || matches!(lock.serialized.as_ref(), Some(RustValue::String(_))) } #[getter] fn is_date(&self) -> bool { let lock = self.inner.lock(); - matches!(lock.deserialized.as_ref(), Some(DeserializedValue::ChronoDate(_))) - || matches!(lock.serialized.as_ref(), Some(SerializedValue::ChronoDate(_))) + matches!(lock.deserialized.as_ref(), Some(PythonValue::ChronoDate(_))) + || matches!(lock.serialized.as_ref(), Some(RustValue::ChronoDate(_))) } #[getter] fn is_datetime(&self) -> bool { let lock = self.inner.lock(); - matches!( - lock.deserialized.as_ref(), - Some(DeserializedValue::ChronoDateTime(_)) - ) || matches!( - lock.serialized.as_ref(), - Some(SerializedValue::ChronoDateTime(_)) | Some(SerializedValue::ChronoDateTimeWithTimeZone(_)) - ) + matches!(lock.deserialized.as_ref(), Some(PythonValue::ChronoDateTime(_))) + || matches!( + lock.serialized.as_ref(), + Some(RustValue::ChronoDateTime(_)) | Some(RustValue::ChronoDateTimeWithTimeZone(_)) + ) } #[getter] fn is_time(&self) -> bool { let lock = self.inner.lock(); - matches!(lock.deserialized.as_ref(), Some(DeserializedValue::ChronoTime(_))) - || matches!(lock.serialized.as_ref(), Some(SerializedValue::ChronoTime(_))) + matches!(lock.deserialized.as_ref(), Some(PythonValue::ChronoTime(_))) + || matches!(lock.serialized.as_ref(), Some(RustValue::ChronoTime(_))) } #[getter] fn is_uuid(&self) -> bool { let lock = self.inner.lock(); - matches!(lock.deserialized.as_ref(), Some(DeserializedValue::Uuid(_))) - || matches!(lock.serialized.as_ref(), Some(SerializedValue::Uuid(_))) + matches!(lock.deserialized.as_ref(), Some(PythonValue::Uuid(_))) + || matches!(lock.serialized.as_ref(), Some(RustValue::Uuid(_))) } #[getter] fn is_bytes(&self) -> bool { let lock = self.inner.lock(); - matches!(lock.deserialized.as_ref(), Some(DeserializedValue::Bytes(_))) - || matches!(lock.serialized.as_ref(), Some(SerializedValue::Bytes(_))) + matches!(lock.deserialized.as_ref(), Some(PythonValue::Bytes(_))) + || matches!(lock.serialized.as_ref(), Some(RustValue::Bytes(_))) } #[getter] fn is_json(&self) -> bool { let lock = self.inner.lock(); - matches!(lock.deserialized.as_ref(), Some(DeserializedValue::Json(_))) - || matches!(lock.serialized.as_ref(), Some(SerializedValue::Json(_))) + matches!(lock.deserialized.as_ref(), Some(PythonValue::Json(_))) + || matches!(lock.serialized.as_ref(), Some(RustValue::Json(_))) } #[getter] fn is_decimal(&self) -> bool { let lock = self.inner.lock(); - matches!(lock.deserialized.as_ref(), Some(DeserializedValue::Decimal(_))) - || matches!(lock.serialized.as_ref(), Some(SerializedValue::Decimal(_))) + matches!(lock.deserialized.as_ref(), Some(PythonValue::Decimal(_))) + || matches!(lock.serialized.as_ref(), Some(RustValue::Decimal(_))) } #[getter] fn is_array(&self) -> bool { let lock = self.inner.lock(); - matches!(lock.deserialized.as_ref(), Some(DeserializedValue::Array(_))) - || matches!(lock.serialized.as_ref(), Some(SerializedValue::Array(_))) + matches!(lock.deserialized.as_ref(), Some(PythonValue::Array(_))) + || matches!(lock.serialized.as_ref(), Some(RustValue::Array(_))) } #[getter] fn is_vector(&self) -> bool { let lock = self.inner.lock(); - matches!(lock.deserialized.as_ref(), Some(DeserializedValue::Vector(_))) - || matches!(lock.serialized.as_ref(), Some(SerializedValue::Vector(_))) + matches!(lock.deserialized.as_ref(), Some(PythonValue::Vector(_))) + || matches!(lock.serialized.as_ref(), Some(RustValue::Vector(_))) } #[getter] diff --git a/src/adaptation/serialize.rs b/src/adaptation/serialize.rs index b7c1dfb..99a48e8 100644 --- a/src/adaptation/serialize.rs +++ b/src/adaptation/serialize.rs @@ -1,7 +1,7 @@ use std::ptr::NonNull; #[derive(Clone, Debug, PartialEq)] -pub enum SerializedValue { +pub enum RustValue { Null, Bool(bool), BigInt(i64), @@ -16,35 +16,35 @@ pub enum SerializedValue { ChronoDateTimeWithTimeZone(chrono::DateTime), Uuid(uuid::Uuid), Decimal(rust_decimal::Decimal), - Array(Vec), + Array(Vec), Vector(Vec), } -impl SerializedValue { - pub fn deserialize(&self, py: pyo3::Python<'_>) -> pyo3::PyResult { +impl RustValue { + pub fn deserialize(&self, py: pyo3::Python<'_>) -> pyo3::PyResult { use chrono::{Datelike, Timelike}; use pyo3::IntoPyObject; unsafe { match self { - Self::Null => Ok(super::deserialize::DeserializedValue::Null), - Self::Bool(x) => Ok(super::deserialize::DeserializedValue::Bool(*x)), - Self::BigInt(x) => Ok(super::deserialize::DeserializedValue::BigInt(*x)), - Self::BigUnsigned(x) => Ok(super::deserialize::DeserializedValue::BigUnsigned(*x)), - Self::Double(x) => Ok(super::deserialize::DeserializedValue::Double(*x)), + Self::Null => Ok(super::deserialize::PythonValue::Null), + Self::Bool(x) => Ok(super::deserialize::PythonValue::Bool(*x)), + Self::BigInt(x) => Ok(super::deserialize::PythonValue::BigInt(*x)), + Self::BigUnsigned(x) => Ok(super::deserialize::PythonValue::BigUnsigned(*x)), + Self::Double(x) => Ok(super::deserialize::PythonValue::Double(*x)), Self::String(x) => { let val = pyo3::types::PyString::intern(py, std::str::from_utf8_unchecked(x)); - Ok(super::deserialize::DeserializedValue::String( - NonNull::new_unchecked(val.into_ptr()), - )) + Ok(super::deserialize::PythonValue::String(NonNull::new_unchecked( + val.into_ptr(), + ))) } Self::Bytes(x) => { let val = pyo3::types::PyBytes::new(py, x); - Ok(super::deserialize::DeserializedValue::Bytes( - NonNull::new_unchecked(val.into_ptr()), - )) + Ok(super::deserialize::PythonValue::Bytes(NonNull::new_unchecked( + val.into_ptr(), + ))) } Self::Json(x) => { let encoded = serde_json::to_vec(x) @@ -52,15 +52,13 @@ impl SerializedValue { let val = pyo3::types::PyString::intern(py, std::str::from_utf8_unchecked(&encoded)); let val = super::common::_deserialize_object_with_pyjson(py, val.as_ptr())?; - Ok(super::deserialize::DeserializedValue::Json( - NonNull::new_unchecked(val), - )) + Ok(super::deserialize::PythonValue::Json(NonNull::new_unchecked(val))) } Self::ChronoDate(x) => { let val = pyo3::types::PyDate::new(py, x.year(), (x.month0() + 1) as u8, (x.day0() + 1) as u8)?; - Ok(super::deserialize::DeserializedValue::ChronoDate( + Ok(super::deserialize::PythonValue::ChronoDate( NonNull::new_unchecked(val.into_ptr()), )) } @@ -75,88 +73,84 @@ impl SerializedValue { ) .unwrap(); - Ok(super::deserialize::DeserializedValue::Bytes( - NonNull::new_unchecked(val.into_ptr()), - )) + Ok(super::deserialize::PythonValue::Bytes(NonNull::new_unchecked( + val.into_ptr(), + ))) } Self::ChronoDateTime(x) => { let val = x.into_pyobject(py)?; - Ok(super::deserialize::DeserializedValue::ChronoDateTime( + Ok(super::deserialize::PythonValue::ChronoDateTime( NonNull::new_unchecked(val.into_ptr()), )) } Self::ChronoDateTimeWithTimeZone(x) => { let val = x.into_pyobject(py)?; - Ok(super::deserialize::DeserializedValue::ChronoDateTime( + Ok(super::deserialize::PythonValue::ChronoDateTime( NonNull::new_unchecked(val.into_ptr()), )) } Self::Uuid(x) => { let val = x.into_pyobject(py)?; - Ok(super::deserialize::DeserializedValue::Uuid( - NonNull::new_unchecked(val.into_ptr()), - )) + Ok(super::deserialize::PythonValue::Uuid(NonNull::new_unchecked( + val.into_ptr(), + ))) } Self::Decimal(x) => { let val = x.into_pyobject(py)?; - Ok(super::deserialize::DeserializedValue::Decimal( - NonNull::new_unchecked(val.into_ptr()), - )) + Ok(super::deserialize::PythonValue::Decimal(NonNull::new_unchecked( + val.into_ptr(), + ))) } - Self::Array(x) => Ok(super::deserialize::DeserializedValue::Array( + Self::Array(x) => Ok(super::deserialize::PythonValue::Array( x.iter().map(|x| x.deserialize(py).unwrap()).collect(), )), Self::Vector(x) => { let val = x.into_pyobject(py)?; - Ok(super::deserialize::DeserializedValue::Vector( - NonNull::new_unchecked(val.into_ptr()), - )) + Ok(super::deserialize::PythonValue::Vector(NonNull::new_unchecked( + val.into_ptr(), + ))) } } } } } -impl From for sea_query::Value { +impl From for sea_query::Value { #[inline] - fn from(value: SerializedValue) -> Self { + fn from(value: RustValue) -> Self { match value { - SerializedValue::Null => Self::BigInt(None), - SerializedValue::Bool(x) => Self::Bool(Some(x)), - SerializedValue::BigInt(x) => Self::BigInt(Some(x)), - SerializedValue::BigUnsigned(x) => Self::BigUnsigned(Some(x)), - SerializedValue::Double(x) => Self::Double(Some(x)), - SerializedValue::String(x) => { - Self::String(Some(Box::new(unsafe { String::from_utf8_unchecked(x) }))) - } - SerializedValue::Bytes(x) => Self::Bytes(Some(Box::new(x.to_vec()))), - SerializedValue::Json(x) => Self::Json(Some(Box::new(x))), - SerializedValue::ChronoDate(x) => Self::ChronoDate(Some(Box::new(x))), - SerializedValue::ChronoTime(x) => Self::ChronoTime(Some(Box::new(x))), - SerializedValue::ChronoDateTime(x) => Self::ChronoDateTime(Some(Box::new(x))), - SerializedValue::ChronoDateTimeWithTimeZone(x) => { - Self::ChronoDateTimeWithTimeZone(Some(Box::new(x))) - } - SerializedValue::Uuid(x) => Self::Uuid(Some(Box::new(x))), - SerializedValue::Decimal(x) => Self::Decimal(Some(Box::new(x))), - SerializedValue::Array(x) => { + RustValue::Null => Self::BigInt(None), + RustValue::Bool(x) => Self::Bool(Some(x)), + RustValue::BigInt(x) => Self::BigInt(Some(x)), + RustValue::BigUnsigned(x) => Self::BigUnsigned(Some(x)), + RustValue::Double(x) => Self::Double(Some(x)), + RustValue::String(x) => Self::String(Some(Box::new(unsafe { String::from_utf8_unchecked(x) }))), + RustValue::Bytes(x) => Self::Bytes(Some(Box::new(x.to_vec()))), + RustValue::Json(x) => Self::Json(Some(Box::new(x))), + RustValue::ChronoDate(x) => Self::ChronoDate(Some(Box::new(x))), + RustValue::ChronoTime(x) => Self::ChronoTime(Some(Box::new(x))), + RustValue::ChronoDateTime(x) => Self::ChronoDateTime(Some(Box::new(x))), + RustValue::ChronoDateTimeWithTimeZone(x) => Self::ChronoDateTimeWithTimeZone(Some(Box::new(x))), + RustValue::Uuid(x) => Self::Uuid(Some(Box::new(x))), + RustValue::Decimal(x) => Self::Decimal(Some(Box::new(x))), + RustValue::Array(x) => { Self::Array( /* this parameter is unusable and not important */ sea_query::ArrayType::BigInt, Some(Box::new(x.into_iter().map(|x| x.into()).collect())), ) } - SerializedValue::Vector(x) => Self::Vector(Some(Box::new(pgvector::Vector::from(x)))), + RustValue::Vector(x) => Self::Vector(Some(Box::new(pgvector::Vector::from(x)))), } } } -impl From for SerializedValue { +impl From for RustValue { #[inline] fn from(value: sea_query::Value) -> Self { match value { diff --git a/src/expression/expr.rs b/src/expression/expr.rs index 12310bf..443e4b0 100644 --- a/src/expression/expr.rs +++ b/src/expression/expr.rs @@ -5,7 +5,6 @@ use pyo3::types::PyAnyMethods; /// A bridge between Python & [`sea_query::SimpleExpr`] #[pyo3::pyclass(module = "rapidquery._lib", name = "Expr", frozen)] pub struct PyExpr { - // TOD: support subquery and case pub(crate) inner: sea_query::SimpleExpr, } @@ -120,6 +119,12 @@ impl PyExpr { None, Box::new(stmt.into_sub_query_statement()), ))) + } else if type_ptr == crate::typeref::CASE_STATEMENT_TYPE { + let value = value.cast_into_unchecked::(); + let stmt = value.get().inner.lock(); + let stmt = Box::new(stmt.as_statement(value.py())); + + Ok(Self::from_simple_expr(sea_query::SimpleExpr::Case(stmt))) } else if pyo3::ffi::PyTuple_Check(value.as_ptr()) == 1 { let value = value.cast_into_unchecked::(); let mut arr: Vec = Vec::new(); @@ -179,25 +184,6 @@ impl PyExpr { } } - #[classmethod] - fn func( - _cls: &pyo3::Bound<'_, pyo3::types::PyType>, - value: &pyo3::Bound<'_, pyo3::PyAny>, - ) -> pyo3::PyResult { - unsafe { - if pyo3::ffi::Py_TYPE(value.as_ptr()) != crate::typeref::FUNCTION_CALL_TYPE { - return Err(typeerror!( - "expected FunctionCall, got {}", - value.py(), - value.as_ptr() - )); - } - - let x = value.cast_unchecked::(); - Ok(Self::from_function_call(x.get())) - } - } - #[classmethod] fn col( _cls: &pyo3::Bound<'_, pyo3::types::PyType>, @@ -460,6 +446,10 @@ impl PyExpr { Ok(sea_query::ExprTrait::and(slf.inner.clone(), other.inner).into()) } + fn __neg__<'a>(slf: pyo3::PyRef<'a, Self>) -> pyo3::PyResult { + Ok(sea_query::ExprTrait::mul(slf.inner.clone(), -1).into()) + } + fn __or__<'a>(slf: pyo3::PyRef<'a, Self>, other: &pyo3::Bound<'a, pyo3::PyAny>) -> pyo3::PyResult { let other = Self::try_from(other.clone())?; Ok(sea_query::ExprTrait::or(slf.inner.clone(), other.inner).into()) diff --git a/src/expression/function.rs b/src/expression/function.rs index 980754b..57e8f17 100644 --- a/src/expression/function.rs +++ b/src/expression/function.rs @@ -258,6 +258,27 @@ impl PyFunctionCall { } } + #[classmethod] + fn rank(_cls: &pyo3::Bound<'_, pyo3::types::PyType>) -> Self { + Self { + inner: parking_lot::Mutex::new(sea_query::Func::cust("RANK")), + } + } + + #[classmethod] + fn dense_rank(_cls: &pyo3::Bound<'_, pyo3::types::PyType>) -> Self { + Self { + inner: parking_lot::Mutex::new(sea_query::Func::cust("DENSE_RANK")), + } + } + + #[classmethod] + fn percent_rank(_cls: &pyo3::Bound<'_, pyo3::types::PyType>) -> Self { + Self { + inner: parking_lot::Mutex::new(sea_query::Func::cust("PERCENT_RANK")), + } + } + #[classmethod] fn round( _cls: &pyo3::Bound<'_, pyo3::types::PyType>, diff --git a/src/lib.rs b/src/lib.rs index dd50b5a..4c407ab 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -7,9 +7,6 @@ #![feature(likely_unlikely)] #![feature(optimize_attribute)] -// TODO: Use [`pyo3::Bound`] instead of [`Vec`] arguments -// to improve performance - /// Helper macros and some utilitize functions #[macro_use] mod macros; @@ -84,11 +81,17 @@ mod _lib { use super::query::update::PyUpdate; #[pymodule_export] - use super::query::select::{PySelect, PySelectExpr}; + use super::query::select::{PySelect, PySelectCol}; #[pymodule_export] use super::query::on_conflict::PyOnConflict; + #[pymodule_export] + use super::query::case::PyCase; + + #[pymodule_export] + use super::query::window::{PyWindow, PyWindowFrame}; + #[pymodule_init] fn init(m: &pyo3::Bound<'_, pyo3::types::PyModule>) -> pyo3::PyResult<()> { m.add("INTERVAL_YEAR", sea_query::PgInterval::Year as u8)?; diff --git a/src/macros.rs b/src/macros.rs index cea83ab..335e98f 100644 --- a/src/macros.rs +++ b/src/macros.rs @@ -36,7 +36,7 @@ pub unsafe fn get_type_name<'a>(py: pyo3::Python<'a>, obj: *mut pyo3::ffi::PyObj #[macro_export] macro_rules! typeerror { ( - $message:expr, + $message:expr $(,)? ) => { pyo3::PyErr::new::($message) }; @@ -100,7 +100,7 @@ macro_rules! build_query_parts { let values = { values .into_iter() - .map(|x| $crate::adaptation::SerializedValue::from(x)) + .map(|x| $crate::adaptation::RustValue::from(x)) .map(|x| $crate::adaptation::ReturnableValue::from(x)) .map(|x| $crate::adaptation::PyAdaptedValue::from(x)) }; diff --git a/src/query/case.rs b/src/query/case.rs new file mode 100644 index 0000000..8725be8 --- /dev/null +++ b/src/query/case.rs @@ -0,0 +1,82 @@ +#[derive(Default)] +pub struct CaseInner { + // Always is `Vec<(PyExpr, PyExpr)>` + when: Vec<(pyo3::Py, pyo3::Py)>, + // Always is `Option` + r#else: Option>, +} + +impl CaseInner { + #[inline] + pub fn as_statement(&self, py: pyo3::Python) -> sea_query::CaseStatement { + let mut stmt = sea_query::CaseStatement::new(); + + for (cond, then) in &self.when { + let cond = unsafe { cond.cast_bound_unchecked::(py) }; + let then = unsafe { then.cast_bound_unchecked::(py) }; + + stmt = stmt.case(cond.get().inner.clone(), then.get().inner.clone()); + } + + if let Some(x) = &self.r#else { + let x = unsafe { x.cast_bound_unchecked::(py) }; + stmt = stmt.finally(x.get().inner.clone()); + } + + stmt + } +} + +#[pyo3::pyclass(module = "rapidquery._lib", name = "Case", frozen)] +pub struct PyCase { + pub inner: parking_lot::Mutex, +} + +#[pyo3::pymethods] +impl PyCase { + #[new] + fn new() -> Self { + Self { + inner: parking_lot::Mutex::new(Default::default()), + } + } + + fn when<'a>( + slf: pyo3::PyRef<'a, Self>, + cond: pyo3::Bound<'a, pyo3::PyAny>, + then: pyo3::Bound<'a, pyo3::PyAny>, + ) -> pyo3::PyResult> { + let cond = crate::expression::PyExpr::from_bound_into_any(cond)?; + let then = crate::expression::PyExpr::from_bound_into_any(then)?; + + { + let mut lock = slf.inner.lock(); + lock.when.push((cond, then)); + } + + Ok(slf) + } + + fn else_<'a>( + slf: pyo3::PyRef<'a, Self>, + expr: pyo3::Bound<'a, pyo3::PyAny>, + ) -> pyo3::PyResult> { + let expr = crate::expression::PyExpr::from_bound_into_any(expr)?; + + { + let mut lock = slf.inner.lock(); + lock.r#else = Some(expr); + } + + Ok(slf) + } + + fn to_expr(&self, py: pyo3::Python) -> crate::expression::PyExpr { + let stmt = { + let lock = self.inner.lock(); + lock.as_statement(py) + }; + + crate::expression::PyExpr::from(sea_query::SimpleExpr::Case(Box::new(stmt))) + } +} diff --git a/src/query/delete.rs b/src/query/delete.rs index 8034c48..1f29f29 100644 --- a/src/query/delete.rs +++ b/src/query/delete.rs @@ -7,8 +7,8 @@ pub struct DeleteInner { // Always is `Option` pub table: Option>, - // Always is `Option` - pub r#where: Option>, + // Always is `Vec` + pub r#where: Vec>, pub limit: Option, pub returning_clause: super::returning::ReturningClause, pub orders: Vec, @@ -25,7 +25,7 @@ impl DeleteInner { stmt.from_table(x.get().clone()); } - if let Some(x) = &self.r#where { + for x in &self.r#where { let x = unsafe { x.cast_bound_unchecked::(py) }; stmt.and_where(x.get().inner.clone()); } @@ -159,7 +159,7 @@ impl PyDelete { { let mut lock = slf.inner.lock(); - lock.r#where = Some(condition); + lock.r#where.push(condition); } Ok(slf) @@ -215,8 +215,16 @@ impl PyDelete { if let Some(x) = lock.limit { write!(s, " limit={x}").unwrap(); } - if let Some(x) = &lock.r#where { - write!(s, " where={x}").unwrap(); + + write!(s, " where=[").unwrap(); + + let n = lock.r#where.len(); + for (index, expr) in lock.r#where.iter().enumerate() { + if index + 1 == n { + write!(s, "{expr}]").unwrap(); + } else { + write!(s, "{expr}, ").unwrap(); + } } write!(s, " orders=[").unwrap(); diff --git a/src/query/mod.rs b/src/query/mod.rs index f43eedd..d4582da 100644 --- a/src/query/mod.rs +++ b/src/query/mod.rs @@ -1,3 +1,4 @@ +pub mod case; pub mod delete; pub mod insert; pub mod on_conflict; @@ -5,3 +6,4 @@ pub mod order; pub mod returning; pub mod select; pub mod update; +pub mod window; diff --git a/src/query/select.rs b/src/query/select.rs index a5a2e7e..9e50340 100644 --- a/src/query/select.rs +++ b/src/query/select.rs @@ -1,34 +1,58 @@ use crate::backend::PyQueryStatement; use pyo3::types::PyTupleMethods; +use pyo3::PyTypeInfo; use sea_query::IntoIden; -#[pyo3::pyclass(module = "rapidquery._lib", name = "SelectExpr", frozen)] -pub struct PySelectExpr { +#[pyo3::pyclass(module = "rapidquery._lib", name = "SelectCol", frozen)] +pub struct PySelectCol { // Always is `PyExpr` pub expr: pyo3::Py, // Always is `PyExpr` pub alias: Option, - // TODO - // pub window: pyo3::Py, + + // Always is `PyWindow | PyString` + pub window: Option>, } -impl PySelectExpr { +impl PySelectCol { pub fn clone_ref(&self, py: pyo3::Python) -> Self { Self { expr: self.expr.clone_ref(py), + window: self.window.as_ref().map(|x| x.clone_ref(py)), alias: self.alias.clone(), } } - pub fn as_select_expr(&self, py: pyo3::Python) -> sea_query::SelectExpr { + pub fn as_statement(&self, py: pyo3::Python) -> sea_query::SelectExpr { let expr = unsafe { self.expr.cast_bound_unchecked::(py) }; let expr = expr.get().inner.clone(); - sea_query::SelectExpr { - expr, - alias: self.alias.as_ref().map(|x| sea_query::Alias::new(x).into_iden()), - window: None, + if let Some(window_ref) = &self.window { + let window_type = unsafe { + if pyo3::ffi::PyUnicode_CheckExact(window_ref.as_ptr()) == 1 { + sea_query::WindowSelectType::Name( + sea_query::Alias::new(window_ref.extract::(py).unwrap_unchecked()) + .into_iden(), + ) + } else { + let py_window = window_ref.cast_bound_unchecked::(py); + let stmt = py_window.get().inner.lock(); + sea_query::WindowSelectType::Query(stmt.as_statement(py)) + } + }; + + sea_query::SelectExpr { + expr, + alias: self.alias.as_ref().map(|x| sea_query::Alias::new(x).into_iden()), + window: Some(window_type), + } + } else { + sea_query::SelectExpr { + expr, + alias: self.alias.as_ref().map(|x| sea_query::Alias::new(x).into_iden()), + window: None, + } } } @@ -44,17 +68,22 @@ impl PySelectExpr { let slf = Self { expr: bound.clone().unbind(), alias: None, + window: None, }; return pyo3::Py::new(bound.py(), slf).map(|x| x.into_any()); } - if PySelectExpr::is_exact_type_of(bound) { + if PySelectCol::is_exact_type_of(bound) { return Ok(bound.clone().unbind()); } let expr = crate::expression::PyExpr::from_bound_into_any(bound.clone())?; - let slf = Self { expr, alias: None }; + let slf = Self { + expr, + alias: None, + window: None, + }; pyo3::Py::new(bound.py(), slf).map(|x| x.into_any()) } @@ -62,19 +91,28 @@ impl PySelectExpr { } #[pyo3::pymethods] -impl PySelectExpr { +impl PySelectCol { #[new] - #[pyo3(signature=(expr, alias=None))] - fn new(expr: &pyo3::Bound<'_, pyo3::PyAny>, alias: Option) -> pyo3::PyResult> { + #[pyo3(signature=(expr, alias=None, window=None))] + fn new( + expr: &pyo3::Bound<'_, pyo3::PyAny>, + alias: Option, + window: Option>, + ) -> pyo3::PyResult> { use pyo3::PyTypeInfo; - if PySelectExpr::is_exact_type_of(expr) { + if PySelectCol::is_exact_type_of(expr) { let slf = unsafe { expr.clone().cast_into_unchecked::() }; if let Some(x) = alias { let expr = slf.get().expr.clone_ref(slf.py()); + let window = slf.get().window.as_ref().map(|x| x.clone_ref(slf.py())); - let new_slf = Self { expr, alias: Some(x) }; + let new_slf = Self { + expr, + alias: Some(x), + window, + }; Ok(pyo3::Py::new(slf.py(), new_slf).unwrap()) } else { Ok(slf.unbind()) @@ -82,15 +120,48 @@ impl PySelectExpr { } else { let py = expr.py(); let expr = crate::expression::PyExpr::from_bound_into_any(expr.clone())?; - let slf = Self { expr, alias }; + if let Some(window_ref) = &window { + unsafe { + if pyo3::ffi::PyUnicode_CheckExact(window_ref.as_ptr()) == 0 + && !super::window::PyWindow::is_exact_type_of(window_ref) + { + return Err(typeerror!( + "expected Window or str, got {:?}", + py, + window_ref.as_ptr() + )); + } + } + } + + let slf = Self { + expr, + alias, + window: window.map(|x| x.unbind()), + }; Ok(pyo3::Py::new(py, slf).unwrap()) } } + + #[getter] + fn expr(&self, py: pyo3::Python) -> pyo3::Py { + self.expr.clone_ref(py) + } + + #[getter] + fn alias(&self) -> Option { + self.alias.clone() + } + + #[getter] + fn window(&self, py: pyo3::Python) -> Option> { + self.window.as_ref().map(|x| x.clone_ref(py)) + } } #[derive(Debug, Default)] -pub enum SelectDistinct { +pub enum DistinctMode { #[default] None, Distinct, @@ -100,7 +171,7 @@ pub enum SelectDistinct { ), } -pub struct SelectLock { +pub struct LockOptions { pub r#type: sea_query::LockType, pub behavior: Option, @@ -108,7 +179,7 @@ pub struct SelectLock { pub tables: Vec>, } -pub struct SelectJoin { +pub struct JoinOptions { pub r#type: sea_query::JoinType, // Always is `TableName | PySelect` @@ -119,7 +190,7 @@ pub struct SelectJoin { pub lateral: Option, } -pub enum SelectTable { +pub enum SelectReference { SubQuery( // Always is `PySelect` pyo3::Py, @@ -139,14 +210,13 @@ pub enum SelectTable { #[derive(Default)] pub struct SelectInner { // TODO: support from_values - pub tables: Vec, + pub tables: Vec, - // TODO: support subqueries // Always is `Option` pub cols: Vec>, - // Always is `Option` - pub r#where: Option>, + // Always is `Vec` + pub r#where: Vec>, // Always is `Vec` pub groups: Vec>, @@ -158,14 +228,14 @@ pub struct SelectInner { pub having: Option>, pub orders: Vec, - - pub distinct: SelectDistinct, - pub join: Vec, - pub lock: Option, + pub distinct: DistinctMode, + pub join: Vec, + pub lock: Option, pub limit: Option, pub offset: Option, + pub window: Option<(String, pyo3::Py)>, + // TODO - // pub window: Option>, // pub with: Option>, // pub table_sample: Option>, // pub index_hint: Option>, @@ -177,11 +247,11 @@ impl SelectInner { let mut stmt = sea_query::SelectStatement::new(); match &self.distinct { - SelectDistinct::None => (), - SelectDistinct::Distinct => { + DistinctMode::None => (), + DistinctMode::Distinct => { stmt.distinct(); } - SelectDistinct::DistinctOn(cols) => { + DistinctMode::DistinctOn(cols) => { use sea_query::IntoColumnRef; stmt.distinct_on(cols.iter().map(|col| unsafe { @@ -198,15 +268,15 @@ impl SelectInner { for table in self.tables.iter() { match table { - SelectTable::TableName(x) => unsafe { + SelectReference::TableName(x) => unsafe { let x = unsafe { x.cast_bound_unchecked::(py) }; stmt.from(x.get().clone()); }, - SelectTable::FunctionCall(x, alias) => unsafe { + SelectReference::FunctionCall(x, alias) => unsafe { let x = unsafe { x.cast_bound_unchecked::(py) }; stmt.from_function(x.get().inner.lock().clone(), sea_query::Alias::new(alias)); }, - SelectTable::SubQuery(x, alias) => unsafe { + SelectReference::SubQuery(x, alias) => unsafe { let x = unsafe { x.cast_bound_unchecked::(py) }; let inner = x.get().inner.lock(); @@ -217,8 +287,8 @@ impl SelectInner { if !self.cols.is_empty() { stmt.exprs(self.cols.iter().map(|x| unsafe { - let expr = x.cast_bound_unchecked::(py); - expr.get().as_select_expr(py) + let expr = x.cast_bound_unchecked::(py); + expr.get().as_statement(py) })); } @@ -229,7 +299,7 @@ impl SelectInner { })); } - if let Some(x) = &self.r#where { + for x in &self.r#where { let x = unsafe { x.cast_bound_unchecked::(py) }; stmt.and_where(x.get().inner.clone()); } @@ -315,6 +385,13 @@ impl SelectInner { } } + if let Some((window_name, window)) = &self.window { + let window = unsafe { window.cast_bound_unchecked::(py) }; + let lock = window.get().inner.lock(); + + stmt.window(sea_query::Alias::new(window_name), lock.as_statement(py)); + } + stmt } } @@ -332,7 +409,7 @@ impl PySelect { let mut exprs = Vec::with_capacity(PyTupleMethods::len(cols)); for expr in PyTupleMethods::iter(cols) { - exprs.push(PySelectExpr::from_bound_into_any(&expr)?); + exprs.push(PySelectCol::from_bound_into_any(&expr)?); } let slf = Self { @@ -352,7 +429,7 @@ impl PySelect { ) -> pyo3::PyResult> { if PyTupleMethods::is_empty(on) { let mut lock = slf.inner.lock(); - lock.distinct = SelectDistinct::Distinct; + lock.distinct = DistinctMode::Distinct; } else { let mut cols = Vec::with_capacity(PyTupleMethods::len(on)); @@ -383,7 +460,7 @@ impl PySelect { } let mut lock = slf.inner.lock(); - lock.distinct = SelectDistinct::DistinctOn(cols); + lock.distinct = DistinctMode::DistinctOn(cols); } Ok(slf) @@ -397,7 +474,7 @@ impl PySelect { let mut exprs = Vec::with_capacity(PyTupleMethods::len(cols)); for expr in PyTupleMethods::iter(cols) { - exprs.push(PySelectExpr::from_bound_into_any(&expr)?); + exprs.push(PySelectCol::from_bound_into_any(&expr)?); } { @@ -426,7 +503,7 @@ impl PySelect { { let mut lock = slf.inner.lock(); - lock.tables.push(SelectTable::TableName(table)); + lock.tables.push(SelectReference::TableName(table)); } Ok(slf) @@ -460,7 +537,7 @@ impl PySelect { { let mut lock = slf.inner.lock(); - lock.tables.push(SelectTable::SubQuery(subquery, alias)); + lock.tables.push(SelectReference::SubQuery(subquery, alias)); } Ok(slf) @@ -487,7 +564,7 @@ impl PySelect { { let mut lock = slf.inner.lock(); - lock.tables.push(SelectTable::FunctionCall(function, alias)); + lock.tables.push(SelectReference::FunctionCall(function, alias)); } Ok(slf) @@ -519,7 +596,7 @@ impl PySelect { { let mut lock = slf.inner.lock(); - lock.r#where = Some(condition); + lock.r#where.push(condition); } Ok(slf) @@ -609,7 +686,7 @@ impl PySelect { { let mut lock = slf.inner.lock(); - lock.lock = Some(SelectLock { + lock.lock = Some(LockOptions { r#type, behavior, tables: tbs, @@ -728,7 +805,7 @@ impl PySelect { let expr = crate::expression::PyExpr::from_bound_into_any(on.clone())?; - let join_expr = SelectJoin { + let join_expr = JoinOptions { r#type, table, on: expr, @@ -785,7 +862,7 @@ impl PySelect { let expr = crate::expression::PyExpr::from_bound_into_any(on.clone())?; - let join_expr = SelectJoin { + let join_expr = JoinOptions { r#type, table: query.clone().unbind(), on: expr, @@ -800,6 +877,28 @@ impl PySelect { Ok(slf) } + #[pyo3(signature=(name, statement))] + fn window<'a>( + slf: pyo3::PyRef<'a, Self>, + name: String, + statement: &'a pyo3::Bound<'a, pyo3::PyAny>, + ) -> pyo3::PyResult> { + if !super::window::PyWindow::is_exact_type_of(statement) { + return Err(typeerror!( + "expected Window, got {:?}", + slf.py(), + statement.as_ptr() + )); + } + + { + let mut lock = slf.inner.lock(); + lock.window = Some((name, statement.clone().unbind())); + } + + Ok(slf) + } + fn build( &self, backend: &pyo3::Bound<'_, pyo3::PyAny>, diff --git a/src/query/update.rs b/src/query/update.rs index 8282ad3..5a0ea19 100644 --- a/src/query/update.rs +++ b/src/query/update.rs @@ -13,8 +13,8 @@ pub struct UpdateInner { // Always is `Vec` pub values: Vec<(String, pyo3::Py)>, - // Always is `Option` - pub r#where: Option>, + // Always is `Vec` + pub r#where: Vec>, pub limit: Option, pub orders: Vec, pub returning_clause: super::returning::ReturningClause, @@ -36,7 +36,7 @@ impl UpdateInner { stmt.from(x.get().clone()); } - if let Some(x) = &self.r#where { + for x in &self.r#where { let x = unsafe { x.cast_bound_unchecked::(py) }; stmt.and_where(x.get().inner.clone()); } @@ -197,7 +197,7 @@ impl PyUpdate { { let mut lock = slf.inner.lock(); - lock.r#where = Some(condition); + lock.r#where.push(condition); } Ok(slf) @@ -283,8 +283,16 @@ impl PyUpdate { if let Some(x) = lock.limit { write!(s, " limit={x}").unwrap(); } - if let Some(x) = &lock.r#where { - write!(s, " where={x}").unwrap(); + + write!(s, " where=[").unwrap(); + + let n = lock.r#where.len(); + for (index, expr) in lock.r#where.iter().enumerate() { + if index + 1 == n { + write!(s, "{expr}]").unwrap(); + } else { + write!(s, "{expr}, ").unwrap(); + } } write!(s, " orders=[").unwrap(); diff --git a/src/query/window.rs b/src/query/window.rs new file mode 100644 index 0000000..349085f --- /dev/null +++ b/src/query/window.rs @@ -0,0 +1,198 @@ +use pyo3::types::PyTupleMethods; +use sea_query::OverStatement; + +#[derive(Debug, Clone, PartialEq)] +pub struct FrameClause { + pub r#type: sea_query::FrameType, + pub start: sea_query::Frame, + pub end: Option, +} + +#[pyo3::pyclass(module = "rapidquery._lib", name = "WindowFrame", frozen)] +pub struct PyWindowFrame(sea_query::Frame); + +#[pyo3::pymethods] +impl PyWindowFrame { + #[classmethod] + fn unbounded_preceding(_cls: &pyo3::Bound<'_, pyo3::types::PyType>) -> Self { + Self(sea_query::Frame::UnboundedPreceding) + } + + #[classmethod] + fn current_row(_cls: &pyo3::Bound<'_, pyo3::types::PyType>) -> Self { + Self(sea_query::Frame::CurrentRow) + } + + #[classmethod] + fn unbounded_following(_cls: &pyo3::Bound<'_, pyo3::types::PyType>) -> Self { + Self(sea_query::Frame::UnboundedFollowing) + } + + #[classmethod] + fn following(_cls: &pyo3::Bound<'_, pyo3::types::PyType>, val: u32) -> Self { + Self(sea_query::Frame::Following(val)) + } + + #[classmethod] + fn preceding(_cls: &pyo3::Bound<'_, pyo3::types::PyType>, val: u32) -> Self { + Self(sea_query::Frame::Preceding(val)) + } +} + +#[derive(Default)] +pub struct WindowInner { + // Always is `Vec` + pub partition_by: Vec>, + pub orders: Vec, + pub frame: Option, +} + +impl WindowInner { + #[inline] + pub fn as_statement(&self, py: pyo3::Python) -> sea_query::WindowStatement { + let mut stmt = sea_query::WindowStatement::new(); + + for expr in &self.partition_by { + let expr = unsafe { expr.cast_bound_unchecked::(py) }; + stmt.add_partition_by(expr.get().inner.clone()); + } + + for order in self.orders.iter() { + let target = unsafe { order.target.cast_bound_unchecked::(py) }; + let target = target.get().inner.clone(); + + if let Some(x) = order.null_order { + stmt.order_by_expr_with_nulls(target, order.order.clone(), x); + } else { + stmt.order_by_expr(target, order.order.clone()); + } + } + + if let Some(x) = &self.frame { + stmt.frame(x.r#type.clone(), x.start.clone(), x.end.clone()); + } + + stmt + } +} + +#[pyo3::pyclass(module = "rapidquery._lib", name = "Window", frozen)] +pub struct PyWindow { + pub inner: parking_lot::Mutex, +} + +#[pyo3::pymethods] +impl PyWindow { + #[new] + #[pyo3(signature=(*partition_by))] + fn new(partition_by: &pyo3::Bound<'_, pyo3::types::PyTuple>) -> pyo3::PyResult { + let mut partitions = Vec::>::new(); + + unsafe { + for value in PyTupleMethods::iter(partition_by) { + partitions.push(crate::expression::PyExpr::from_bound_into_any(value)?); + } + } + + Ok(Self { + inner: parking_lot::Mutex::new(WindowInner { + partition_by: partitions, + orders: Vec::new(), + frame: None, + }), + }) + } + + #[pyo3(signature=(*partition_by))] + fn partition<'a>( + slf: pyo3::PyRef<'a, Self>, + partition_by: &pyo3::Bound<'a, pyo3::types::PyTuple>, + ) -> pyo3::PyResult> { + let mut partitions = Vec::>::new(); + + unsafe { + for value in PyTupleMethods::iter(partition_by) { + partitions.push(crate::expression::PyExpr::from_bound_into_any(value)?); + } + } + + { + let mut lock = slf.inner.lock(); + lock.partition_by.append(&mut partitions); + } + + Ok(slf) + } + + #[pyo3(signature=(target, order, null_order=None))] + fn order_by<'a>( + slf: pyo3::PyRef<'a, Self>, + target: pyo3::Bound<'_, pyo3::PyAny>, + order: String, + null_order: Option, + ) -> pyo3::PyResult> { + let order = super::order::OrderClause::from_parameters(target, order, null_order)?; + + { + let mut lock = slf.inner.lock(); + lock.orders.push(order); + } + + Ok(slf) + } + + #[pyo3(signature=(r#type, start, end=None))] + fn frame<'a>( + slf: pyo3::PyRef<'a, Self>, + mut r#type: String, + start: pyo3::Bound<'_, pyo3::PyAny>, + end: Option>, + ) -> pyo3::PyResult> { + let r#type = { + r#type.make_ascii_lowercase(); + + if r#type == "rows" { + sea_query::FrameType::Rows + } else if r#type == "range" { + sea_query::FrameType::Range + } else { + return Err(pyo3::PyErr::new::(format!( + "invalid frame type, expected 'rows' or 'range'; got {:?}", + r#type + ))); + } + }; + + let start = { + match start.cast::() { + Ok(x) => x.get().0.clone(), + Err(_) => { + return Err(typeerror!( + "expected WindowFrame, got {}", + slf.py(), + start.as_ptr() + )); + } + } + }; + + let end = { + match end { + None => None, + Some(y) => match y.cast::() { + Ok(x) => Some(x.get().0.clone()), + Err(_) => { + return Err(typeerror!("expected WindowFrame, got {}", slf.py(), y.as_ptr())); + } + }, + } + }; + + { + let mut lock = slf.inner.lock(); + lock.frame = Some(FrameClause { r#type, start, end }); + } + + Ok(slf) + } +} diff --git a/src/table/aliased.rs b/src/table/aliased.rs index 360c5ee..98bed7c 100644 --- a/src/table/aliased.rs +++ b/src/table/aliased.rs @@ -13,12 +13,11 @@ impl Py_AliasedTableColumnsSequence { let col_name = unsafe { let lock = slf.inner.lock(); - let (name, _) = lock - .columns - .iter() - .find(|(x, _)| x.eq(&name)) - .ok_or_else(|| pyo3::PyErr::new::(name.to_owned()))?; - + if !lock.columns.contains_key(&name) { + return Err(pyo3::PyErr::new::( + name.to_owned(), + )); + } sea_query::Alias::new(name) }; @@ -34,12 +33,11 @@ impl Py_AliasedTableColumnsSequence { let col_name = unsafe { let lock = slf.inner.lock(); - let (name, _) = lock - .columns - .iter() - .find(|(x, _)| x.eq(&name)) - .ok_or_else(|| pyo3::PyErr::new::(name.to_owned()))?; - + if !lock.columns.contains_key(&name) { + return Err(pyo3::PyErr::new::( + name.to_owned(), + )); + } sea_query::Alias::new(name) }; diff --git a/src/table/table.rs b/src/table/table.rs index 84dceb0..2d531af 100644 --- a/src/table/table.rs +++ b/src/table/table.rs @@ -1,13 +1,12 @@ use crate::backend::PySchemaStatement; use pyo3::types::PyAnyMethods; -type ColumnsSequence = Vec<(String, pyo3::Py)>; +type ColumnsSequence = indexmap::IndexMap>; pub struct TableInner { // Always is `TableName` pub name: pyo3::Py, - // TODO: use `indexmap` crate to optimize lookup from `O(n)` into `O(1)` // Always is `ColumnsSequence` pub columns: ColumnsSequence, @@ -127,9 +126,8 @@ impl Py_TableColumnsSequence { let lock = slf.inner.lock(); lock.columns - .iter() - .find(|(x, _)| x.eq(&name)) - .map(|(_, x)| x.clone_ref(slf.py())) + .get(&name) + .map(|x| x.clone_ref(slf.py())) .ok_or_else(|| pyo3::PyErr::new::(name.to_owned())) } @@ -137,9 +135,8 @@ impl Py_TableColumnsSequence { let lock = slf.inner.lock(); lock.columns - .iter() - .find(|(x, _)| x.eq(&name)) - .map(|(_, x)| x.clone_ref(slf.py())) + .get(&name) + .map(|x| x.clone_ref(slf.py())) .ok_or_else(|| pyo3::PyErr::new::(name.to_owned())) } @@ -159,7 +156,7 @@ impl Py_TableColumnsSequence { let name = colobj.name.clone(); drop(colobj); - lock.columns.push((name, col.unbind())); + lock.columns.insert(name, col.unbind()); } Ok(()) @@ -168,13 +165,11 @@ impl Py_TableColumnsSequence { fn remove(slf: pyo3::PyRef<'_, Self>, name: String) -> pyo3::PyResult> { let mut lock = slf.inner.lock(); - let position = lock + let x = lock .columns - .iter() - .position(|(x, _)| x.eq(&name)) + .shift_remove(&name) .ok_or_else(|| pyo3::PyErr::new::(name.to_owned()))?; - let (_, x) = lock.columns.remove(position); Ok(x.clone_ref(slf.py())) } @@ -252,7 +247,7 @@ impl PyTable { let name = colobj.name.clone(); drop(colobj); - cols.push((name, col)); + cols.insert(name, col); } } diff --git a/src/typeref.rs b/src/typeref.rs index 52e895a..c31b80a 100644 --- a/src/typeref.rs +++ b/src/typeref.rs @@ -51,6 +51,7 @@ pub(crate) static mut TABLE_NAME_TYPE: *mut pyo3::ffi::PyTypeObject = std::ptr:: pub(crate) static mut COLUMN_TYPE: *mut pyo3::ffi::PyTypeObject = std::ptr::null_mut(); pub(crate) static mut INDEX_COLUMN_TYPE: *mut pyo3::ffi::PyTypeObject = std::ptr::null_mut(); pub(crate) static mut SELECT_STATEMENT_TYPE: *mut pyo3::ffi::PyTypeObject = std::ptr::null_mut(); +pub(crate) static mut CASE_STATEMENT_TYPE: *mut pyo3::ffi::PyTypeObject = std::ptr::null_mut(); // Python standard libraries types pub(crate) static mut STD_DECIMAL_TYPE: *mut pyo3::ffi::PyTypeObject = std::ptr::null_mut(); @@ -130,6 +131,7 @@ fn _initialize_typeref(py: pyo3::Python) -> bool { COLUMN_TYPE = get_type_object_for::(py); INDEX_COLUMN_TYPE = get_type_object_for::(py); SELECT_STATEMENT_TYPE = get_type_object_for::(py); + CASE_STATEMENT_TYPE = get_type_object_for::(py); STD_DECIMAL_TYPE = look_up_type_object(c"decimal", c"Decimal"); STD_UUID_TYPE = look_up_type_object(c"uuid", c"UUID"); diff --git a/tests/test_adaptedvalue.py b/tests/test_adaptedvalue.py deleted file mode 100644 index 4848b5d..0000000 --- a/tests/test_adaptedvalue.py +++ /dev/null @@ -1,50 +0,0 @@ -from rapidquery import _lib -import pytest -import decimal -import datetime -import uuid - - -inferdata = [ - (12, "is_integer"), - (12.4, "is_float"), - ("string", "is_string"), - (b"bytes", "is_bytes"), - (True, "is_boolean"), - (list(), "is_json"), - (dict(), "is_json"), - ({2: 3}, "is_json"), - (decimal.Decimal(3.4), "is_decimal"), - (datetime.datetime.now(), "is_datetime"), - (datetime.datetime.now(tz=datetime.timezone.utc), "is_datetime"), - (datetime.datetime.now().date(), "is_date"), - (datetime.datetime.now().time(), "is_time"), - (uuid.uuid4(), "is_uuid"), - (None, "is_null"), -] -specificdata = [ - (12, "is_integer", _lib.TinyUnsignedType()), - (12.4, "is_float", _lib.FloatType()), - ("string", "is_json", _lib.JsonBinaryType()), - ("string", "is_string", _lib.EnumType("name", ["var1"])), - (b"bytes", "is_bytes", _lib.BlobType()), - (True, "is_json", _lib.JsonType()), - (list(), "is_array", _lib.ArrayType(_lib.IntegerType())), - (uuid.uuid4(), "is_uuid", _lib.UuidType()), -] - - -@pytest.mark.parametrize("value,attribute", inferdata) -def test_infer(value, attribute): - adapted = _lib.AdaptedValue(value) - assert getattr(adapted, attribute) is True - - _lib.Expr(adapted) # Force AdaptedValue to adapt - - -@pytest.mark.parametrize("value,attribute,typ", specificdata) -def test_specific_type(value, attribute, typ): - adapted = _lib.AdaptedValue(value, type=typ) - assert getattr(adapted, attribute) is True - - _lib.Expr(adapted) # Force AdaptedValue to adapt diff --git a/tests/test_adapting.py b/tests/test_adapting.py new file mode 100644 index 0000000..80e8ff8 --- /dev/null +++ b/tests/test_adapting.py @@ -0,0 +1,76 @@ +from collections import namedtuple +from datetime import datetime, timezone +import decimal +import pytest +import uuid + +import rapidquery as rq + + +NamedCase = namedtuple("NamedCase", ["data", "attribute", "type", "error"]) + + +TEST_CASES = [ + NamedCase(None, "is_null", None, False), + NamedCase(None, "is_null", rq.FloatType(), False), + NamedCase(True, "is_boolean", None, False), + NamedCase(False, "is_boolean", rq.FloatType(), True), + NamedCase(False, "is_boolean", rq.BooleanType(), False), + NamedCase(1, "is_integer", None, False), + NamedCase(-4, "is_integer", rq.IntegerType(), False), + NamedCase(3e-3, "is_integer", rq.IntegerType(), True), + NamedCase(3, "is_integer", rq.UnsignedType(), False), + NamedCase(-1, "is_integer", rq.UnsignedType(), True), + NamedCase(5e-3, "is_float", None, False), + NamedCase(5e-3, "is_float", rq.DoubleType(), False), + NamedCase(-4.5, "is_float", rq.FloatType(), False), + NamedCase("data", "is_string", None, False), + NamedCase("data", "is_string", rq.StringType(), False), + NamedCase("data", "is_string", rq.EnumType("a", ["a"]), False), + NamedCase("data", "is_string", rq.IntervalType(), False), + NamedCase("data", "is_string", rq.InetType(), False), + NamedCase("data", "is_string", rq.MacAddressType(), False), + NamedCase("data", "is_string", rq.CidrType(), False), + NamedCase("data", "is_string", rq.CharType(), False), + NamedCase(b"data", "is_bytes", None, False), + NamedCase(b"data", "is_bytes", rq.BitType(), False), + NamedCase({"name": "rq"}, "is_json", None, False), + NamedCase([], "is_json", None, False), + NamedCase(6, "is_json", rq.JsonBinaryType(), False), + NamedCase("data", "is_json", rq.JsonBinaryType(), False), + NamedCase(4.5, "is_json", rq.JsonBinaryType(), False), + NamedCase({1: "rq"}, "is_json", None, True), + NamedCase({1: "rq"}, "is_json", rq.JsonBinaryType(), True), + NamedCase([1, 2, 3], "is_array", rq.ArrayType(rq.TinyIntegerType()), False), + NamedCase([3, "b"], "is_array", rq.ArrayType(rq.TinyIntegerType()), True), + NamedCase(datetime.now(), "is_datetime", None, False), + NamedCase(datetime.now(tz=timezone.utc), "is_datetime", None, False), + NamedCase(datetime.now(), "is_datetime", rq.DateTimeType(), False), + NamedCase(datetime.now(tz=timezone.utc), "is_datetime", rq.TimestampType(), False), + NamedCase(datetime.now(tz=timezone.utc), "is_datetime", rq.TimestampWithTimeZoneType(), False), + NamedCase(datetime.now(), "is_datetime", rq.TimestampWithTimeZoneType(), False), + NamedCase(uuid.uuid4(), "is_uuid", None, False), + NamedCase(uuid.uuid4(), "is_uuid", rq.UuidType(), False), + NamedCase(uuid.uuid4().hex, "is_uuid", rq.UuidType(), True), + NamedCase(decimal.Decimal("1.2"), "is_decimal", None, False), + NamedCase(decimal.Decimal("1.2"), "is_decimal", rq.DecimalType(), False), + NamedCase(decimal.Decimal("1.2"), "is_decimal", rq.FloatType(), True), + NamedCase(1.2, "is_decimal", rq.DecimalType(), True), + NamedCase([1.3, 2.1, 3], "is_vector", rq.VectorType(), False), + NamedCase([3, "b"], "is_vector", rq.VectorType(), True), +] + + +@pytest.mark.parametrize("case", TEST_CASES) +def test_adaptedvalue(case: NamedCase): + try: + val = rq.AdaptedValue(case.data, case.type) + except (ValueError, TypeError, OverflowError): + if case.error: + return + + raise + + assert getattr(val, case.attribute) + + rq.Expr(val) # Force AdaptedValue to adapt diff --git a/tests/test_column.py b/tests/test_column.py index 6dcdb4c..bf25de0 100644 --- a/tests/test_column.py +++ b/tests/test_column.py @@ -1,8 +1,9 @@ -from rapidquery import _lib import dataclasses import typing import pytest +import rapidquery as rq + @dataclasses.dataclass class ColumnTestCase: @@ -16,88 +17,88 @@ class ColumnTestCase: comment: str | None = None default_expr: str = "" stored_generated: bool = False - column_ref: typing.Optional[_lib.ColumnRef] = None + column_ref: typing.Optional[rq.ColumnRef] = None def test_different_types(): # Simple - ty = _lib.IntegerType() - assert ty == _lib.IntegerType() + ty = rq.IntegerType() + assert ty == rq.IntegerType() assert repr(ty) == "" # Length - ty = _lib.StringType(None) - assert ty == _lib.StringType() + ty = rq.StringType(None) + assert ty == rq.StringType() assert ty.length is None assert repr(ty) == "" - ty = _lib.StringType(20) - assert ty != _lib.StringType(30) - assert ty != _lib.StringType(None) - assert ty == _lib.StringType(20) + ty = rq.StringType(20) + assert ty != rq.StringType(30) + assert ty != rq.StringType(None) + assert ty == rq.StringType(20) assert ty.length == 20 assert repr(ty) == "" # Percision Scale - ty = _lib.MoneyType() - assert ty == _lib.MoneyType() + ty = rq.MoneyType() + assert ty == rq.MoneyType() assert ty.precision_scale is None assert repr(ty) == "" - ty = _lib.MoneyType((10, 8)) - assert ty != _lib.MoneyType((4, 6)) - assert ty != _lib.MoneyType(None) - assert ty == _lib.MoneyType((10, 8)) + ty = rq.MoneyType((10, 8)) + assert ty != rq.MoneyType((4, 6)) + assert ty != rq.MoneyType(None) + assert ty == rq.MoneyType((10, 8)) assert ty.precision_scale == (10, 8) assert repr(ty) == "" # Enum - ty = _lib.EnumType("priority", ["low", "medium"]) + ty = rq.EnumType("priority", ["low", "medium"]) assert ty.name == "priority" assert ty.variants == ["low", "medium"] - assert ty == _lib.EnumType("priority", ["low", "medium"]) - assert ty != _lib.EnumType("priority", ["low", "medium", "high"]) + assert ty == rq.EnumType("priority", ["low", "medium"]) + assert ty != rq.EnumType("priority", ["low", "medium", "high"]) # Array try: - ty = _lib.ArrayType(str) + ty = rq.ArrayType(str) except Exception: pass else: pytest.fail() - ty = _lib.ArrayType(_lib.TextType()) - assert ty.element == _lib.TextType() + ty = rq.ArrayType(rq.TextType()) + assert ty.element == rq.TextType() # Interval try: - ty = _lib.IntervalType(5983) + ty = rq.IntervalType(5983) except Exception: pass else: pytest.fail() - ty = _lib.IntervalType(_lib.INTERVAL_DAY_TO_MINUTE) - assert ty.fields == _lib.INTERVAL_DAY_TO_MINUTE + ty = rq.IntervalType(rq.INTERVAL_DAY_TO_MINUTE) + assert ty.fields == rq.INTERVAL_DAY_TO_MINUTE assert ty.precision is None - ty = _lib.IntervalType(_lib.INTERVAL_HOUR, 5) - assert ty.fields == _lib.INTERVAL_HOUR + ty = rq.IntervalType(rq.INTERVAL_HOUR, 5) + assert ty.fields == rq.INTERVAL_HOUR assert ty.precision == 5 -_metadata_column = _lib.Column( - "metadata", _lib.ArrayType(_lib.IntegerType()), nullable=True, default=[1, 2, 3] +_metadata_column = rq.Column( + "metadata", rq.ArrayType(rq.IntegerType()), nullable=True, default=[1, 2, 3] ) -_lib.Table("users", [_metadata_column]) +rq.Table("users", [_metadata_column]) columndata = [ ( - _lib.Column( + rq.Column( "id", - _lib.BigIntegerType(), + rq.BigIntegerType(), primary_key=True, nullable=False, auto_increment=True, @@ -105,29 +106,29 @@ def test_different_types(): ), ColumnTestCase( "id", - _lib.BigIntegerType, + rq.BigIntegerType, primary_key=True, nullable=False, auto_increment=True, default_expr="1", - column_ref=_lib.ColumnRef("id"), + column_ref=rq.ColumnRef("id"), ), ), ( _metadata_column, ColumnTestCase( "metadata", - _lib.ArrayType, + rq.ArrayType, nullable=True, default_expr="ARRAY [1,2,3]", - column_ref=_lib.ColumnRef("metadata", table="users"), + column_ref=rq.ColumnRef("metadata", table="users"), ), ), ] @pytest.mark.parametrize("val,case", columndata) -def test_column(val: _lib.Column, case: ColumnTestCase): +def test_column(val: rq.Column, case: ColumnTestCase): assert val.name == case.name assert val.primary_key == case.primary_key assert val.unique == case.unique diff --git a/tests/test_edgecases.py b/tests/test_edgecases.py index c19f7d9..cd62be9 100644 --- a/tests/test_edgecases.py +++ b/tests/test_edgecases.py @@ -541,3 +541,46 @@ def test_interval_with_negative_precision(self): """Interval with negative precision (invalid).""" with pytest.raises(OverflowError): _lib.IntervalType(precision=-1) + + +class TestWhereChaining: + def test_select(self): + query = ( + _lib.Select(_lib.ASTERISK) + .from_table("users") + .where(_lib.Expr.col("id") > 10) + .where(_lib.Expr.col("id") < 20) + ) + + assert "AND" in query.to_sql("postgresql") + + def test_delete(self): + query = ( + _lib.Delete() + .from_table("users") + .where(_lib.Expr.col("id") > 10) + .where(_lib.Expr.col("id") < 20) + ) + + assert "AND" in query.to_sql("postgresql") + + def test_update(self): + query = ( + _lib.Update() + .table("users") + .where(_lib.Expr.col("id") > 10) + .where(_lib.Expr.col("id") < 20) + ) + + assert "AND" in query.to_sql("postgresql") + + +class TestCase: + def test_to_expr(self): + query = ( + _lib.Case() + .when(_lib.Case().when(_lib.Expr.col("id") == 1, True).else_(False), True) + .else_(False) + ) + + assert query.to_expr().to_sql("postgresql").count("CASE") == 2 diff --git a/tests/test_expression.py b/tests/test_expression.py index fc268f4..6727ac2 100644 --- a/tests/test_expression.py +++ b/tests/test_expression.py @@ -1,59 +1,142 @@ -from rapidquery import _lib +from dataclasses import dataclass import pytest +import typing +import rapidquery as rq -exprdata = [ - (_lib.Expr(3) == 3, "3 = 3", "postgres"), - ( - _lib.Expr.col("name").cast_as("VARCHAR(1000)").cast_as("hierarchy_path"), + +@dataclass +class SQLCase: + expr: rq.Expr + expected: str + backend: str + + +@dataclass +class DifferentInputCase: + value: typing.Any + sqlcontain: str + error: bool + + +sqlcases = [ + SQLCase(rq.Expr(3) == 3, "3 = 3", "postgres"), + SQLCase( + rq.Expr.col("name").cast_as("VARCHAR(1000)").cast_as("hierarchy_path"), 'CAST(CAST("name" AS VARCHAR(1000)) AS hierarchy_path)', "postgres", ), - ( - (_lib.Expr.col("oh.level") + 1).between(24, 26), + SQLCase( + (rq.Expr.col("oh.level") + 1).between(24, 26), '"oh"."level" + 1 BETWEEN 24 AND 26', "postgres", ), - ( - (_lib.Expr.col("oh.level") + 1).between(24, 26), + SQLCase( + (rq.Expr.col("oh.level") + 1).between(24, 26), '"oh"."level" + 1 BETWEEN 24 AND 26', "postgres", ), - ( - _lib.FunctionCall.max(_lib.Expr(_lib.ColumnRef("id"))).to_expr() == 9, + SQLCase( + rq.FunctionCall.max(rq.Expr(rq.ColumnRef("id"))).to_expr() == 9, 'MAX("id") = 9', "postgres", ), - ( - _lib.all(_lib.Expr(_lib.ASTERISK).is_null(), _lib.Expr(None).is_null()), + SQLCase( + rq.all(rq.Expr(rq.ASTERISK).is_null(), rq.Expr(None).is_null()), "* IS NULL AND NULL IS NULL", "postgres", ), - ( - _lib.any(_lib.Expr.current_date(), _lib.Expr.current_time()), + SQLCase( + rq.any(rq.Expr.current_date(), rq.Expr.current_time()), "CURRENT_DATE OR CURRENT_TIME", "postgres", ), - ( - _lib.not_(_lib.FunctionCall.count(_lib.Expr(_lib.ASTERISK)).to_expr() == 1), + SQLCase( + rq.not_(rq.FunctionCall.count(rq.Expr(rq.ASTERISK)).to_expr() == 1), "NOT COUNT(*) = 1", "postgres", ), ] -@pytest.mark.parametrize("val,expected,backend", exprdata) -def test_expr_build(val: _lib.Expr, expected: str, backend: str): - expr = val.to_sql(backend) - assert expr == expected +@pytest.mark.parametrize("case", sqlcases) +def test_expr_build(case: SQLCase): + expr = case.expr.to_sql(case.backend) + assert expr == case.expected -class Unknown: - pass +inputcases = [ + DifferentInputCase( + rq.Expr.custom("CUSTOM"), + "CUSTOM", + False, + ), + DifferentInputCase( + rq.AdaptedValue(1), + "1", + False, + ), + DifferentInputCase( + rq.ColumnRef("id"), + '"id"', + False, + ), + DifferentInputCase( + rq.Column("id", rq.IntegerType()), + '"id"', + False, + ), + DifferentInputCase( + (1, "rapidquery", 3), + "(1, 'rapidquery', 3)", + False, + ), + DifferentInputCase( + rq.ASTERISK, + "*", + False, + ), + DifferentInputCase( + rq.Select(1), + "SELECT", + False, + ), + DifferentInputCase( + rq.Case().when(rq.Expr.col("id") == 1, True), + "CASE WHEN", + False, + ), + DifferentInputCase( + rq.FunctionCall.avg(rq.Expr.asterisk()), + "AVG(*)", + False, + ), + DifferentInputCase( + rq.IntegerType(), + "", + True, + ), +] + + +@pytest.mark.parametrize("case", inputcases) +def test_input_value(case: DifferentInputCase): + try: + expr = rq.Expr(case.value) + except (TypeError, ValueError, OverflowError): + if case.error: + return + + raise + + assert expr.to_sql("postgresql").find(case.sqlcontain) > -1 def test_invalid_expr(): + class Unknown: + pass + try: - _lib.Expr(Unknown()) + rq.Expr(Unknown()) except ValueError: pass