Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 41 additions & 0 deletions db/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
import logging
from datetime import UTC, datetime, timedelta
from typing import Sequence

from sqlmodel import Session, select

from auth0.client import Auth0Client
from db.models import BiocommonsUser

logger = logging.getLogger("uvicorn.error")


UNVERIFIED_REFRESH_INTERVAL_SECONDS = 60 * 5
LAST_UNVERIFIED_REFRESH_TIME: datetime | None = None


def refresh_unverified_users(session: Session, auth0_client: Auth0Client):
"""
Update all unverified users with their latest email_verified status from Auth0.
Run as a background task, triggered when an admin looks at unverified users.

Uses a simple time-based throttle to avoid excessive API calls.
"""
global LAST_UNVERIFIED_REFRESH_TIME
if LAST_UNVERIFIED_REFRESH_TIME is not None:
since_last_refresh = datetime.now(UTC) - LAST_UNVERIFIED_REFRESH_TIME
if since_last_refresh < timedelta(seconds=UNVERIFIED_REFRESH_INTERVAL_SECONDS):
logger.info(f"Skipping refresh of unverified users: last refresh was {LAST_UNVERIFIED_REFRESH_TIME}")
return
LAST_UNVERIFIED_REFRESH_TIME = datetime.now(UTC)
unverified_users: Sequence[BiocommonsUser] = session.exec(select(BiocommonsUser).where(BiocommonsUser.email_verified.is_(False))).all()
try:
for user in unverified_users:
auth0_data = auth0_client.get_user(user.id)
if auth0_data.email_verified != user.email_verified:
logger.info(f"Updating email_verified status for user {user.id}: {auth0_data.email_verified}")
user.email_verified = auth0_data.email_verified
session.add(user)
session.commit()
finally:
session.close()
15 changes: 13 additions & 2 deletions routers/admin.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from datetime import datetime, timezone
from typing import Annotated, Any

from fastapi import APIRouter, Depends, HTTPException
from fastapi import APIRouter, BackgroundTasks, Depends, HTTPException
from fastapi.params import Query
from pydantic import BaseModel, Field, ValidationError, field_validator
from sqlalchemy import func, or_
Expand Down Expand Up @@ -40,6 +40,7 @@
GroupMembershipData,
PlatformMembershipData,
)
from db.utils import refresh_unverified_users
from schemas.biocommons import Auth0UserDataWithMemberships, ServiceIdParam, UserIdParam
from schemas.user import SessionUser
from services.email_queue import enqueue_email
Expand Down Expand Up @@ -580,13 +581,23 @@ def get_filtered_user_query(
response_model=list[BiocommonsUserResponse])
def get_users(db_session: Annotated[Session, Depends(get_db_session)],
query_params: Annotated[UserQueryParams, Depends()],
user_query: Annotated[SelectOfScalar[BiocommonsUser], Depends(get_filtered_user_query)]):
user_query: Annotated[SelectOfScalar[BiocommonsUser], Depends(get_filtered_user_query)],
auth0_client: Annotated[Auth0Client, Depends(get_auth0_client)],
background_tasks: BackgroundTasks):
"""
Get all users from the database with pagination and optional filtering.

The admin_user must have roles that allow access to either the platform or group to
see the users.
"""
# Refresh unverified status if specifically requested
if query_params.email_verified is not None:
logger.info("Refreshing unverified users")
background_tasks.add_task(
refresh_unverified_users,
session=db_session,
auth0_client=auth0_client,
)
# Check for missing IDs in the database (e.g. group ID not found) and raise 404
query_params.check_missing_ids(db_session)
users = db_session.exec(user_query).all()
Expand Down
5 changes: 3 additions & 2 deletions run_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@

