Skip to content
This repository was archived by the owner on Apr 15, 2025. It is now read-only.
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 64 additions & 1 deletion databases/sync_tests/test_transactions.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from prisma import Prisma
from prisma.models import User, Profile

from ..utils import CURRENT_DATABASE
from ..utils import CURRENT_DATABASE, ISOLATION_LEVELS_MAPPING, RawQueries


def test_model_query(client: Prisma) -> None:
Expand Down Expand Up @@ -201,3 +201,66 @@ def test_transaction_already_closed(client: Prisma) -> None:
transaction.user.delete_many()

assert exc.match('Transaction already closed')


@pytest.mark.parametrize(
('input_level',),
[
pytest.param(
'READ_UNCOMMITTED',
id='read uncommitted',
marks=pytest.mark.skipif(CURRENT_DATABASE in ['cockroachdb', 'sqlite'], reason='Not available'),
),
pytest.param(
'READ_COMMITTED',
id='read committed',
marks=pytest.mark.skipif(CURRENT_DATABASE in ['cockroachdb', 'sqlite'], reason='Not available'),
),
pytest.param(
'REPEATABLE_READ',
id='repeatable read',
marks=pytest.mark.skipif(CURRENT_DATABASE in ['cockroachdb', 'sqlite'], reason='Not available'),
),
pytest.param(
'SNAPSHOT',
id='snapshot',
marks=pytest.mark.skipif(CURRENT_DATABASE != 'sqlserver', reason='Not available'),
),
pytest.param(
'SERIALIZABLE',
id='serializable',
marks=pytest.mark.skipif(
CURRENT_DATABASE == 'sqlite',
reason="SQLite doesn't have the way to query the current transaction isolation level",
),
),
],
)
@pytest.mark.skipif(CURRENT_DATABASE == 'mongodb', reason='Not available')
@pytest.mark.skipif(
CURRENT_DATABASE in ['mysql', 'mariadb'],
reason="""
MySQL 8.0 doesn't have the way to query the current transaction isolation level.
See https://bugs.mysql.com/bug.php?id=53341

Refs:
* https://github.com/prisma/prisma/issues/22890
""",
)
def test_isolation_level(
client: Prisma,
database: str,
raw_queries: RawQueries,
input_level: str,
) -> None:
"""Ensure that transaction isolation level is set correctly"""
with client.tx(isolation_level=getattr(prisma.TransactionIsolationLevel, input_level)) as tx:
results = tx.query_raw(raw_queries.select_tx_isolation)

assert len(results) == 1

row = results[0]
assert any(row)

level = next(iter(row.values()))
assert level == ISOLATION_LEVELS_MAPPING[database][input_level]
66 changes: 65 additions & 1 deletion databases/tests/test_transactions.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from prisma import Prisma
from prisma.models import User, Profile

from ..utils import CURRENT_DATABASE
from ..utils import CURRENT_DATABASE, ISOLATION_LEVELS_MAPPING, RawQueries


@pytest.mark.asyncio
Expand Down Expand Up @@ -212,3 +212,67 @@ async def test_transaction_already_closed(client: Prisma) -> None:
await transaction.user.delete_many()

assert exc.match('Transaction already closed')


@pytest.mark.asyncio
@pytest.mark.parametrize(
('input_level',),
[
pytest.param(
'READ_UNCOMMITTED',
id='read uncommitted',
marks=pytest.mark.skipif(CURRENT_DATABASE in ['cockroachdb', 'sqlite'], reason='Not available'),
),
pytest.param(
'READ_COMMITTED',
id='read committed',
marks=pytest.mark.skipif(CURRENT_DATABASE in ['cockroachdb', 'sqlite'], reason='Not available'),
),
pytest.param(
'REPEATABLE_READ',
id='repeatable read',
marks=pytest.mark.skipif(CURRENT_DATABASE in ['cockroachdb', 'sqlite'], reason='Not available'),
),
pytest.param(
'SNAPSHOT',
id='snapshot',
marks=pytest.mark.skipif(CURRENT_DATABASE != 'sqlserver', reason='Not available'),
),
pytest.param(
'SERIALIZABLE',
id='serializable',
marks=pytest.mark.skipif(
CURRENT_DATABASE == 'sqlite',
reason="SQLite doesn't have the way to query the current transaction isolation level",
),
),
],
)
@pytest.mark.skipif(CURRENT_DATABASE == 'mongodb', reason='Not available')
@pytest.mark.skipif(
CURRENT_DATABASE in ['mysql', 'mariadb'],
reason="""
MySQL 8.0 doesn't have the way to query the current transaction isolation level.
See https://bugs.mysql.com/bug.php?id=53341

Refs:
* https://github.com/prisma/prisma/issues/22890
""",
)
async def test_isolation_level(
client: Prisma,
database: str,
raw_queries: RawQueries,
input_level: str,
) -> None:
"""Ensure that transaction isolation level is set correctly"""
async with client.tx(isolation_level=getattr(prisma.TransactionIsolationLevel, input_level)) as tx:
results = await tx.query_raw(raw_queries.select_tx_isolation)

