Skip to content

Commit fcbf4df

Browse files
authored
Fix broken user token rotation API (#1487)
1 parent 5aebcee commit fcbf4df

File tree

3 files changed

+59
-4
lines changed

3 files changed

+59
-4
lines changed

src/dstack/_internal/server/routers/users.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -85,9 +85,9 @@ async def update_user(
8585
async def refresh_token(
8686
body: RefreshTokenRequest,
8787
session: AsyncSession = Depends(get_session),
88-
user: UserModel = Depends(GlobalAdmin()),
88+
user: UserModel = Depends(Authenticated()),
8989
) -> UserWithCreds:
90-
res = await users.refresh_user_token(session=session, username=body.username)
90+
res = await users.refresh_user_token(session=session, user=user, username=body.username)
9191
if res is None:
9292
raise ResourceNotExistsError()
9393
return users.user_model_to_user_with_creds(res)

src/dstack/_internal/server/services/users.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from dstack._internal.core.errors import ResourceExistsError
99
from dstack._internal.core.models.users import GlobalRole, User, UserTokenCreds, UserWithCreds
1010
from dstack._internal.server.models import UserModel
11+
from dstack._internal.server.utils.routers import error_forbidden
1112

1213
_ADMIN_USERNAME = "admin"
1314

@@ -90,9 +91,15 @@ async def update_user(
9091
return await get_user_model_by_name_or_error(session=session, username=username)
9192

9293

93-
async def refresh_user_token(session: AsyncSession, username: str) -> Optional[UserModel]:
94+
async def refresh_user_token(
95+
session: AsyncSession,
96+
user: UserModel,
97+
username: str,
98+
) -> Optional[UserModel]:
99+
if user.global_role != GlobalRole.ADMIN and user.name != username:
100+
raise error_forbidden()
94101
await session.execute(
95-
update(UserModel).where(UserModel.name == username).values(token=uuid.uuid4())
102+
update(UserModel).where(UserModel.name == username).values(token=str(uuid.uuid4()))
96103
)
97104
await session.commit()
98105
return await get_user_model_by_name(session=session, username=username)

src/tests/_internal/server/routers/test_users.py

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -172,3 +172,51 @@ async def test_deletes_users(self, test_db, session: AsyncSession):
172172
assert response.status_code == 200
173173
res = await session.execute(select(UserModel).where(UserModel.name == user.name))
174174
assert len(res.scalars().all()) == 0
175+
176+
177+
class TestRefreshToken:
178+
def test_returns_40x_if_not_authenticated(self):
179+
response = client.post("/api/users/refresh_token")
180+
assert response.status_code in [401, 403]
181+
182+
@pytest.mark.asyncio
183+
async def test_refreshes_token(self, test_db, session: AsyncSession):
184+
user1 = await create_user(name="user1", session=session)
185+
old_token = user1.token
186+
response = client.post(
187+
"/api/users/refresh_token",
188+
headers=get_auth_headers(user1.token),
189+
json={"username": user1.name},
190+
)
191+
assert response.status_code == 200
192+
assert response.json()["creds"]["token"] != old_token
193+
await session.refresh(user1)
194+
assert user1.token != old_token
195+
196+
@pytest.mark.asyncio
197+
async def test_returns_403_if_non_admin_refreshes_for_other_user(
198+
self, test_db, session: AsyncSession
199+
):
200+
user1 = await create_user(name="user1", session=session, global_role=GlobalRole.USER)
201+
user2 = await create_user(name="user2", session=session)
202+
response = client.post(
203+
"/api/users/refresh_token",
204+
headers=get_auth_headers(user1.token),
205+
json={"username": user2.name},
206+
)
207+
assert response.status_code == 403
208+
209+
@pytest.mark.asyncio
210+
async def test_global_admin_refreshes_token(self, test_db, session: AsyncSession):
211+
user1 = await create_user(name="user1", session=session, global_role=GlobalRole.ADMIN)
212+
user2 = await create_user(name="user2", session=session)
213+
old_token = user2.token
214+
response = client.post(
215+
"/api/users/refresh_token",
216+
headers=get_auth_headers(user1.token),
217+
json={"username": user2.name},
218+
)
219+
assert response.status_code == 200
220+
assert response.json()["creds"]["token"] != old_token
221+
await session.refresh(user2)
222+
assert user2.token != old_token

0 commit comments

Comments
 (0)