|
| 1 | +from http import HTTPStatus |
| 2 | + |
| 3 | +from fastapi import HTTPException |
| 4 | +from pydantic import BaseModel, ConfigDict |
1 | 5 | from pytest import fixture |
2 | 6 | from sqlalchemy import select, text |
| 7 | +from sqlalchemy.exc import IntegrityError |
| 8 | +from sqlalchemy.orm import Mapped, mapped_column |
| 9 | + |
| 10 | + |
| 11 | +@fixture(autouse=True) |
| 12 | +async def setup_tear_down(engine): |
| 13 | + async with engine.connect() as conn: |
| 14 | + await conn.execute( |
| 15 | + text(""" |
| 16 | + CREATE TABLE user ( |
| 17 | + id INTEGER PRIMARY KEY AUTOINCREMENT, |
| 18 | + email TEXT UNIQUE NOT NULL, |
| 19 | + name TEXT NOT NULL |
| 20 | + ) |
| 21 | + """) |
| 22 | + ) |
3 | 23 |
|
4 | 24 |
|
5 | 25 | @fixture |
6 | | -def app(app): |
7 | | - from fastapi_async_sqla import Session |
| 26 | +def app(setup_tear_down, app): |
| 27 | + from fastapi_async_sqla import Base, Item, Session |
| 28 | + |
| 29 | + class User(Base): |
| 30 | + __tablename__ = "user" |
| 31 | + id: Mapped[int] = mapped_column(primary_key=True) |
| 32 | + email: Mapped[str] = mapped_column(unique=True) |
| 33 | + name: Mapped[str] |
| 34 | + |
| 35 | + class UserIn(BaseModel): |
| 36 | + email: str |
| 37 | + name: str |
| 38 | + |
| 39 | + class UserModel(UserIn): |
| 40 | + model_config = ConfigDict(from_attributes=True) |
| 41 | + id: int |
8 | 42 |
|
9 | 43 | @app.get("/session-dependency") |
10 | 44 | async def get_session(session: Session): |
11 | 45 | res = await session.execute(select(text("'OK'"))) |
12 | 46 | return {"data": res.scalar()} |
13 | 47 |
|
| 48 | + @app.post("/users", response_model=Item[UserModel], status_code=HTTPStatus.CREATED) |
| 49 | + async def create_user(user_in: UserIn, session: Session): |
| 50 | + user = User(**user_in.model_dump()) |
| 51 | + user_in.model_dump |
| 52 | + session.add(user) |
| 53 | + try: |
| 54 | + await session.flush() |
| 55 | + except IntegrityError: |
| 56 | + raise HTTPException(status_code=400) |
| 57 | + return {"data": user} |
| 58 | + |
14 | 59 | return app |
15 | 60 |
|
16 | 61 |
|
17 | 62 | async def test_it(client): |
18 | | - response = await client.get("/session-dependency") |
19 | | - assert response.status_code == 200 |
20 | | - assert response.json() == {"data": "OK"} |
| 63 | + res = await client.get("/session-dependency") |
| 64 | + assert res.status_code == HTTPStatus.OK, (res.status_code, res.content) |
| 65 | + assert res.json() == {"data": "OK"} |
| 66 | + |
| 67 | + |
| 68 | +async def test_session_is_commited(client, session): |
| 69 | + payload = { "email": "[email protected]", "name": "Bobby"} |
| 70 | + res = await client.post("/users", json=payload) |
| 71 | + |
| 72 | + assert res.status_code == HTTPStatus.CREATED, (res.status_code, res.content) |
| 73 | + |
| 74 | + all_users = (await session.execute(text("SELECT * FROM user"))).mappings().all() |
| 75 | + assert all_users == [{"id": 1, **payload}] |
| 76 | + |
| 77 | + |
| 78 | +@fixture |
| 79 | +async def bob_exists(session): |
| 80 | + await session.execute( |
| 81 | + text( "INSERT INTO user (email, name) VALUES ('[email protected]', 'Bobby')") |
| 82 | + ) |
| 83 | + await session.commit() |
| 84 | + yield |
| 85 | + |
| 86 | + |
| 87 | +async def test_with_an_integrity_error(client, bob_exists): |
| 88 | + res = await client. post( "/users", json={ "email": "[email protected]", "name": "Bobby"}) |
| 89 | + assert res.status_code == HTTPStatus.BAD_REQUEST, (res.status_code, res.content) |
0 commit comments