assert len(results) == 1

row = results[0]
assert any(row)

level = next(iter(row.values()))
assert level == ISOLATION_LEVELS_MAPPING[database][input_level]
63 changes: 61 additions & 2 deletions databases/utils.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
from __future__ import annotations

import os
from typing import Set
from typing import Set, Optional
from pathlib import Path
from typing_extensions import Literal, get_args, override
from typing_extensions import Literal, TypedDict, get_args, override

from pydantic import BaseModel
from syrupy.location import PyTestLocation
Expand Down Expand Up @@ -86,6 +86,8 @@ class RawQueries(BaseModel):
test_query_raw_no_result: LiteralString
test_execute_raw_no_result: LiteralString

select_tx_isolation: LiteralString


_mysql_queries = RawQueries(
count_posts="""
Expand Down Expand Up @@ -137,8 +139,12 @@ class RawQueries(BaseModel):
SET title = 'updated title'
WHERE id = 'sdldsd'
""",
select_tx_isolation="""
SELECT @@transaction_isolation
""",
)


_postgresql_queries = RawQueries(
count_posts="""
SELECT COUNT(*) as count
Expand Down Expand Up @@ -189,6 +195,9 @@ class RawQueries(BaseModel):
SET title = 'updated title'
WHERE id = 'sdldsd'
""",
select_tx_isolation="""
SHOW transaction_isolation
""",
)

