Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 21 additions & 13 deletions bats_ai/api.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
import logging

from django.http import HttpRequest
from ninja import NinjaAPI
from ninja.security import HttpBearer
from oauth2_provider.models import AccessToken
from oauth2_provider.oauth2_backends import get_oauthlib_core

from bats_ai.core.views import (
ConfigurationRouter,
Expand All @@ -20,21 +23,26 @@
logger = logging.getLogger(__name__)


def global_auth(request):
if request.user.is_anonymous:
token = request.headers.get('Authorization', '').replace('Bearer ', '')
if len(token) > 0:
try:
access_token = AccessToken.objects.get(token=token)
except AccessToken.DoesNotExist:
access_token = None
if access_token and access_token.user:
if not access_token.user.is_anonymous:
request.user = access_token.user
return not request.user.is_anonymous
class OAuth2Auth(HttpBearer):
def __init__(self, scopes: list[str] | None = None) -> None:
super().__init__()
self.scopes = scopes if scopes is not None else []

def authenticate(self, request: HttpRequest, token: str) -> AccessToken | None:
oauthlib_core = get_oauthlib_core()
# This also sets `request.user`,
# which Ninja does not: https://github.com/vitalik/django-ninja/issues/76
valid, r = oauthlib_core.verify_request(request, scopes=self.scopes)

api = NinjaAPI(auth=global_auth)
if valid:
# Any truthy return is success, but give the full AccessToken for Ninja to set as
# `request.auth`.
request.user = r.access_token.user
return r.access_token
return None


api = NinjaAPI(auth=OAuth2Auth())

api.add_router('/recording/', RecordingRouter)
api.add_router('/species/', SpeciesRouter)
Expand Down
25 changes: 18 additions & 7 deletions bats_ai/core/tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
from django.contrib.auth.models import User
from django.test import Client
from oauth2_provider.models import AccessToken
import pytest

from bats_ai.core.models import VettingDetails

from .factories import SuperuserFactory, UserFactory, VettingDetailsFactory
from .factories import AccessTokenFactory, SuperuserFactory, UserFactory, VettingDetailsFactory


@pytest.fixture
Expand All @@ -18,21 +19,31 @@ def user() -> User:


@pytest.fixture
def superuser() -> User:
return SuperuserFactory()
def user_token(user) -> AccessToken:
return AccessTokenFactory(user=user)


@pytest.fixture
def authenticated_client(user: User) -> Client:
def authenticated_client(user: User, user_token: AccessToken) -> Client:
client = Client()
client.force_login(user=user)
client.defaults['HTTP_AUTHORIZATION'] = f'Bearer {user_token.token}'
return client


@pytest.fixture
def authorized_client(superuser: User) -> Client:
def superuser() -> User:
return SuperuserFactory()


@pytest.fixture
def superuser_token(superuser) -> AccessToken:
return AccessTokenFactory(user=superuser)


@pytest.fixture
def authorized_client(superuser: User, superuser_token: AccessToken) -> Client:
client = Client()
client.force_login(user=superuser)
client.defaults['HTTP_AUTHORIZATION'] = f'Bearer {superuser_token.token}'
return client


Expand Down
14 changes: 14 additions & 0 deletions bats_ai/core/tests/factories.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
from datetime import timedelta

from django.contrib.auth.models import User
from django.utils import timezone
import factory.django
from oauth2_provider.models import AccessToken

from bats_ai.core.models import VettingDetails

Expand Down Expand Up @@ -39,3 +43,13 @@ class Meta:

user = factory.SubFactory(UserFactory)
reference_materials = factory.Faker('paragraph', nb_sentences=3)


class AccessTokenFactory(factory.django.DjangoModelFactory[AccessToken]):
class Meta:
model = AccessToken

user = factory.SubFactory(UserFactory)
token = factory.Faker('uuid4')
scope = 'read write'
expires = factory.LazyFunction(lambda: timezone.now() + timedelta(hours=1))
17 changes: 17 additions & 0 deletions bats_ai/core/tests/test_admin.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,3 +16,20 @@ def test_is_admin(client_fixture, status_code, is_admin, request):
assert resp.status_code == status_code
if is_admin is not None:
assert resp.json()['is_admin'] == is_admin


@pytest.mark.parametrize(
'client_fixture,status_code',
[
('authenticated_client', 200),
('client', 401),
],
)
@pytest.mark.django_db
def test_get_current_user(client_fixture, status_code, user, request):
api_client = request.getfixturevalue(client_fixture)
resp = api_client.get('/api/v1/configuration/me')
assert resp.status_code == status_code
if status_code == 200:
assert resp.json()['name'] == user.username
assert resp.json()['email'] == user.email
8 changes: 5 additions & 3 deletions bats_ai/core/tests/test_vetting_details.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import pytest

from .factories import UserFactory, VettingDetailsFactory
from .factories import AccessTokenFactory, UserFactory, VettingDetailsFactory


@pytest.mark.parametrize(
Expand Down Expand Up @@ -33,7 +33,8 @@ def test_create_vetting_details(client):
test_text = 'foo'
data = {'reference_materials': test_text}
test_user = UserFactory()
client.force_login(user=test_user)
test_token = AccessTokenFactory(user=test_user)
client.defaults['HTTP_AUTHORIZATION'] = f'Bearer {test_token.token}'
resp = client.post(
f'/api/v1/vetting/user/{test_user.id}', data=data, content_type='application/json'
)
Expand Down Expand Up @@ -67,8 +68,9 @@ def test_update_vetting_details(client):
test_text = 'bar'
data = {'reference_materials': 'bar'}
test_user = UserFactory()
test_token = AccessTokenFactory(user=test_user)
VettingDetailsFactory(user=test_user, reference_materials='foo')
client.force_login(test_user)
client.defaults['HTTP_AUTHORIZATION'] = f'Bearer {test_token.token}'

initial_resp = client.get(f'/api/v1/vetting/user/{test_user.id}')
assert initial_resp.status_code == 200
Expand Down
Loading