diff --git a/api/requirements.txt b/api/requirements.txt index dba2e1357..fd9dc636f 100644 --- a/api/requirements.txt +++ b/api/requirements.txt @@ -1,7 +1,7 @@ # Runtime requirements flask flask_cors -sqlalchemy>=1.4.9,<2.0 +sqlalchemy~=2.0.0 gunicorn rocky>=1,<2 PyMySQL diff --git a/api/src/core/auth.py b/api/src/core/auth.py index a04fcdf2f..d1f3a1234 100644 --- a/api/src/core/auth.py +++ b/api/src/core/auth.py @@ -131,9 +131,8 @@ def password_reset(reset_token, unhashed_password): except ValueError as e: raise BadRequest(str(e)) - try: - member = db_session.query(Member).get(password_reset_token.member_id) - except NoResultFound: + member = db_session.get(Member, password_reset_token.member_id) + if member is None: raise InternalServerError(log=f"No member with id {password_reset_token.member_id} found, this is a bug.") member.password = hashed_password diff --git a/api/src/core/models.py b/api/src/core/models.py index cfb2cfc3c..dd7e7d3ca 100644 --- a/api/src/core/models.py +++ b/api/src/core/models.py @@ -1,7 +1,6 @@ from service.db import db_session from sqlalchemy import Column, DateTime, Integer, String, Text, func, text -from sqlalchemy.ext.declarative import declarative_base -from sqlalchemy.orm import configure_mappers +from sqlalchemy.orm import configure_mappers, declarative_base Base = declarative_base() @@ -46,19 +45,21 @@ class Login: @staticmethod def register_login_failed(ip): - db_session.execute("INSERT INTO login (success, ip) VALUES (0, :ip)", {"ip": ip}) + db_session.execute(text("INSERT INTO login (success, ip) VALUES (0, :ip)"), {"ip": ip}) @staticmethod def register_login_success(ip, user_id): db_session.execute( - "INSERT INTO login (success, user_id, ip) VALUES (1, :user_id, :ip)", {"user_id": user_id, "ip": ip} + text("INSERT INTO login (success, user_id, ip) VALUES (1, :user_id, :ip)"), {"user_id": user_id, "ip": ip} ) @staticmethod def get_failed_login_count(ip): (count,) = db_session.execute( - "SELECT count(1) FROM login" - " WHERE ip = :ip AND NOT success AND date >= DATE_SUB(NOW(), INTERVAL 1 HOUR)", + text( + "SELECT count(1) FROM login" + " WHERE ip = :ip AND NOT success AND date >= DATE_SUB(NOW(), INTERVAL 1 HOUR)" + ), {"ip": ip}, ).fetchone() return count diff --git a/api/src/firstrun.py b/api/src/firstrun.py index b67e57c53..ab71bc881 100644 --- a/api/src/firstrun.py +++ b/api/src/firstrun.py @@ -134,7 +134,7 @@ def get_password(): member_id = member["member_id"] logger.info(f"Adding new member {member_id} to admin group.") - admins.members.append(db_session.query(Member).get(member_id)) + admins.members.append(db_session.get(Member, member_id)) db_session.commit() break except Exception as e: diff --git a/api/src/init_db.py b/api/src/init_db.py index f8fd5ab5a..05ddce1a8 100755 --- a/api/src/init_db.py +++ b/api/src/init_db.py @@ -11,6 +11,7 @@ from rocky.process import log_exception from service.config import get_mysql_config from service.db import create_mysql_engine +from sqlalchemy import text from sqlalchemy.orm import sessionmaker from sqlalchemy.orm.exc import MultipleResultsFound, NoResultFound @@ -18,7 +19,7 @@ def clear_permission_cache(session_factory): """Clear permisssion cache as a part of every db_init/restart.""" with closing(session_factory()) as session: - session.execute("UPDATE access_tokens SET permissions = NULL") + session.execute(text("UPDATE access_tokens SET permissions = NULL")) session.commit() diff --git a/api/src/member/member.py b/api/src/member/member.py index c6382c71b..7ee42fa84 100644 --- a/api/src/member/member.py +++ b/api/src/member/member.py @@ -37,7 +37,7 @@ def send_access_token_email(redirect, user_identification, ip, browser): def send_updated_member_info_email(member_id: int, msg_swe: str, msg_en: str): - member = db_session.query(Member).get(member_id) + member = db_session.get(Member, member_id) logger.info( f"sending email about updated personal information to member_id {member.member_id} with message {msg_en=}," @@ -56,7 +56,7 @@ def send_updated_member_info_email(member_id: int, msg_swe: str, msg_en: str): def set_pin_code(member_id: int, pin_code: str): - member: Member = db_session.query(Member).get(member_id) + member: Member = db_session.get(Member, member_id) member.pin_code = pin_code try: diff --git a/api/src/member/views.py b/api/src/member/views.py index 98339778d..ad80d5dfb 100644 --- a/api/src/member/views.py +++ b/api/src/member/views.py @@ -40,7 +40,7 @@ def current_member(): # Expose if the member has a password set, but not what the password is (not even the hash) assert m is not None - m2 = db_session.query(Member).get(g.user_id) + m2 = db_session.get(Member, g.user_id) assert m2 is not None m["has_password"] = m2.password is not None diff --git a/api/src/membership/member_entity.py b/api/src/membership/member_entity.py index 85d410a74..3b62faebb 100644 --- a/api/src/membership/member_entity.py +++ b/api/src/membership/member_entity.py @@ -5,6 +5,7 @@ from pymysql.constants.ER import DUP_ENTRY from service.db import db_session from service.entity import Entity +from sqlalchemy import text from sqlalchemy.exc import IntegrityError from membership.member_auth import check_and_hash_password @@ -46,7 +47,7 @@ def create(self, data=None, commit=True): with db_session.begin_nested(): data = data.copy() (max_member_number,) = db_session.execute( - "SELECT COALESCE(MAX(member_number), 999) FROM membership_members" + text("SELECT COALESCE(MAX(member_number), 999) FROM membership_members") ).fetchone() data["member_number"] = max_member_number + offset # We must not commit here, because that will end our transaction, and the to_obj call will fail diff --git a/api/src/membership/models.py b/api/src/membership/models.py index 476f9d330..2a01c1035 100644 --- a/api/src/membership/models.py +++ b/api/src/membership/models.py @@ -17,10 +17,8 @@ Text, func, select, - text, ) -from sqlalchemy.ext.declarative import declarative_base -from sqlalchemy.orm import column_property, configure_mappers, relationship, validates +from sqlalchemy.orm import column_property, configure_mappers, declarative_base, relationship, validates Base = declarative_base() @@ -73,7 +71,7 @@ class Member(Base): def validate_phone(self, key: Any, value: Optional[str]) -> Optional[str]: return normalise_phone_number(value) - groups = relationship("Group", secondary=member_group, back_populates="members") + groups = relationship("Group", secondary=member_group, back_populates="members", cascade_backrefs=False) def __repr__(self) -> str: return f"Member(member_id={self.member_id}, member_number={self.member_number}, email={self.email})" @@ -99,9 +97,13 @@ class Group(Base): updated_at = Column(DateTime, server_default=func.now()) deleted_at = Column(DateTime) - members = relationship("Member", secondary=member_group, lazy="dynamic", back_populates="groups") + members = relationship( + "Member", secondary=member_group, lazy="dynamic", back_populates="groups", cascade_backrefs=False + ) - permissions = relationship("Permission", secondary=group_permission, back_populates="groups") + permissions = relationship( + "Permission", secondary=group_permission, back_populates="groups", cascade_backrefs=False + ) def __repr__(self) -> str: return f"Group(group_id={self.group_id}, name={self.name})" @@ -110,7 +112,7 @@ def __repr__(self) -> str: # Calculated property will be executed as a sub select for each groups, since it is not that many groups this will be # fine. Group.num_members = column_property( - select([func.count(member_group.columns.member_id)]) + select(func.count(member_group.columns.member_id)) .where(Group.group_id == member_group.columns.group_id) .scalar_subquery() ) @@ -125,7 +127,7 @@ class Permission(Base): updated_at = Column(DateTime, server_default=func.now()) deleted_at = Column(DateTime) - groups = relationship("Group", secondary=group_permission, back_populates="permissions") + groups = relationship("Group", secondary=group_permission, back_populates="permissions", cascade_backrefs=False) class Key(Base): @@ -139,7 +141,7 @@ class Key(Base): updated_at = Column(DateTime, server_default=func.now()) deleted_at = Column(DateTime) - member = relationship(Member, backref="keys") + member = relationship(Member, backref="keys", cascade_backrefs=False) def __repr__(self) -> str: return f"Key(key_id={self.key_id}, tagid={self.tagid})" @@ -162,7 +164,7 @@ class Span(Base): deleted_at = Column(DateTime) deletion_reason = Column(String(255)) - member = relationship(Member, backref="spans") + member = relationship(Member, backref="spans", cascade_backrefs=False) def __repr__(self) -> str: return f"Span(span_id={self.span_id}, type={self.type}, enddate={self.enddate})" @@ -188,7 +190,7 @@ class Box(Base): # last nag date for that member. last_nag_at = Column(DateTime, nullable=False) - member = relationship(Member, backref="boxes") + member = relationship(Member, backref="boxes", cascade_backrefs=False) def __repr__(self) -> str: return ( @@ -213,7 +215,7 @@ class PhoneNumberChangeRequest(Base): # When the request was made. timestamp = Column(DateTime, nullable=False) - member = relationship(Member, backref="change_phone_number_requests") + member = relationship(Member, backref="change_phone_number_requests", cascade_backrefs=False) @validates("phone") def validate_phone(self, key: Any, value: Optional[str]) -> Optional[str]: diff --git a/api/src/messages/models.py b/api/src/messages/models.py index 0a18b91a6..bf511991d 100644 --- a/api/src/messages/models.py +++ b/api/src/messages/models.py @@ -2,8 +2,7 @@ from membership.models import Member from sqlalchemy import Column, Date, DateTime, Enum, ForeignKey, Integer, String, Text, func -from sqlalchemy.ext.declarative import declarative_base -from sqlalchemy.orm import configure_mappers, relationship +from sqlalchemy.orm import configure_mappers, declarative_base, relationship Base = declarative_base() diff --git a/api/src/migrate.py b/api/src/migrate.py index 87fbca125..0ce6ec09c 100644 --- a/api/src/migrate.py +++ b/api/src/migrate.py @@ -7,7 +7,7 @@ from os.path import dirname, exists, isdir, join from service.logging import logger -from sqlalchemy import inspect +from sqlalchemy import inspect, text Migration = namedtuple("Migration", "id,name") @@ -26,53 +26,71 @@ def ensure_migrations_table(engine, session_factory): if "migrations" not in table_names: with closing(session_factory()) as session: logger.info("creating migrations table") - session.execute("ALTER DATABASE makeradmin CHARACTER SET = utf8mb4 COLLATE = utf8mb4_0900_ai_ci") + session.execute(text("ALTER DATABASE makeradmin CHARACTER SET = utf8mb4 COLLATE = utf8mb4_0900_ai_ci")) session.execute( - "CREATE TABLE migrations (" - " id INTEGER NOT NULL," - " name VARCHAR(255) COLLATE utf8mb4_0900_ai_ci NOT NULL," - " applied_at DATETIME NOT NULL," - " PRIMARY KEY (id)" - ") ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_0900_ai_ci" + text( + "CREATE TABLE migrations (" + " id INTEGER NOT NULL," + " name VARCHAR(255) COLLATE utf8mb4_0900_ai_ci NOT NULL," + " applied_at DATETIME NOT NULL," + " PRIMARY KEY (id)" + ") ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_0900_ai_ci" + ) ) session.commit() elif "service" in {c["name"] for c in engine_inspect.get_columns("migrations")}: with closing(session_factory()) as session: logger.info("updating existing migrations table") session.execute( - "UPDATE migrations SET id=1, name='0001_initial_core'" - " WHERE id=1 AND service='core' AND name='0001_initial'" + text( + "UPDATE migrations SET id=1, name='0001_initial_core'" + " WHERE id=1 AND service='core' AND name='0001_initial'" + ) ) session.execute( - "UPDATE migrations SET id=5, name='0005_remove_excessive_permissions'" - " WHERE id=2 AND service='membership' AND name='0002_remove_excessive_permissions'" + text( + "UPDATE migrations SET id=5, name='0005_remove_excessive_permissions'" + " WHERE id=2 AND service='membership' AND name='0002_remove_excessive_permissions'" + ) ) session.execute( - "UPDATE migrations SET id=6, name='0006_add_box'" - " WHERE id=3 AND service='membership' AND name='0003_add_box'" + text( + "UPDATE migrations SET id=6, name='0006_add_box'" + " WHERE id=3 AND service='membership' AND name='0003_add_box'" + ) ) session.execute( - "UPDATE migrations SET id=2, name='0002_initial_membership'" - " WHERE id=1 AND service='membership' AND name='0001_initial'" + text( + "UPDATE migrations SET id=2, name='0002_initial_membership'" + " WHERE id=1 AND service='membership' AND name='0001_initial'" + ) ) session.execute( - "UPDATE migrations SET id=4, name='0004_initial_messages'" - " WHERE id=1 AND service='messages' AND name='0001_initial'" + text( + "UPDATE migrations SET id=4, name='0004_initial_messages'" + " WHERE id=1 AND service='messages' AND name='0001_initial'" + ) ) session.execute( - "UPDATE migrations SET id=7, name='0007_rename_everything'" - " WHERE id=2 AND service='messages' AND name='0002_rename_everything'" + text( + "UPDATE migrations SET id=7, name='0007_rename_everything'" + " WHERE id=2 AND service='messages' AND name='0002_rename_everything'" + ) ) session.execute( - "UPDATE migrations SET id=3, name='0003_initial_shop'" - " WHERE id=1 AND service='shop' AND name='0001_initial'" + text( + "UPDATE migrations SET id=3, name='0003_initial_shop'" + " WHERE id=1 AND service='shop' AND name='0001_initial'" + ) ) session.execute( - "UPDATE migrations SET id=8, name='0008_password_reset_token'" - " WHERE id=2 AND service='core' AND name='0002_password_reset_token'" + text( + "UPDATE migrations SET id=8, name='0008_password_reset_token'" + " WHERE id=2 AND service='core' AND name='0002_password_reset_token'" + ) ) - session.execute("ALTER TABLE migrations DROP PRIMARY KEY, ADD PRIMARY KEY(id)") - session.execute("ALTER TABLE migrations DROP COLUMN service") + session.execute(text("ALTER TABLE migrations DROP PRIMARY KEY, ADD PRIMARY KEY(id)")) + session.execute(text("ALTER TABLE migrations DROP COLUMN service")) session.commit() @@ -99,7 +117,7 @@ def run_migrations(session_factory): migrations.sort(key=lambda m: m.id) - applied = {i: Migration(i, n) for i, n in session.execute("SELECT id, name FROM migrations ORDER BY ID")} + applied = {i: Migration(i, n) for i, n in session.execute(text("SELECT id, name FROM migrations ORDER BY ID"))} session.commit() logger.info(f"{len(migrations) - len(applied)} migrations to apply, {len(applied)} migrations already applied") @@ -115,10 +133,10 @@ def run_migrations(session_factory): logger.info(f"migrations, applying {migration.name}") for sql in read_sql(join(migrations_dir, migration.name + ".sql")): - session.execute(sql) + session.execute(text(sql)) session.execute( - "INSERT INTO migrations VALUES (:id, :name, :applied_at)", + text("INSERT INTO migrations VALUES (:id, :name, :applied_at)"), { "id": migration.id, "name": migration.name, diff --git a/api/src/multiaccessy/invite.py b/api/src/multiaccessy/invite.py index 133a97c03..0d6044350 100644 --- a/api/src/multiaccessy/invite.py +++ b/api/src/multiaccessy/invite.py @@ -30,7 +30,7 @@ class LabaccessRequirements(Enum): def check_labaccess_requirements(member_id: int) -> LabaccessRequirements: - member = db_session.query(Member).get(member_id) + member = db_session.get(Member, member_id) if member is None: return LabaccessRequirements.MEMBER_MISSING diff --git a/api/src/multiaccessy/models.py b/api/src/multiaccessy/models.py index 4626db6bf..8e99b7c2d 100644 --- a/api/src/multiaccessy/models.py +++ b/api/src/multiaccessy/models.py @@ -1,7 +1,6 @@ from membership.models import Member from sqlalchemy import Boolean, Column, DateTime, Enum, ForeignKey, Integer, Numeric, String, Text, func -from sqlalchemy.ext.declarative import declarative_base -from sqlalchemy.orm import configure_mappers, relationship +from sqlalchemy.orm import configure_mappers, declarative_base, relationship Base = declarative_base() diff --git a/api/src/multiaccessy/sync.py b/api/src/multiaccessy/sync.py index 7446af23d..dc1b76d16 100644 --- a/api/src/multiaccessy/sync.py +++ b/api/src/multiaccessy/sync.py @@ -22,7 +22,7 @@ def get_wanted_access(today: date, member_id: Optional[int] = None) -> dict[PHONE, AccessyMember]: if member_id is not None: - member = db_session.query(Member).get(member_id) + member = db_session.get(Member, member_id) if member is None: raise Exception("Member does not exist") members = [member] @@ -109,7 +109,7 @@ def sync(today: Optional[date] = None, member_id: Optional[int] = None) -> None: # If a specific member is given, sync only that member, # otherwise sync all members if member_id is not None: - member = db_session.query(Member).get(member_id) + member = db_session.get(Member, member_id) if member is None: raise Exception("Member does not exist") if member.phone is None: diff --git a/api/src/pytest.ini b/api/src/pytest.ini index 0d06989b5..f95facfe1 100644 --- a/api/src/pytest.ini +++ b/api/src/pytest.ini @@ -4,7 +4,6 @@ filterwarnings = ignore::DeprecationWarning:rocky.config # selenium not closed in some test, not sure why ignore:.*4444.*:ResourceWarning:selenium.webdriver.remote.remote_connection - ignore:.*not compatible with SQLAlchemy 2.0.*:DeprecationWarning: ignore::DeprecationWarning:stripe.*: log_cli = 1 log_cli_level = INFO diff --git a/api/src/quiz/models.py b/api/src/quiz/models.py index ff93724c5..35f15cd0c 100644 --- a/api/src/quiz/models.py +++ b/api/src/quiz/models.py @@ -1,7 +1,6 @@ from membership.models import Member from sqlalchemy import Boolean, Column, DateTime, Enum, ForeignKey, Integer, Numeric, String, Text, func -from sqlalchemy.ext.declarative import declarative_base -from sqlalchemy.orm import configure_mappers, relationship +from sqlalchemy.orm import configure_mappers, declarative_base, relationship Base = declarative_base() @@ -48,7 +47,7 @@ class QuizQuestionOption(Base): updated_at = Column(DateTime, server_default=func.now()) deleted_at = Column(DateTime) - question = relationship(QuizQuestion, backref="options") + question = relationship(QuizQuestion, backref="options", cascade_backrefs=False) def __repr__(self): return f"QuizQuestionOption(id={self.id}, description={self.description})" @@ -67,7 +66,7 @@ class QuizAnswer(Base): updated_at = Column(DateTime, server_default=func.now()) deleted_at = Column(DateTime) - question = relationship(QuizQuestion, backref="answers") + question = relationship(QuizQuestion, backref="answers", cascade_backrefs=False) def __repr__(self): return f"QuizAnswer(id={self.id})" diff --git a/api/src/quiz/views.py b/api/src/quiz/views.py index b7658a31e..0c64a3315 100644 --- a/api/src/quiz/views.py +++ b/api/src/quiz/views.py @@ -5,7 +5,7 @@ from service.api_definition import GET, POST, PUBLIC, QUIZ_EDIT, USER from service.db import db_session from service.entity import OrmSingeRelation -from sqlalchemy import distinct, exists, func +from sqlalchemy import distinct, exists, func, text from quiz import service from quiz.entities import quiz_entity, quiz_question_entity, quiz_question_option_entity @@ -79,7 +79,7 @@ def answer_question(question_id): ) db_session.flush() - question = db_session.query(QuizQuestion).get(question_id) + question = db_session.get(QuizQuestion, question_id) json = quiz_question_entity.to_obj(question) json["options"] = [] options = ( @@ -284,7 +284,9 @@ def quiz_statistics(quiz_id: int): seconds_to_answer_quiz = list( db_session.execute( - "select TIME_TO_SEC(TIMEDIFF(max(quiz_answers.created_at), min(quiz_answers.created_at))) as t from quiz_answers JOIN quiz_questions ON question_id=quiz_questions.id where quiz_questions.quiz_id=:quiz_id group by member_id order by t asc;", + text( + "select TIME_TO_SEC(TIMEDIFF(max(quiz_answers.created_at), min(quiz_answers.created_at))) as t from quiz_answers JOIN quiz_questions ON question_id=quiz_questions.id where quiz_questions.quiz_id=:quiz_id group by member_id order by t asc;" + ), {"quiz_id": quiz_id}, ) ) diff --git a/api/src/service/auth.py b/api/src/service/auth.py index 1eae7dc22..a6623458b 100644 --- a/api/src/service/auth.py +++ b/api/src/service/auth.py @@ -33,7 +33,7 @@ def authenticate_request() -> None: token = authorization[len(bearer) :].strip() - access_token = db_session.query(AccessToken).get(token) + access_token = db_session.get(AccessToken, token) if not access_token: raise Unauthorized("Unauthorized, invalid access token.", fields="bearer", what=BAD_VALUE) diff --git a/api/src/service/entity.py b/api/src/service/entity.py index a817dc997..1157db69e 100644 --- a/api/src/service/entity.py +++ b/api/src/service/entity.py @@ -304,7 +304,7 @@ def create(self, data=None, commit=True): return self.to_obj(self._create_internal(data, commit=commit)) def read(self, entity_id): - entity = db_session.query(self.model).get(entity_id) + entity = db_session.get(self.model, entity_id) if not entity: raise NotFound("Could not find any entity with specified parameters.") obj = self.to_obj(entity) @@ -316,7 +316,7 @@ def _update_internal(self, entity_id, data, commit=True): self.validate_present(input_data) if not input_data: raise UnprocessableEntity("Can not update using empty data.") - entity = db_session.query(self.model).get(entity_id) + entity = db_session.get(self.model, entity_id) if not entity: raise NotFound("Could not find any entity with specified parameters.") @@ -335,7 +335,7 @@ def update(self, entity_id: int, commit: bool = True) -> Any: return self._update_internal(entity_id, request.json, commit=commit) def delete(self, entity_id: int, commit: bool = True) -> None: - entity = db_session.query(self.model).get(entity_id) + entity = db_session.get(self.model, entity_id) if not entity: raise NotFound("Could not find any entity with specified parameters.") diff --git a/api/src/shell_with_db.py b/api/src/shell_with_db.py index f4c31fe2e..2ce1aeb67 100755 --- a/api/src/shell_with_db.py +++ b/api/src/shell_with_db.py @@ -6,9 +6,8 @@ from IPython import start_ipython from service.config import get_mysql_config from service.db import create_mysql_engine -from sqlalchemy import Column, ForeignKey, Integer, Text, select -from sqlalchemy.ext.declarative import declarative_base -from sqlalchemy.orm import relationship, sessionmaker +from sqlalchemy import select +from sqlalchemy.orm import sessionmaker def init_db(): diff --git a/api/src/shop/email.py b/api/src/shop/email.py index 75afec5bc..a5f085cda 100644 --- a/api/src/shop/email.py +++ b/api/src/shop/email.py @@ -10,7 +10,7 @@ def send_membership_updated_email(member_id: int, extended_days: int, end_date: date) -> None: - member = db_session.query(Member).get(member_id) + member = db_session.get(Member, member_id) send_message( MessageTemplate.ADD_MEMBERSHIP_TIME, member, extended_days=extended_days, end_date=date_to_str(end_date) @@ -18,7 +18,7 @@ def send_membership_updated_email(member_id: int, extended_days: int, end_date: def send_labaccess_extended_email(member_id: int, extended_days: int, end_date: date) -> None: - member = db_session.query(Member).get(member_id) + member = db_session.get(Member, member_id) send_message( MessageTemplate.ADD_LABACCESS_TIME, member, extended_days=extended_days, end_date=date_to_str(end_date) diff --git a/api/src/shop/filters.py b/api/src/shop/filters.py index 1a2fa87f6..8036c398b 100644 --- a/api/src/shop/filters.py +++ b/api/src/shop/filters.py @@ -31,7 +31,7 @@ def filter_no_subscription_active( sub: SubscriptionType, ) -> Callable[["CartItem", int], None]: def filter(cart_item: "CartItem", member_id: int) -> None: - member: Member = db_session.query(Member).get(member_id) + member: Member = db_session.get(Member, member_id) if sub == SubscriptionType.LAB: if member.stripe_labaccess_subscription_id is not None: raise BadRequest( diff --git a/api/src/shop/models.py b/api/src/shop/models.py index b2767fb6f..0c3ef8c62 100644 --- a/api/src/shop/models.py +++ b/api/src/shop/models.py @@ -18,8 +18,7 @@ Text, func, ) -from sqlalchemy.ext.declarative import declarative_base -from sqlalchemy.orm import configure_mappers, relationship, validates +from sqlalchemy.orm import configure_mappers, declarative_base, relationship, validates from shop.stripe_constants import MakerspaceMetadataKeys @@ -74,9 +73,11 @@ class Product(Base): show = Column(Boolean, nullable=False, server_default="1") stripe_product_id = Column(String(64)) - category = relationship(ProductCategory, backref="products") + category = relationship(ProductCategory, backref="products", cascade_backrefs=False) actions = relationship("ProductAction") - product_accounting = relationship("ProductAccountsCostCenters", backref="accounts_cost_centers") + product_accounting = relationship( + "ProductAccountsCostCenters", backref="accounts_cost_centers", cascade_backrefs=False + ) image_id = Column(Integer, ForeignKey(ProductImage.id), nullable=True) @@ -147,7 +148,7 @@ class TransactionContent(Base): count = Column(Integer, nullable=False) amount = Column(Numeric(precision="15,2"), nullable=False) - transaction = relationship(Transaction, backref="contents") + transaction = relationship(Transaction, backref="contents", cascade_backrefs=False) product = relationship(Product) def __repr__(self) -> str: @@ -167,7 +168,7 @@ class TransactionAction(Base): status = Column(Enum(PENDING, COMPLETED), nullable=False) completed_at = Column(DateTime) - content = relationship(TransactionContent, backref="actions") + content = relationship(TransactionContent, backref="actions", cascade_backrefs=False) def __repr__(self) -> str: return ( @@ -227,7 +228,7 @@ class ProductGiftCardMapping(Base): product_quantity = Column(Integer, nullable=False) amount = Column(Numeric(precision="15,2"), nullable=False) - gift_card = relationship(GiftCard, backref="products") + gift_card = relationship(GiftCard, backref="products", cascade_backrefs=False) product = relationship(Product) @@ -273,8 +274,8 @@ class ProductAccountsCostCenters(Base): ) # Using integer with the range 0-100 to represent fractions and avoind precision issues type = Column(Enum(*[x.value for x in AccountingEntryType]), nullable=False) - account = relationship(TransactionAccount, backref="accounts_cost_centers") - cost_center = relationship(TransactionCostCenter, backref="accounts_cost_centers") + account = relationship(TransactionAccount, backref="accounts_cost_centers", cascade_backrefs=False) + cost_center = relationship(TransactionCostCenter, backref="accounts_cost_centers", cascade_backrefs=False) def __repr__(self) -> str: return f"ProductAccountsCostCenters(id={self.id}, account_id={self.account_id}, cost_center_id={self.cost_center_id}, type={self.type}, fraction={self.fraction})" diff --git a/api/src/shop/ordered_entity.py b/api/src/shop/ordered_entity.py index bcc532fc4..324c6be5f 100644 --- a/api/src/shop/ordered_entity.py +++ b/api/src/shop/ordered_entity.py @@ -2,7 +2,7 @@ from service.db import db_session from service.entity import Entity from service.error import InternalServerError -from sqlalchemy import func +from sqlalchemy import func, text class OrderedEntity(Entity): @@ -17,7 +17,7 @@ def create(self, data=None, commit=True): if data is None: data = request.json or {} - (status,) = db_session.execute("SELECT GET_LOCK('display_order', 20)").fetchone() + (status,) = db_session.execute(text("SELECT GET_LOCK('display_order', 20)")).fetchone() if not status: raise InternalServerError("Failed to create, try again later.", log="failed to aquire display_order lock") try: @@ -30,4 +30,4 @@ def create(self, data=None, commit=True): db_session.rollback() raise finally: - db_session.execute("DO RELEASE_LOCK('display_order')") + db_session.execute(text("DO RELEASE_LOCK('display_order')")) diff --git a/api/src/shop/pay.py b/api/src/shop/pay.py index 8c32e182d..b3c93dcb6 100644 --- a/api/src/shop/pay.py +++ b/api/src/shop/pay.py @@ -44,7 +44,7 @@ def make_purchase(member_id: int, purchase: Purchase) -> Transaction: # If this purchase will start a subscription, then the payment method should be attached to the customer so that it can be used for the subscription. starts_subscription = False for item in purchase.cart: - product = db_session.query(Product).get(item.id) + product = db_session.get(Product, item.id) assert product is not None starts_subscription |= product.get_metadata(MakerspaceMetadataKeys.SUBSCRIPTION_TYPE, None) is not None @@ -161,7 +161,7 @@ def setup_payment_method(data_dict: Any, member_id: int) -> SetupPaymentMethodRe except Exception as e: raise BadRequest(message=f"Invalid data: {e}") - member = db_session.query(Member).get(member_id) + member = db_session.get(Member, member_id) if member is None: raise BadRequest(f"Unable to find member with id {member_id}") @@ -206,7 +206,7 @@ def setup_payment_method(data_dict: Any, member_id: int) -> SetupPaymentMethodRe def cancel_subscriptions(data: CancelSubscriptionsRequest, user_id: int) -> None: - member = db_session.query(Member).get(user_id) + member = db_session.get(Member, user_id) if member is None: raise BadRequest(f"Unable to find member with id {user_id}") @@ -221,7 +221,7 @@ def cancel_subscriptions(data: CancelSubscriptionsRequest, user_id: int) -> None def start_subscriptions(data: StartSubscriptionsRequest, user_id: int) -> None: - member = db_session.query(Member).get(user_id) + member = db_session.get(Member, user_id) if member is None: raise BadRequest(f"Unable to find member with id {user_id}") diff --git a/api/src/shop/shop_data.py b/api/src/shop/shop_data.py index 0ee05ec45..ce50f6a24 100644 --- a/api/src/shop/shop_data.py +++ b/api/src/shop/shop_data.py @@ -18,7 +18,7 @@ transaction_content_entity, transaction_entity, ) -from shop.models import Product, ProductAction, ProductCategory, Transaction +from shop.models import Product, ProductAction, ProductCategory, Transaction, TransactionContent from shop.stripe_constants import MakerspaceMetadataKeys from shop.transactions import pending_actions_query @@ -51,7 +51,7 @@ def pending_actions(member_id: Optional[int] = None) -> List[Any]: def member_history(member_id: int): query = ( db_session.query(Transaction) - .options(joinedload("contents"), joinedload("contents.product")) + .options(joinedload(Transaction.contents).joinedload(TransactionContent.product)) .filter(Transaction.member_id == member_id) .order_by(desc(Transaction.id)) ) @@ -76,7 +76,7 @@ def receipt(member_id, transaction_id): transaction = ( db_session.query(Transaction) .filter_by(member_id=member_id, id=transaction_id) - .options(joinedload("contents"), joinedload("contents.product")) + .options(joinedload(Transaction.contents).joinedload(TransactionContent.product)) .one() ) except NoResultFound: diff --git a/api/src/shop/stripe_customer.py b/api/src/shop/stripe_customer.py index 8f9329374..e9fe64293 100644 --- a/api/src/shop/stripe_customer.py +++ b/api/src/shop/stripe_customer.py @@ -107,7 +107,7 @@ def update_stripe_customer(makeradmin_member: Member) -> stripe.Customer: def delete_stripe_customer(member_id: int) -> None: - member = db_session.query(Member).get(member_id) + member = db_session.get(Member, member_id) if member is None: raise NotFound(f"Unable to find member with id {member_id}") stripe_customer_id = member.stripe_customer_id diff --git a/api/src/shop/stripe_event.py b/api/src/shop/stripe_event.py index 0f1e00e03..45a7b34da 100644 --- a/api/src/shop/stripe_event.py +++ b/api/src/shop/stripe_event.py @@ -113,7 +113,7 @@ def stripe_invoice_event(subtype: EventSubtype, event: stripe.Event, current_tim logger.error(f"Unexpected error reading invoice metadata: {e}") continue - member: Optional[Member] = db_session.query(Member).get(member_id) + member: Optional[Member] = db_session.get(Member, member_id) if member is None: logger.error(f"Ignoring invoice which contains subscription for non-existing member (id={member_id}).") continue @@ -257,7 +257,7 @@ def stripe_customer_event(event_subtype: EventSubtype, event: stripe.Event) -> N return member_id = int(meta[MakerspaceMetadataKeys.USER_ID.value]) - member = db_session.query(Member).get(member_id) + member = db_session.get(Member, member_id) if member is None: logger.warning(f"Ignoring customer event {event['id']} for non-existing member (id={member_id}).") return @@ -335,7 +335,7 @@ def stripe_subscription_schedule_event(event_subtype: EventSubtype, event: strip return member_id = int(meta[MakerspaceMetadataKeys.USER_ID.value]) - member: Optional[Member] = db_session.query(Member).get(member_id) + member: Optional[Member] = db_session.get(Member, member_id) if member is None: logger.warning(f"Ignoring subscription schedule event {event['id']} for unknown member {member_id}") return diff --git a/api/src/shop/stripe_payment_intent.py b/api/src/shop/stripe_payment_intent.py index 511bf9107..e2f17ef62 100644 --- a/api/src/shop/stripe_payment_intent.py +++ b/api/src/shop/stripe_payment_intent.py @@ -114,7 +114,7 @@ def create_client_response(transaction: Transaction, payment_intent: PaymentInte def confirm_stripe_payment_intent(transaction_id: int) -> PartialPayment: """Called by client after payment_intent next_action has been handled""" pending = db_session.query(StripePending).filter_by(transaction_id=transaction_id).one() - transaction = db_session.query(Transaction).get(transaction_id) + transaction = db_session.get(Transaction, transaction_id) if not transaction: raise BadRequest(f"unknown transaction ({transaction_id})") if transaction.status == Transaction.FAILED: @@ -163,7 +163,7 @@ def pay_with_stripe( """Handle stripe payment""" try: - member = db_session.query(Member).get(transaction.member_id) + member = db_session.get(Member, transaction.member_id) assert member is not None stripe_customer = get_and_sync_stripe_customer(member) assert stripe_customer is not None diff --git a/api/src/shop/stripe_subscriptions.py b/api/src/shop/stripe_subscriptions.py index 8d1c3b51c..ea57fa149 100644 --- a/api/src/shop/stripe_subscriptions.py +++ b/api/src/shop/stripe_subscriptions.py @@ -439,7 +439,7 @@ def cancel_subscription( def open_stripe_customer_portal(member_id: int) -> str: """Create a customer portal session and return the URL to which the user should be redirected.""" - member = db_session.query(Member).get(member_id) + member = db_session.get(Member, member_id) if member is None: raise BadRequest(f"Unable to find member with id {member_id}") @@ -493,7 +493,7 @@ def get_subscription_info_from_subscription(sub_type: SubscriptionType, sub_id: def list_subscriptions(member_id: int) -> List[SubscriptionInfo]: - member = db_session.query(Member).get(member_id) + member = db_session.get(Member, member_id) if member is None: raise BadRequest(f"Unable to find member with id {member_id}") diff --git a/api/src/shop/transactions.py b/api/src/shop/transactions.py index 85a88046e..64d4c60bb 100644 --- a/api/src/shop/transactions.py +++ b/api/src/shop/transactions.py @@ -25,6 +25,7 @@ from service.config import config from service.db import db_session, nested_atomic from service.error import BadRequest, InternalServerError, NotFound +from sqlalchemy import text from sqlalchemy.orm.exc import MultipleResultsFound, NoResultFound from sqlalchemy.sql import func @@ -108,11 +109,13 @@ def commit_transaction_to_db(member_id: int, total_amount: Decimal, contents: Li db_session.flush() db_session.execute( - """ - INSERT INTO webshop_transaction_actions (content_id, action_type, value, status) - SELECT :content_id AS content_id, action_type, SUM(:count * value) AS value, :pending AS status - FROM webshop_product_actions WHERE product_id=:product_id AND deleted_at IS NULL GROUP BY action_type - """, + text( + """ + INSERT INTO webshop_transaction_actions (content_id, action_type, value, status) + SELECT :content_id AS content_id, action_type, SUM(:count * value) AS value, :pending AS status + FROM webshop_product_actions WHERE product_id=:product_id AND deleted_at IS NULL GROUP BY action_type + """ + ), { "content_id": content.id, "count": content.count, @@ -179,7 +182,7 @@ def complete_pending_action(action: TransactionAction) -> None: def activate_paused_labaccess_subscription(member_id: int, earliest_start_at: datetime) -> None: - member = db_session.query(Member).get(member_id) + member = db_session.get(Member, member_id) if member is None: raise BadRequest(f"Unable to find member with id {member_id}") if member.stripe_labaccess_subscription_id is not None: @@ -340,7 +343,7 @@ def payment_success(transaction: Transaction) -> None: def process_cart(member_id: int, cart: List[CartItem]) -> Tuple[Decimal, List[TransactionContent]]: contents = [] - member = db_session.query(Member).get(member_id) + member = db_session.get(Member, member_id) if member is None: raise NotFound(message=f"Could not find member with id {member_id}.") price_level = get_price_level_for_member(member) diff --git a/api/src/statistics/maker_statistics.py b/api/src/statistics/maker_statistics.py index 0526caf02..b9687bbc3 100644 --- a/api/src/statistics/maker_statistics.py +++ b/api/src/statistics/maker_statistics.py @@ -11,7 +11,7 @@ from service.logging import logger from shop.entities import category_entity, product_entity from shop.models import Product, ProductCategory, Transaction, TransactionContent -from sqlalchemy import func +from sqlalchemy import func, text def spans_by_date(span_type: str) -> List[Tuple[str, int]]: @@ -164,7 +164,8 @@ def membership_by_date_statistics(): def lasertime() -> List[Tuple[str, int]]: query = db_session.execute( - """ + text( + """ SELECT DATE_FORMAT(webshop_transactions.created_at, "%Y-%m"), sum(webshop_transaction_contents.count) FROM webshop_transaction_contents INNER JOIN webshop_transactions @@ -172,6 +173,7 @@ def lasertime() -> List[Tuple[str, int]]: WHERE webshop_transaction_contents.product_id=7 AND webshop_transactions.status='completed' GROUP BY DATE_FORMAT(webshop_transactions.created_at, "%Y-%m") """ + ) ) results = [(date, int(count)) for (date, count) in query] diff --git a/api/src/systest/api/password_reset_test.py b/api/src/systest/api/password_reset_test.py index 74f971a58..043e061f9 100644 --- a/api/src/systest/api/password_reset_test.py +++ b/api/src/systest/api/password_reset_test.py @@ -76,6 +76,6 @@ def test_reset_password_works_for_nice_password(self): self.assertIsNone(e) - member = db_session.query(Member).get(member_id) + member = db_session.get(Member, member_id) self.assertTrue(verify_password(unhashed_password, member.password)) diff --git a/api/src/test_aid/db.py b/api/src/test_aid/db.py index 310426107..ad70cae4e 100644 --- a/api/src/test_aid/db.py +++ b/api/src/test_aid/db.py @@ -18,6 +18,7 @@ TransactionContent, TransactionCostCenter, ) +from sqlalchemy import text from test_aid.obj import ObjFactory from test_aid.test_util import random_str @@ -165,7 +166,7 @@ def get_member_number(self) -> int: while True: member_number = randint(5000, 2000000) sql = "SELECT 1 FROM membership_members WHERE member_number = :number" - if db_session.execute(sql, params=dict(number=member_number)).first() is None: + if db_session.execute(text(sql), params=dict(number=member_number)).first() is None: break return member_number