Skip to content

Commit 369f180

Browse files
SeungjinYangtvtv
authored
add request name to admin policy (#7840)
* add request name to admin policy * format * fix ut * plumb to userrequest * add ut * fix restful test * add example policy using new field * make request names available on sky * better example admin policies * controller launch * add reprs * format --------- Co-authored-by: tv <[email protected]> Co-authored-by: tv <[email protected]>
1 parent a9b31ca commit 369f180

File tree

18 files changed

+257
-51
lines changed

18 files changed

+257
-51
lines changed
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
admin_policy: example_policy.AddLabelsConditionalPolicy

examples/admin_policy/example_policy/example_policy/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
"""Example admin policy module and prebuilt policies."""
22
from example_policy.client_policy import UseLocalGcpCredentialsPolicy
3+
from example_policy.skypilot_policy import AddLabelsConditionalPolicy
34
from example_policy.skypilot_policy import AddLabelsPolicy
45
from example_policy.skypilot_policy import AddVolumesPolicy
56
from example_policy.skypilot_policy import DisablePublicIpPolicy

examples/admin_policy/example_policy/example_policy/skypilot_policy.py

Lines changed: 28 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,27 @@ def validate_and_mutate(
4242
return sky.MutatedUserRequest(user_request.task, config)
4343

4444

45+
class AddLabelsConditionalPolicy(sky.AdminPolicy):
46+
"""Example policy: adds a kubernetes label for skypilot_config
47+
if the request is a cluster launch request."""
48+
49+
@classmethod
50+
def validate_and_mutate(
51+
cls, user_request: sky.UserRequest) -> sky.MutatedUserRequest:
52+
if user_request.request_name in [
53+
sky.AdminPolicyRequestName.VALIDATE,
54+
sky.AdminPolicyRequestName.OPTIMIZE
55+
]:
56+
return sky.MutatedUserRequest(user_request.task,
57+
user_request.skypilot_config)
58+
config = user_request.skypilot_config
59+
labels = config.get_nested(('kubernetes', 'custom_metadata', 'labels'),
60+
{})
61+
labels['app'] = 'skypilot'
62+
config.set_nested(('kubernetes', 'custom_metadata', 'labels'), labels)
63+
return sky.MutatedUserRequest(user_request.task, config)
64+
65+
4566
class DisablePublicIpPolicy(sky.AdminPolicy):
4667
"""Example policy: disables public IP for all AWS tasks."""
4768

@@ -94,15 +115,17 @@ def validate_and_mutate(
94115
policy is applied, we should expect a few seconds latency when a user
95116
run a request.
96117
"""
97-
request_options = user_request.request_options
98-
99-
# Request options is None when a task is executed with `jobs launch` or
100-
# `sky serve up`.
101-
if request_options is None:
118+
if user_request.request_name not in [
119+
sky.AdminPolicyRequestName.CLUSTER_LAUNCH,
120+
sky.AdminPolicyRequestName.CLUSTER_EXEC,
121+
]:
102122
return sky.MutatedUserRequest(
103123
task=user_request.task,
104124
skypilot_config=user_request.skypilot_config)
105125

126+
request_options = user_request.request_options
127+
# Request options is not None when a task is executed with `sky launch`.
128+
assert request_options is not None
106129
# Get the cluster record to operate on.
107130
cluster_name = request_options.cluster_name
108131
cluster_records: List[responses.StatusResponse] = []

sky/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -122,6 +122,7 @@ def set_proxy_env_var(proxy_var: str, urllib_var: Optional[str]):
122122
from sky.jobs import ManagedJobStatus
123123
from sky.optimizer import Optimizer
124124
from sky.resources import Resources
125+
from sky.server.requests.request_names import AdminPolicyRequestName
125126
from sky.skylet.job_lib import JobStatus
126127
from sky.task import Task
127128
from sky.utils.common import OptimizeTarget
@@ -228,6 +229,7 @@ def set_proxy_env_var(proxy_var: str, urllib_var: Optional[str]):
228229
'MutatedUserRequest',
229230
'AdminPolicy',
230231
'Config',
232+
'AdminPolicyRequestName',
231233
# Registry
232234
'CLOUD_REGISTRY',
233235
'JOBS_RECOVERY_STRATEGY_REGISTRY',

sky/admin_policy.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from sky import exceptions
1212
from sky import models
1313
from sky.adaptors import common as adaptors_common
14+
from sky.server.requests import request_names
1415
from sky.utils import common_utils
1516
from sky.utils import config_utils
1617
from sky.utils import ux_utils
@@ -50,6 +51,7 @@ class _UserRequestBody(pydantic.BaseModel):
5051
# will be converted to JSON string, which will lose the None key.
5152
task: str
5253
skypilot_config: str
54+
request_name: str
5355
request_options: Optional[RequestOptions] = None
5456
at_client_side: bool = False
5557
user: str
@@ -81,6 +83,7 @@ class UserRequest:
8183
"""
8284
task: 'sky.Task'
8385
skypilot_config: 'sky.Config'
86+
request_name: request_names.AdminPolicyRequestName
8487
request_options: Optional['RequestOptions'] = None
8588
at_client_side: bool = False
8689
user: Optional['models.User'] = None
@@ -90,6 +93,7 @@ def encode(self) -> str:
9093
task=yaml_utils.dump_yaml_str(self.task.to_yaml_config()),
9194
skypilot_config=yaml_utils.dump_yaml_str(dict(
9295
self.skypilot_config)),
96+
request_name=self.request_name.value,
9397
request_options=self.request_options,
9498
at_client_side=self.at_client_side,
9599
user=(yaml_utils.dump_yaml_str(self.user.to_dict())
@@ -110,6 +114,8 @@ def decode(cls, body: str) -> 'UserRequest':
110114
skypilot_config=config_utils.Config.from_dict(
111115
yaml_utils.read_yaml_all_str(
112116
user_request_body.skypilot_config)[0]),
117+
request_name=request_names.AdminPolicyRequestName(
118+
user_request_body.request_name),
113119
request_options=user_request_body.request_options,
114120
at_client_side=user_request_body.at_client_side,
115121
user=user,

sky/client/sdk.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737
from sky.server import rest
3838
from sky.server import versions
3939
from sky.server.requests import payloads
40+
from sky.server.requests import request_names
4041
from sky.server.requests import requests as requests_lib
4142
from sky.skylet import autostop_lib
4243
from sky.skylet import constants
@@ -603,7 +604,10 @@ def launch(
603604
down=down,
604605
dryrun=dryrun)
605606
with admin_policy_utils.apply_and_use_config_in_current_request(
606-
dag, request_options=request_options, at_client_side=True) as dag:
607+
dag,
608+
request_name=request_names.AdminPolicyRequestName.CLUSTER_LAUNCH,
609+
request_options=request_options,
610+
at_client_side=True) as dag:
607611
return _launch(
608612
dag,
609613
cluster_name,

sky/core.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
from sky.provision.kubernetes import constants as kubernetes_constants
2626
from sky.provision.kubernetes import utils as kubernetes_utils
2727
from sky.schemas.api import responses
28+
from sky.server.requests import request_names
2829
from sky.skylet import autostop_lib
2930
from sky.skylet import constants
3031
from sky.skylet import job_lib
@@ -84,7 +85,9 @@ def optimize(
8485
# but we do not apply the admin policy there. We should apply the admin
8586
# policy in the optimizer, but that will require some refactoring.
8687
with admin_policy_utils.apply_and_use_config_in_current_request(
87-
dag, request_options=request_options) as dag:
88+
dag,
89+
request_name=request_names.AdminPolicyRequestName.OPTIMIZE,
90+
request_options=request_options) as dag:
8891
dag.resolve_and_validate_volumes()
8992
return optimizer.Optimizer.optimize(dag=dag,
9093
minimize=minimize,

sky/execution.py

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
from sky import optimizer
1717
from sky import sky_logging
1818
from sky.backends import backend_utils
19+
from sky.server.requests import request_names
1920
from sky.skylet import autostop_lib
2021
from sky.usage import usage_lib
2122
from sky.utils import admin_policy_utils
@@ -116,8 +117,10 @@ def _execute(
116117
no_setup: bool = False,
117118
clone_disk_from: Optional[str] = None,
118119
skip_unnecessary_provisioning: bool = False,
120+
*, #keyword only separator
119121
# Internal only:
120122
# pylint: disable=invalid-name
123+
_request_name: request_names.AdminPolicyRequestName,
121124
_quiet_optimizer: bool = False,
122125
_is_launched_by_jobs_controller: bool = False,
123126
_is_launched_by_sky_serve_controller: bool = False,
@@ -187,6 +190,7 @@ def _execute(
187190
idle_minutes_to_autostop = resource.autostop_config.idle_minutes
188191
with admin_policy_utils.apply_and_use_config_in_current_request(
189192
dag,
193+
request_name=_request_name,
190194
request_options=admin_policy.RequestOptions(
191195
cluster_name=cluster_name,
192196
idle_minutes_to_autostop=idle_minutes_to_autostop,
@@ -519,6 +523,51 @@ def _planner(_t: 'sky.Task'):
519523

520524
@timeline.event
521525
@usage_lib.entrypoint
526+
def cluster_launch(
527+
task: Union['sky.Task', 'sky.Dag'],
528+
cluster_name: Optional[str] = None,
529+
retry_until_up: bool = False,
530+
idle_minutes_to_autostop: Optional[int] = None,
531+
dryrun: bool = False,
532+
down: bool = False,
533+
stream_logs: bool = True,
534+
backend: Optional[backends.Backend] = None,
535+
optimize_target: common.OptimizeTarget = common.OptimizeTarget.COST,
536+
no_setup: bool = False,
537+
clone_disk_from: Optional[str] = None,
538+
fast: bool = False,
539+
*, #keyword only separator
540+
# Internal only:
541+
# pylint: disable=invalid-name
542+
_quiet_optimizer: bool = False,
543+
_is_launched_by_jobs_controller: bool = False,
544+
_is_launched_by_sky_serve_controller: bool = False,
545+
_disable_controller_check: bool = False,
546+
job_logger: logging.Logger = logger,
547+
) -> Tuple[Optional[int], Optional[backends.ResourceHandle]]:
548+
return launch(
549+
task=task,
550+
cluster_name=cluster_name,
551+
retry_until_up=retry_until_up,
552+
idle_minutes_to_autostop=idle_minutes_to_autostop,
553+
dryrun=dryrun,
554+
down=down,
555+
stream_logs=stream_logs,
556+
backend=backend,
557+
optimize_target=optimize_target,
558+
no_setup=no_setup,
559+
clone_disk_from=clone_disk_from,
560+
fast=fast,
561+
_quiet_optimizer=_quiet_optimizer,
562+
_is_launched_by_jobs_controller=_is_launched_by_jobs_controller,
563+
_is_launched_by_sky_serve_controller=
564+
_is_launched_by_sky_serve_controller,
565+
_disable_controller_check=_disable_controller_check,
566+
_request_name=request_names.AdminPolicyRequestName.CLUSTER_LAUNCH,
567+
job_logger=job_logger,
568+
)
569+
570+
522571
# A launch routine will share tempfiles between steps, so we init a tempdir
523572
# for the launch routine and gc the entire dir after launch.
524573
@tempstore.with_tempdir
@@ -535,12 +584,14 @@ def launch(
535584
no_setup: bool = False,
536585
clone_disk_from: Optional[str] = None,
537586
fast: bool = False,
587+
*, #keyword only separator
538588
# Internal only:
539589
# pylint: disable=invalid-name
540590
_quiet_optimizer: bool = False,
541591
_is_launched_by_jobs_controller: bool = False,
542592
_is_launched_by_sky_serve_controller: bool = False,
543593
_disable_controller_check: bool = False,
594+
_request_name: request_names.AdminPolicyRequestName,
544595
job_logger: logging.Logger = logger,
545596
) -> Tuple[Optional[int], Optional[backends.ResourceHandle]]:
546597
# NOTE(dev): Keep the docstring consistent between the Python API and CLI.
@@ -707,6 +758,7 @@ def launch(
707758
_is_launched_by_jobs_controller=_is_launched_by_jobs_controller,
708759
_is_launched_by_sky_serve_controller=
709760
_is_launched_by_sky_serve_controller,
761+
_request_name=_request_name,
710762
job_logger=job_logger)
711763

712764

@@ -794,4 +846,5 @@ def exec( # pylint: disable=redefined-builtin
794846
],
795847
cluster_name=cluster_name,
796848
job_logger=job_logger,
849+
_request_name=request_names.AdminPolicyRequestName.CLUSTER_EXEC,
797850
)

sky/jobs/client/sdk.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from sky.server import rest
1616
from sky.server import versions
1717
from sky.server.requests import payloads
18+
from sky.server.requests import request_names
1819
from sky.skylet import constants
1920
from sky.usage import usage_lib
2021
from sky.utils import admin_policy_utils
@@ -84,7 +85,9 @@ def launch(
8485

8586
dag = dag_utils.convert_entrypoint_to_dag(task)
8687
with admin_policy_utils.apply_and_use_config_in_current_request(
87-
dag, at_client_side=True) as dag:
88+
dag,
89+
request_name=request_names.AdminPolicyRequestName.JOBS_LAUNCH,
90+
at_client_side=True) as dag:
8891
sdk.validate(dag)
8992
if _need_confirmation:
9093
job_identity = 'a managed job'

sky/jobs/server/core.py

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
from sky.serve import serve_state
3636
from sky.serve import serve_utils
3737
from sky.serve.server import impl
38+
from sky.server.requests import request_names
3839
from sky.skylet import constants as skylet_constants
3940
from sky.usage import usage_lib
4041
from sky.utils import admin_policy_utils
@@ -237,7 +238,8 @@ def launch(
237238
# Always apply the policy again here, even though it might have been applied
238239
# in the CLI. This is to ensure that we apply the policy to the final DAG
239240
# and get the mutated config.
240-
dag, mutated_user_config = admin_policy_utils.apply(dag)
241+
dag, mutated_user_config = admin_policy_utils.apply(
242+
dag, request_name=request_names.AdminPolicyRequestName.JOBS_LAUNCH)
241243
dag.resolve_and_validate_volumes()
242244
if not dag.is_chain():
243245
with ux_utils.print_exception_no_traceback():
@@ -465,12 +467,15 @@ def _submit_one(
465467
# intermediate bucket and newly created bucket should be in
466468
# workspace A.
467469
if consolidation_mode_job_id is None:
468-
return execution.launch(task=controller_task,
469-
cluster_name=controller_name,
470-
stream_logs=stream_logs,
471-
retry_until_up=True,
472-
fast=True,
473-
_disable_controller_check=True)
470+
return execution.launch(
471+
task=controller_task,
472+
cluster_name=controller_name,
473+
stream_logs=stream_logs,
474+
retry_until_up=True,
475+
fast=True,
476+
_request_name=request_names.AdminPolicyRequestName.
477+
JOBS_LAUNCH_CONTROLLER,
478+
_disable_controller_check=True)
474479
# Manually launch the scheduler in consolidation mode.
475480
local_handle = backend_utils.is_controller_accessible(
476481
controller=controller, stopped_message='')

0 commit comments

Comments
 (0)