diff --git a/db/utils.py b/db/utils.py new file mode 100644 index 0000000..d5a54c4 --- /dev/null +++ b/db/utils.py @@ -0,0 +1,41 @@ +import logging +from datetime import UTC, datetime, timedelta +from typing import Sequence + +from sqlmodel import Session, select + +from auth0.client import Auth0Client +from db.models import BiocommonsUser + +logger = logging.getLogger("uvicorn.error") + + +UNVERIFIED_REFRESH_INTERVAL_SECONDS = 60 * 5 +LAST_UNVERIFIED_REFRESH_TIME: datetime | None = None + + +def refresh_unverified_users(session: Session, auth0_client: Auth0Client): + """ + Update all unverified users with their latest email_verified status from Auth0. + Run as a background task, triggered when an admin looks at unverified users. + + Uses a simple time-based throttle to avoid excessive API calls. + """ + global LAST_UNVERIFIED_REFRESH_TIME + if LAST_UNVERIFIED_REFRESH_TIME is not None: + since_last_refresh = datetime.now(UTC) - LAST_UNVERIFIED_REFRESH_TIME + if since_last_refresh < timedelta(seconds=UNVERIFIED_REFRESH_INTERVAL_SECONDS): + logger.info(f"Skipping refresh of unverified users: last refresh was {LAST_UNVERIFIED_REFRESH_TIME}") + return + LAST_UNVERIFIED_REFRESH_TIME = datetime.now(UTC) + unverified_users: Sequence[BiocommonsUser] = session.exec(select(BiocommonsUser).where(BiocommonsUser.email_verified.is_(False))).all() + try: + for user in unverified_users: + auth0_data = auth0_client.get_user(user.id) + if auth0_data.email_verified != user.email_verified: + logger.info(f"Updating email_verified status for user {user.id}: {auth0_data.email_verified}") + user.email_verified = auth0_data.email_verified + session.add(user) + session.commit() + finally: + session.close() diff --git a/routers/admin.py b/routers/admin.py index 6918cf8..eaac858 100644 --- a/routers/admin.py +++ b/routers/admin.py @@ -4,7 +4,7 @@ from datetime import datetime, timezone from typing import Annotated, Any -from fastapi import APIRouter, Depends, HTTPException +from fastapi import APIRouter, BackgroundTasks, Depends, HTTPException from fastapi.params import Query from pydantic import BaseModel, Field, ValidationError, field_validator from sqlalchemy import func, or_ @@ -40,6 +40,7 @@ GroupMembershipData, PlatformMembershipData, ) +from db.utils import refresh_unverified_users from schemas.biocommons import Auth0UserDataWithMemberships, ServiceIdParam, UserIdParam from schemas.user import SessionUser from services.email_queue import enqueue_email @@ -580,13 +581,23 @@ def get_filtered_user_query( response_model=list[BiocommonsUserResponse]) def get_users(db_session: Annotated[Session, Depends(get_db_session)], query_params: Annotated[UserQueryParams, Depends()], - user_query: Annotated[SelectOfScalar[BiocommonsUser], Depends(get_filtered_user_query)]): + user_query: Annotated[SelectOfScalar[BiocommonsUser], Depends(get_filtered_user_query)], + auth0_client: Annotated[Auth0Client, Depends(get_auth0_client)], + background_tasks: BackgroundTasks): """ Get all users from the database with pagination and optional filtering. The admin_user must have roles that allow access to either the platform or group to see the users. """ + # Refresh unverified status if specifically requested + if query_params.email_verified is not None: + logger.info("Refreshing unverified users") + background_tasks.add_task( + refresh_unverified_users, + session=db_session, + auth0_client=auth0_client, + ) # Check for missing IDs in the database (e.g. group ID not found) and raise 404 query_params.check_missing_ids(db_session) users = db_session.exec(user_query).all() diff --git a/run_scheduler.py b/run_scheduler.py index 3a1e929..2d16766 100644 --- a/run_scheduler.py +++ b/run_scheduler.py @@ -27,6 +27,7 @@ def schedule_jobs(scheduler: AsyncIOScheduler): hourly_trigger = IntervalTrigger(minutes=60) + half_hourly_trigger = IntervalTrigger(minutes=30) email_trigger = IntervalTrigger(minutes=1) logger.info("Adding one-off job: populate DB groups") scheduler.add_job( @@ -50,10 +51,10 @@ def schedule_jobs(scheduler: AsyncIOScheduler): replace_existing=True, next_run_time=datetime.now(UTC) ) - logger.info("Adding hourly job: sync_auth0_users") + logger.info("Adding half-hourly job: sync_auth0_users") scheduler.add_job( sync_auth0_users, - trigger=hourly_trigger, + trigger=half_hourly_trigger, id="sync_auth0_users", replace_existing=True, next_run_time=datetime.now(UTC) + timedelta(minutes=15) diff --git a/tests/conftest.py b/tests/conftest.py index 9362787..a4c09d4 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,7 +1,7 @@ import os import warnings from datetime import datetime -from unittest.mock import MagicMock +from unittest.mock import MagicMock, patch import pytest from fastapi.testclient import TestClient @@ -330,6 +330,16 @@ def persistent_factories(test_db_session): for factory in factories: factory.__session__ = None + +@pytest.fixture +def mock_background_tasks(): + """ + Mock BackgroundTasks - tasks will not be run but you can check add_task was called + """ + with patch("fastapi.BackgroundTasks.add_task") as mocked: + yield mocked + + @pytest.fixture(scope="function") def aws_credentials(): """Mocked AWS Credentials for moto.""" diff --git a/tests/db/test_utils.py b/tests/db/test_utils.py new file mode 100644 index 0000000..216ba2a --- /dev/null +++ b/tests/db/test_utils.py @@ -0,0 +1,102 @@ +from unittest.mock import MagicMock, patch + +import pytest +from sqlmodel import Session, select + +from db.models import BiocommonsUser +from db.utils import refresh_unverified_users +from tests.db.datagen import BiocommonsUserFactory + + +@pytest.fixture(autouse=True) +def reset_refresh_time(): + """Reset the global throttle variable before each test.""" + with patch("db.utils.LAST_UNVERIFIED_REFRESH_TIME", None): + yield + + +def test_refresh_unverified_users_throttling(test_db_session): + """Verifies that the function returns early if called within the refresh interval.""" + auth0_client = MagicMock() + + # Patch session.exec to see if it gets called + with patch.object(test_db_session, 'exec', wraps=test_db_session.exec) as spy_exec: + # First call: should proceed and update LAST_UNVERIFIED_REFRESH_TIME + refresh_unverified_users(test_db_session, auth0_client) + assert spy_exec.call_count == 1 + + spy_exec.reset_mock() + + # Second call: should hit the global throttle and return immediately + with patch("db.utils.logger") as mock_logger: + refresh_unverified_users(test_db_session, auth0_client) + + # Verify the "Skipping refresh" log was emitted + mock_logger.info.assert_called_once() + assert "Skipping refresh" in mock_logger.info.call_args[0][0] + + # Crucially: verify the database was NOT queried a second time + spy_exec.assert_not_called() + # And Auth0 was never touched + auth0_client.get_user.assert_not_called() + + +def test_refresh_unverified_users_updates_status(test_db_session: Session): + """Verifies that a user's status is updated when Auth0 reports they are now verified.""" + # Create an unverified user + user = BiocommonsUser( + id="auth0|123", + email="test@example.com", + username="testuser", + email_verified=False, + ) + test_db_session.add(user) + test_db_session.commit() + + # Mock Auth0 to return verified=True + auth0_client = MagicMock() + auth0_data = MagicMock() + auth0_data.email_verified = True + auth0_client.get_user.return_value = auth0_data + + refresh_unverified_users(test_db_session, auth0_client) + + # Re-fetch user to check update + db_session_new = Session(test_db_session.bind) # Use fresh session to avoid cache + updated_user = db_session_new.exec(select(BiocommonsUser).where(BiocommonsUser.id == "auth0|123")).one() + assert updated_user.email_verified is True + + +def test_refresh_unverified_users_skips_verified_users(test_db_session): + """Verifies that the query only targets users where email_verified is False.""" + user = BiocommonsUser( + id="auth0|456", + email="verified@example.com", + username="verified", + email_verified=True, + ) + test_db_session.add(user) + test_db_session.commit() + + auth0_client = MagicMock() + refresh_unverified_users(test_db_session, auth0_client) + + # Auth0 should not have been called for this user + auth0_client.get_user.assert_not_called() + + +def test_refresh_unverified_users_no_change(test_db_session, persistent_factories): + """Verifies no DB update occurs if Auth0 still says False.""" + user = BiocommonsUserFactory.create_sync(email_verified=False) + test_db_session.add(user) + test_db_session.commit() + + auth0_client = MagicMock() + auth0_data = MagicMock() + auth0_data.email_verified = False + auth0_client.get_user.return_value = auth0_data + + with patch.object(test_db_session, 'add') as mock_add: + refresh_unverified_users(test_db_session, auth0_client) + # session.add(user) should NOT be called if status is the same + mock_add.assert_not_called() diff --git a/tests/test_admin_user_filters.py b/tests/test_admin_user_filters.py index ce5fe65..bbad655 100644 --- a/tests/test_admin_user_filters.py +++ b/tests/test_admin_user_filters.py @@ -4,6 +4,7 @@ from db.models import BiocommonsUser from db.types import ApprovalStatusEnum, GroupEnum, PlatformEnum +from db.utils import refresh_unverified_users from routers.admin import UserQueryParams from tests.db.datagen import ( Auth0RoleFactory, @@ -36,7 +37,7 @@ def test_user_query_params_missing_method(monkeypatch): UserQueryParams(email_verified=True) -def test_user_query_multiple_filters(test_client, mock_auth0_client, as_admin_user, test_db_session, persistent_factories): +def test_user_query_multiple_filters(test_client, mock_auth0_client, as_admin_user, test_db_session, mock_background_tasks, persistent_factories): """ Test that multiple conditions can be combined correctly """ @@ -164,3 +165,16 @@ def test_get_users_combined_group_filters( matching_user_ids = {u.id for u in matching_users} returned_ids = {u["id"] for u in data} assert returned_ids == matching_user_ids + + +def test_get_users_email_verified_refreshes_status(test_client, as_admin_user, test_db_session, persistent_factories, mock_background_tasks): + """ + Check that explicitly searching based on email_verified status triggers a refresh of unverified users. + """ + resp = test_client.get( + "/admin/users?email_verified=true" + ) + assert resp.status_code == 200 + assert mock_background_tasks.called + print(mock_background_tasks.call_args) + assert mock_background_tasks.call_args[0][0] == refresh_unverified_users