Skip to content

Commit a9b31ca

Browse files
SeungjinYangtv
andauthored
add user data to admin policy (#7800)
* add user info * encode/decode test * another ut * add comment * fix syntax warning --------- Co-authored-by: tv <[email protected]>
1 parent 16cfa24 commit a9b31ca

File tree

4 files changed

+79
-1
lines changed

4 files changed

+79
-1
lines changed

sky/admin_policy.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99

1010
import sky
1111
from sky import exceptions
12+
from sky import models
1213
from sky.adaptors import common as adaptors_common
1314
from sky.utils import common_utils
1415
from sky.utils import config_utils
@@ -51,6 +52,7 @@ class _UserRequestBody(pydantic.BaseModel):
5152
skypilot_config: str
5253
request_options: Optional[RequestOptions] = None
5354
at_client_side: bool = False
55+
user: str
5456

5557

5658
@dataclasses.dataclass
@@ -73,11 +75,15 @@ class UserRequest:
7375
skypilot_config: Global skypilot config to be used in this request.
7476
request_options: Request options. It is None for jobs and services.
7577
at_client_side: Is the request intercepted by the policy at client-side?
78+
user: User who made the request.
79+
Only available on the server side.
80+
This value is None if at_client_side is True.
7681
"""
7782
task: 'sky.Task'
7883
skypilot_config: 'sky.Config'
7984
request_options: Optional['RequestOptions'] = None
8085
at_client_side: bool = False
86+
user: Optional['models.User'] = None
8187

8288
def encode(self) -> str:
8389
return _UserRequestBody(
@@ -86,11 +92,18 @@ def encode(self) -> str:
8692
self.skypilot_config)),
8793
request_options=self.request_options,
8894
at_client_side=self.at_client_side,
95+
user=(yaml_utils.dump_yaml_str(self.user.to_dict())
96+
if self.user is not None else ''),
8997
).model_dump_json()
9098

9199
@classmethod
92100
def decode(cls, body: str) -> 'UserRequest':
93101
user_request_body = _UserRequestBody.model_validate_json(body)
102+
user_dict = yaml_utils.read_yaml_str(
103+
user_request_body.user) if user_request_body.user != '' else None
104+
user = models.User(
105+
id=user_dict['id'],
106+
name=user_dict['name']) if user_dict is not None else None
94107
return cls(
95108
task=sky.Task.from_yaml_config(
96109
yaml_utils.read_yaml_all_str(user_request_body.task)[0]),
@@ -99,6 +112,7 @@ def decode(cls, body: str) -> 'UserRequest':
99112
user_request_body.skypilot_config)[0]),
100113
request_options=user_request_body.request_options,
101114
at_client_side=user_request_body.at_client_side,
115+
user=user,
102116
)
103117

104118

sky/utils/admin_policy_utils.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import contextlib
33
import copy
44
import importlib
5+
import typing
56
from typing import Iterator, Optional, Tuple, Union
67
import urllib.parse
78

@@ -19,6 +20,9 @@
1920

2021
logger = sky_logging.init_logger(__name__)
2122

23+
if typing.TYPE_CHECKING:
24+
from sky import models
25+
2226

2327
def _is_url(policy_string: str) -> bool:
2428
"""Check if the policy string is a URL."""
@@ -126,9 +130,13 @@ def apply(
126130
if policy is None:
127131
return dag, skypilot_config.to_dict()
128132

133+
user = None
129134
if at_client_side:
130135
logger.info(f'Applying client admin policy: {policy}')
131136
else:
137+
# When being called by the server, the middleware has set the
138+
# current user and this information is available at this point.
139+
user = common_utils.get_current_user()
132140
logger.info(f'Applying server admin policy: {policy}')
133141
config = copy.deepcopy(skypilot_config.to_dict())
134142
mutated_dag = dag_lib.Dag()
@@ -137,7 +145,7 @@ def apply(
137145
mutated_config = None
138146
for task in dag.tasks:
139147
user_request = admin_policy.UserRequest(task, config, request_options,
140-
at_client_side)
148+
at_client_side, user)
141149
try:
142150
mutated_user_request = policy.apply(user_request)
143151
# Avoid duplicate exception wrapping.

tests/unit_tests/test_admin_policy.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
import sky
1616
from sky import exceptions
17+
from sky import models
1718
from sky import sky_logging
1819
from sky import skypilot_config
1920
from sky.utils import admin_policy_utils
@@ -379,6 +380,21 @@ def test_use_local_gcp_credentials_policy(add_example_policy_paths, task):
379380
policy.apply(user_request)
380381

381382

383+
def test_user_request_encode_decode(task):
384+
with mock.patch('sky.utils.common_utils.get_current_user',
385+
return_value=models.User(id='123', name='test')):
386+
user_request = sky.UserRequest(task=task,
387+
skypilot_config=sky.Config(),
388+
at_client_side=False,
389+
user=models.User(id='123', name='test'))
390+
encoded_request = user_request.encode()
391+
decoded_request = sky.UserRequest.decode(encoded_request)
392+
assert repr(decoded_request.task) == repr(task)
393+
assert decoded_request.skypilot_config == sky.Config()
394+
assert decoded_request.at_client_side == False
395+
assert decoded_request.user == models.User(id='123', name='test')
396+
397+
382398
def test_restful_policy(add_example_policy_paths, task):
383399
"""Test RestfulAdminPolicy for various scenarios."""
384400

tests/unit_tests/test_admin_policy_restful.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
import threading
2020
import time
2121
from typing import Optional, Tuple
22+
from unittest import mock
2223

2324
from fastapi import FastAPI
2425
from fastapi import Request
@@ -28,6 +29,7 @@
2829

2930
import sky
3031
from sky import admin_policy
32+
from sky import models
3133
from sky import skypilot_config
3234
from sky.utils import admin_policy_utils
3335
from sky.utils import common_utils
@@ -371,6 +373,44 @@ def test_restful_policy_with_request_options(monkeypatch):
371373
os.unlink(config_path)
372374

373375

376+
def test_restful_policy_with_user(monkeypatch):
377+
"""Test RESTful admin policy receiving user information."""
378+
with mock.patch('sky.utils.common_utils.get_current_user',
379+
return_value=models.User(id='123', name='test')):
380+
with PolicyServer() as server:
381+
ImageIdInspectorPolicy.received_requests.clear()
382+
383+
# Create a test task
384+
task = create_test_task()
385+
386+
# Create temporary config and apply policy using existing function
387+
with tempfile.NamedTemporaryFile(mode='w',
388+
suffix='.yaml',
389+
delete=False) as f:
390+
f.write(f'admin_policy: http://127.0.0.1:{server.port}\n')
391+
config_path = f.name
392+
393+
try:
394+
dag, config = _load_task_and_apply_policy(
395+
task, config_path, monkeypatch)
396+
397+
# Verify the policy was called with proper request structure
398+
assert len(ImageIdInspectorPolicy.received_requests) == 1
399+
request = ImageIdInspectorPolicy.received_requests[0]
400+
401+
# Check that user information was properly included
402+
assert request.user is not None
403+
assert request.user.id == '123'
404+
assert request.user.name == 'test'
405+
406+
# Check that we got valid results back
407+
assert dag is not None
408+
assert len(dag.tasks) == 1
409+
assert config is not None
410+
finally:
411+
os.unlink(config_path)
412+
413+
374414
def test_restful_policy_basic_functionality(monkeypatch):
375415
"""Test basic RESTful admin policy functionality using real patterns."""
376416
with PolicyServer() as server:

0 commit comments

Comments
 (0)