def schedule_jobs(scheduler: AsyncIOScheduler):
hourly_trigger = IntervalTrigger(minutes=60)
half_hourly_trigger = IntervalTrigger(minutes=30)
email_trigger = IntervalTrigger(minutes=1)
logger.info("Adding one-off job: populate DB groups")
scheduler.add_job(
Expand All @@ -50,10 +51,10 @@ def schedule_jobs(scheduler: AsyncIOScheduler):
replace_existing=True,
next_run_time=datetime.now(UTC)
)
logger.info("Adding hourly job: sync_auth0_users")
logger.info("Adding half-hourly job: sync_auth0_users")
scheduler.add_job(
sync_auth0_users,
trigger=hourly_trigger,
trigger=half_hourly_trigger,
id="sync_auth0_users",
replace_existing=True,
next_run_time=datetime.now(UTC) + timedelta(minutes=15)
Expand Down
12 changes: 11 additions & 1 deletion tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import os
import warnings
from datetime import datetime
from unittest.mock import MagicMock
from unittest.mock import MagicMock, patch

import pytest
from fastapi.testclient import TestClient
Expand Down Expand Up @@ -330,6 +330,16 @@ def persistent_factories(test_db_session):
for factory in factories:
factory.__session__ = None


@pytest.fixture
def mock_background_tasks():
"""
Mock BackgroundTasks - tasks will not be run but you can check add_task was called
"""
with patch("fastapi.BackgroundTasks.add_task") as mocked:
yield mocked


@pytest.fixture(scope="function")
def aws_credentials():
"""Mocked AWS Credentials for moto."""
Expand Down
102 changes: 102 additions & 0 deletions tests/db/test_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
from unittest.mock import MagicMock, patch

import pytest
from sqlmodel import Session, select

from db.models import BiocommonsUser
from db.utils import refresh_unverified_users
from tests.db.datagen import BiocommonsUserFactory


@pytest.fixture(autouse=True)
def reset_refresh_time():
"""Reset the global throttle variable before each test."""
with patch("db.utils.LAST_UNVERIFIED_REFRESH_TIME", None):
yield


def test_refresh_unverified_users_throttling(test_db_session):
"""Verifies that the function returns early if called within the refresh interval."""
auth0_client = MagicMock()

# Patch session.exec to see if it gets called
with patch.object(test_db_session, 'exec', wraps=test_db_session.exec) as spy_exec:
# First call: should proceed and update LAST_UNVERIFIED_REFRESH_TIME
refresh_unverified_users(test_db_session, auth0_client)
assert spy_exec.call_count == 1

spy_exec.reset_mock()

# Second call: should hit the global throttle and return immediately
with patch("db.utils.logger") as mock_logger:
refresh_unverified_users(test_db_session, auth0_client)

# Verify the "Skipping refresh" log was emitted
mock_logger.info.assert_called_once()
assert "Skipping refresh" in mock_logger.info.call_args[0][0]

# Crucially: verify the database was NOT queried a second time
spy_exec.assert_not_called()
# And Auth0 was never touched
auth0_client.get_user.assert_not_called()


def test_refresh_unverified_users_updates_status(test_db_session: Session):
"""Verifies that a user's status is updated when Auth0 reports they are now verified."""
# Create an unverified user
user = BiocommonsUser(
id="auth0|123",
email="[email protected]",
username="testuser",
email_verified=False,
)
test_db_session.add(user)
test_db_session.commit()

# Mock Auth0 to return verified=True
auth0_client = MagicMock()
auth0_data = MagicMock()
auth0_data.email_verified = True
auth0_client.get_user.return_value = auth0_data

refresh_unverified_users(test_db_session, auth0_client)

# Re-fetch user to check update
db_session_new = Session(test_db_session.bind) # Use fresh session to avoid cache
updated_user = db_session_new.exec(select(BiocommonsUser).where(BiocommonsUser.id == "auth0|123")).one()
assert updated_user.email_verified is True


def test_refresh_unverified_users_skips_verified_users(test_db_session):
"""Verifies that the query only targets users where email_verified is False."""
user = BiocommonsUser(
id="auth0|456",
email="[email protected]",
username="verified",
email_verified=True,
)
test_db_session.add(user)
test_db_session.commit()

auth0_client = MagicMock()
refresh_unverified_users(test_db_session, auth0_client)

# Auth0 should not have been called for this user
auth0_client.get_user.assert_not_called()


def test_refresh_unverified_users_no_change(test_db_session, persistent_factories):
"""Verifies no DB update occurs if Auth0 still says False."""
user = BiocommonsUserFactory.create_sync(email_verified=False)
test_db_session.add(user)
test_db_session.commit()

auth0_client = MagicMock()
auth0_data = MagicMock()
auth0_data.email_verified = False
auth0_client.get_user.return_value = auth0_data

with patch.object(test_db_session, 'add') as mock_add:
refresh_unverified_users(test_db_session, auth0_client)
# session.add(user) should NOT be called if status is the same
mock_add.assert_not_called()
16 changes: 15 additions & 1 deletion tests/test_admin_user_filters.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

from db.models import BiocommonsUser
from db.types import ApprovalStatusEnum, GroupEnum, PlatformEnum
from db.utils import refresh_unverified_users
from routers.admin import UserQueryParams
from tests.db.datagen import (
Auth0RoleFactory,
Expand Down Expand Up @@ -36,7 +37,7 @@ def test_user_query_params_missing_method(monkeypatch):
UserQueryParams(email_verified=True)


def test_user_query_multiple_filters(test_client, mock_auth0_client, as_admin_user, test_db_session, persistent_factories):
def test_user_query_multiple_filters(test_client, mock_auth0_client, as_admin_user, test_db_session, mock_background_tasks, persistent_factories):
"""
Test that multiple conditions can be combined correctly
"""
Expand Down Expand Up @@ -164,3 +165,16 @@ def test_get_users_combined_group_filters(
matching_user_ids = {u.id for u in matching_users}
returned_ids = {u["id"] for u in data}
assert returned_ids == matching_user_ids


def test_get_users_email_verified_refreshes_status(test_client, as_admin_user, test_db_session, persistent_factories, mock_background_tasks):
"""
Check that explicitly searching based on email_verified status triggers a refresh of unverified users.
"""
resp = test_client.get(
"/admin/users?email_verified=true"
)
assert resp.status_code == 200
assert mock_background_tasks.called
print(mock_background_tasks.call_args)
assert mock_background_tasks.call_args[0][0] == refresh_unverified_users