diff --git a/bats_ai/api.py b/bats_ai/api.py index 9b1cc1ca..62b594ed 100644 --- a/bats_ai/api.py +++ b/bats_ai/api.py @@ -1,7 +1,10 @@ import logging +from django.http import HttpRequest from ninja import NinjaAPI +from ninja.security import HttpBearer from oauth2_provider.models import AccessToken +from oauth2_provider.oauth2_backends import get_oauthlib_core from bats_ai.core.views import ( ConfigurationRouter, @@ -20,21 +23,26 @@ logger = logging.getLogger(__name__) -def global_auth(request): - if request.user.is_anonymous: - token = request.headers.get('Authorization', '').replace('Bearer ', '') - if len(token) > 0: - try: - access_token = AccessToken.objects.get(token=token) - except AccessToken.DoesNotExist: - access_token = None - if access_token and access_token.user: - if not access_token.user.is_anonymous: - request.user = access_token.user - return not request.user.is_anonymous +class OAuth2Auth(HttpBearer): + def __init__(self, scopes: list[str] | None = None) -> None: + super().__init__() + self.scopes = scopes if scopes is not None else [] + def authenticate(self, request: HttpRequest, token: str) -> AccessToken | None: + oauthlib_core = get_oauthlib_core() + # This also sets `request.user`, + # which Ninja does not: https://github.com/vitalik/django-ninja/issues/76 + valid, r = oauthlib_core.verify_request(request, scopes=self.scopes) -api = NinjaAPI(auth=global_auth) + if valid: + # Any truthy return is success, but give the full AccessToken for Ninja to set as + # `request.auth`. + request.user = r.access_token.user + return r.access_token + return None + + +api = NinjaAPI(auth=OAuth2Auth()) api.add_router('/recording/', RecordingRouter) api.add_router('/species/', SpeciesRouter) diff --git a/bats_ai/core/tests/conftest.py b/bats_ai/core/tests/conftest.py index d1c5e8fe..9df7dc7a 100644 --- a/bats_ai/core/tests/conftest.py +++ b/bats_ai/core/tests/conftest.py @@ -1,10 +1,11 @@ from django.contrib.auth.models import User from django.test import Client +from oauth2_provider.models import AccessToken import pytest from bats_ai.core.models import VettingDetails -from .factories import SuperuserFactory, UserFactory, VettingDetailsFactory +from .factories import AccessTokenFactory, SuperuserFactory, UserFactory, VettingDetailsFactory @pytest.fixture @@ -18,21 +19,31 @@ def user() -> User: @pytest.fixture -def superuser() -> User: - return SuperuserFactory() +def user_token(user) -> AccessToken: + return AccessTokenFactory(user=user) @pytest.fixture -def authenticated_client(user: User) -> Client: +def authenticated_client(user: User, user_token: AccessToken) -> Client: client = Client() - client.force_login(user=user) + client.defaults['HTTP_AUTHORIZATION'] = f'Bearer {user_token.token}' return client @pytest.fixture -def authorized_client(superuser: User) -> Client: +def superuser() -> User: + return SuperuserFactory() + + +@pytest.fixture +def superuser_token(superuser) -> AccessToken: + return AccessTokenFactory(user=superuser) + + +@pytest.fixture +def authorized_client(superuser: User, superuser_token: AccessToken) -> Client: client = Client() - client.force_login(user=superuser) + client.defaults['HTTP_AUTHORIZATION'] = f'Bearer {superuser_token.token}' return client diff --git a/bats_ai/core/tests/factories.py b/bats_ai/core/tests/factories.py index 708fbacf..5db9777c 100644 --- a/bats_ai/core/tests/factories.py +++ b/bats_ai/core/tests/factories.py @@ -1,5 +1,9 @@ +from datetime import timedelta + from django.contrib.auth.models import User +from django.utils import timezone import factory.django +from oauth2_provider.models import AccessToken from bats_ai.core.models import VettingDetails @@ -39,3 +43,13 @@ class Meta: user = factory.SubFactory(UserFactory) reference_materials = factory.Faker('paragraph', nb_sentences=3) + + +class AccessTokenFactory(factory.django.DjangoModelFactory[AccessToken]): + class Meta: + model = AccessToken + + user = factory.SubFactory(UserFactory) + token = factory.Faker('uuid4') + scope = 'read write' + expires = factory.LazyFunction(lambda: timezone.now() + timedelta(hours=1)) diff --git a/bats_ai/core/tests/test_admin.py b/bats_ai/core/tests/test_admin.py index 2ba1d16b..c9653f6b 100644 --- a/bats_ai/core/tests/test_admin.py +++ b/bats_ai/core/tests/test_admin.py @@ -16,3 +16,20 @@ def test_is_admin(client_fixture, status_code, is_admin, request): assert resp.status_code == status_code if is_admin is not None: assert resp.json()['is_admin'] == is_admin + + +@pytest.mark.parametrize( + 'client_fixture,status_code', + [ + ('authenticated_client', 200), + ('client', 401), + ], +) +@pytest.mark.django_db +def test_get_current_user(client_fixture, status_code, user, request): + api_client = request.getfixturevalue(client_fixture) + resp = api_client.get('/api/v1/configuration/me') + assert resp.status_code == status_code + if status_code == 200: + assert resp.json()['name'] == user.username + assert resp.json()['email'] == user.email diff --git a/bats_ai/core/tests/test_vetting_details.py b/bats_ai/core/tests/test_vetting_details.py index e5537267..3fdaf32f 100644 --- a/bats_ai/core/tests/test_vetting_details.py +++ b/bats_ai/core/tests/test_vetting_details.py @@ -1,6 +1,6 @@ import pytest -from .factories import UserFactory, VettingDetailsFactory +from .factories import AccessTokenFactory, UserFactory, VettingDetailsFactory @pytest.mark.parametrize( @@ -33,7 +33,8 @@ def test_create_vetting_details(client): test_text = 'foo' data = {'reference_materials': test_text} test_user = UserFactory() - client.force_login(user=test_user) + test_token = AccessTokenFactory(user=test_user) + client.defaults['HTTP_AUTHORIZATION'] = f'Bearer {test_token.token}' resp = client.post( f'/api/v1/vetting/user/{test_user.id}', data=data, content_type='application/json' ) @@ -67,8 +68,9 @@ def test_update_vetting_details(client): test_text = 'bar' data = {'reference_materials': 'bar'} test_user = UserFactory() + test_token = AccessTokenFactory(user=test_user) VettingDetailsFactory(user=test_user, reference_materials='foo') - client.force_login(test_user) + client.defaults['HTTP_AUTHORIZATION'] = f'Bearer {test_token.token}' initial_resp = client.get(f'/api/v1/vetting/user/{test_user.id}') assert initial_resp.status_code == 200