RAW_QUERIES_MAPPING: DatabaseMapping[RawQueries] = {
Expand Down Expand Up @@ -246,5 +255,55 @@ class RawQueries(BaseModel):
SET title = 'updated title'
WHERE id = 'sdldsd'
""",
select_tx_isolation="""
Not avaliable
""",
),
}


class IsolationLevels(TypedDict):
READ_UNCOMMITTED: Optional[LiteralString]
READ_COMMITTED: Optional[LiteralString]
REPEATABLE_READ: Optional[LiteralString]
SNAPSHOT: Optional[LiteralString]
SERIALIZABLE: Optional[LiteralString]


ISOLATION_LEVELS_MAPPING: DatabaseMapping[IsolationLevels] = {
'postgresql': {
'READ_UNCOMMITTED': 'read uncommitted',
'READ_COMMITTED': 'read committed',
'REPEATABLE_READ': 'repeatable read',
'SNAPSHOT': None,
'SERIALIZABLE': 'serializable',
},
'cockroachdb': {
'READ_UNCOMMITTED': None,
'READ_COMMITTED': None,
'REPEATABLE_READ': None,
'SNAPSHOT': None,
'SERIALIZABLE': 'SERIALIZABLE',
},
'mysql': {
'READ_UNCOMMITTED': 'READ-UNCOMMITTED',
'READ_COMMITTED': 'READ-COMMITTED',
'REPEATABLE_READ': 'REPEATABLE-READ',
'SNAPSHOT': None,
'SERIALIZABLE': 'SERIALIZABLE',
},
'mariadb': {
'READ_UNCOMMITTED': 'READ-UNCOMMITTED',
'READ_COMMITTED': 'READ-COMMITTED',
'REPEATABLE_READ': 'REPEATABLE-READ',
'SNAPSHOT': None,
'SERIALIZABLE': 'SERIALIZABLE',
},
'sqlite': {
'READ_UNCOMMITTED': None,
'READ_COMMITTED': None,
'REPEATABLE_READ': None,
'SNAPSHOT': None,
'SERIALIZABLE': None,
},
}
17 changes: 17 additions & 0 deletions docs/reference/transactions.md
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,23 @@ In the case that this example runs successfully, then both database writes are c
)
```

## Isolation levels

By default, Prisma sets the isolation level to the value currently configured in the database. You can modify this
default with the `isolation_level` argument (see [supported isolation levels](https://www.prisma.io/docs/orm/prisma-client/queries/transactions#supported-isolation-levels)).

!!! note
Prisma Client Python generates `TransactionIsolationLevel` enumeration that includes only the options supported by the current database.

```py
from prisma import Prisma, TransactionIsolationLevel

client = Prisma()
client.tx(
isolation_level=TransactionIsolationLevel.READ_UNCOMMITTED,
)
```

## Timeouts

You can pass the following options to configure how timeouts are applied to your transaction:
Expand Down
51 changes: 32 additions & 19 deletions src/prisma/_transactions.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,24 +3,31 @@
import logging
import warnings
from types import TracebackType
from typing import TYPE_CHECKING, Generic, TypeVar
from typing import TYPE_CHECKING, Any, Generic, TypeVar
from datetime import timedelta

from ._types import TransactionId
from .errors import TransactionNotStartedError
from ._compat import StrEnum
from ._builder import dumps

if TYPE_CHECKING:
from ._base_client import SyncBasePrisma, AsyncBasePrisma

log: logging.Logger = logging.getLogger(__name__)

__all__ = (
'AsyncTransactionManager',
'SyncTransactionManager',
)


_SyncPrismaT = TypeVar('_SyncPrismaT', bound='SyncBasePrisma')
_AsyncPrismaT = TypeVar('_AsyncPrismaT', bound='AsyncBasePrisma')
_IsolationLevelT = TypeVar('_IsolationLevelT', bound=StrEnum)


class AsyncTransactionManager(Generic[_AsyncPrismaT]):
class AsyncTransactionManager(Generic[_AsyncPrismaT, _IsolationLevelT]):
"""Context manager for wrapping a Prisma instance within a transaction.

This should never be created manually, instead it should be used
Expand All @@ -33,8 +40,10 @@ def __init__(
client: _AsyncPrismaT,
max_wait: int | timedelta,
timeout: int | timedelta,
isolation_level: _IsolationLevelT | None,
) -> None:
self.__client = client
self._isolation_level = isolation_level

if isinstance(max_wait, int):
message = (
Expand Down Expand Up @@ -71,14 +80,15 @@ async def start(self, *, _from_context: bool = False) -> _AsyncPrismaT:
stacklevel=3 if _from_context else 2,
)

tx_id = await self.__client._engine.start_transaction(
content=dumps(
{
'timeout': int(self._timeout.total_seconds() * 1000),
'max_wait': int(self._max_wait.total_seconds() * 1000),
}
),
)
content_dict: dict[str, Any] = {
'timeout': int(self._timeout.total_seconds() * 1000),
'max_wait': int(self._max_wait.total_seconds() * 1000),
}
if self._isolation_level is not None:
content_dict['isolation_level'] = self._isolation_level.value

tx_id = await self.__client._engine.start_transaction(content=dumps(content_dict))

self._tx_id = tx_id
client = self.__client._copy()
client._tx_id = tx_id
Expand Down Expand Up @@ -122,7 +132,7 @@ async def __aexit__(
)


class SyncTransactionManager(Generic[_SyncPrismaT]):
class SyncTransactionManager(Generic[_SyncPrismaT, _IsolationLevelT]):
"""Context manager for wrapping a Prisma instance within a transaction.

This should never be created manually, instead it should be used
Expand All @@ -135,8 +145,10 @@ def __init__(
client: _SyncPrismaT,
max_wait: int | timedelta,
timeout: int | timedelta,
isolation_level: _IsolationLevelT | None,
) -> None:
self.__client = client
self._isolation_level = isolation_level

if isinstance(max_wait, int):
message = (
Expand Down Expand Up @@ -173,14 +185,15 @@ def start(self, *, _from_context: bool = False) -> _SyncPrismaT:
stacklevel=3 if _from_context else 2,
)

tx_id = self.__client._engine.start_transaction(
content=dumps(
{
'timeout': int(self._timeout.total_seconds() * 1000),
'max_wait': int(self._max_wait.total_seconds() * 1000),
}
),
)
content_dict: dict[str, Any] = {
'timeout': int(self._timeout.total_seconds() * 1000),
'max_wait': int(self._max_wait.total_seconds() * 1000),
}
if self._isolation_level is not None:
content_dict['isolation_level'] = self._isolation_level.value

tx_id = self.__client._engine.start_transaction(content=dumps(content_dict))

self._tx_id = tx_id
client = self.__client._copy()
client._tx_id = tx_id
Expand Down
Loading