diff --git a/helm_chart/HyperPodHelmChart/charts/inference-operator/Chart.yaml b/helm_chart/HyperPodHelmChart/charts/inference-operator/Chart.yaml index 3717fd6c..7b4671c2 100644 --- a/helm_chart/HyperPodHelmChart/charts/inference-operator/Chart.yaml +++ b/helm_chart/HyperPodHelmChart/charts/inference-operator/Chart.yaml @@ -15,13 +15,11 @@ type: application # This is the chart version. This version number should be incremented each time you make changes # to the chart and its templates, including the app version. # Versions are expected to follow Semantic Versioning (https://semver.org/) -version: 1.0.0 +version: 1.1.0 -# This is the version number of the application being deployed. This version number should be -# incremented each time you make changes to the application. Versions are not expected to -# follow Semantic Versioning. They should reflect the version the application is using. -# It is recommended to use it with quotes. -appVersion: "2.0" +# This is the version number of the application being deployed. Keep this aligned +# with operator image MAJOR.MINOR version. +appVersion: "2.1" dependencies: - name: aws-mountpoint-s3-csi-driver diff --git a/helm_chart/HyperPodHelmChart/charts/inference-operator/config/crd/inference.sagemaker.aws.amazon.com_inferenceendpointconfigs.yaml b/helm_chart/HyperPodHelmChart/charts/inference-operator/config/crd/inference.sagemaker.aws.amazon.com_inferenceendpointconfigs.yaml index 7f43c89a..7616f134 100644 --- a/helm_chart/HyperPodHelmChart/charts/inference-operator/config/crd/inference.sagemaker.aws.amazon.com_inferenceendpointconfigs.yaml +++ b/helm_chart/HyperPodHelmChart/charts/inference-operator/config/crd/inference.sagemaker.aws.amazon.com_inferenceendpointconfigs.yaml @@ -696,7 +696,7 @@ spec: l2CacheBackend: description: L2 cache backend type. Required when L2CacheSpec is provided. - pattern: (?i)redis + pattern: (?i)redis|tieredstorage type: string l2CacheLocalUrl: description: Provide the L2 cache URL to local storage @@ -721,6 +721,12 @@ spec: - round_robin type: string type: object + maxDeployTimeInSeconds: + default: 3600 + description: Maximum allowed time in seconds for the deployment to + complete before timing out. Defaults to 1 hour (3600 seconds) + format: int32 + type: integer metrics: description: Configuration for metrics collection and exposure properties: @@ -1617,12 +1623,6 @@ spec: - round_robin type: string type: object - maxDeployTimeInSeconds: - default: 3600 - description: Maximum allowed time in seconds for the deployment to - complete before timing out. Defaults to 1 hour (3600 seconds) - format: int32 - type: integer metrics: description: Configuration for metrics collection and exposure properties: diff --git a/helm_chart/HyperPodHelmChart/charts/inference-operator/config/crd/inference.sagemaker.aws.amazon.com_jumpstartmodels.yaml b/helm_chart/HyperPodHelmChart/charts/inference-operator/config/crd/inference.sagemaker.aws.amazon.com_jumpstartmodels.yaml index 68ea257e..4e1b5443 100644 --- a/helm_chart/HyperPodHelmChart/charts/inference-operator/config/crd/inference.sagemaker.aws.amazon.com_jumpstartmodels.yaml +++ b/helm_chart/HyperPodHelmChart/charts/inference-operator/config/crd/inference.sagemaker.aws.amazon.com_jumpstartmodels.yaml @@ -350,6 +350,349 @@ spec: type: object maxItems: 100 type: array + intelligentRoutingSpec: + description: |- + Configuration for intelligent routing + This feature is currently not supported for existing deployments. + Adding this configuration to an existing deployment will be rejected. + properties: + autoScalingSpec: + properties: + cloudWatchTrigger: + description: CloudWatch metric trigger to use for autoscaling + properties: + activationTargetValue: + default: 0 + description: Activation Value for CloudWatch metric to + scale from 0 to 1. Only applicable if minReplicaCount + = 0 + type: number + dimensions: + description: Dimensions for Cloudwatch metrics + items: + properties: + name: + description: CloudWatch Metric dimension name + type: string + value: + description: CloudWatch Metric dimension value + type: string + required: + - name + - value + type: object + type: array + metricCollectionPeriod: + default: 300 + description: Defines the Period for CloudWatch query + format: int32 + type: integer + metricCollectionStartTime: + default: 300 + description: Defines the StartTime for CloudWatch query + format: int32 + type: integer + metricName: + description: Metric name to query for Cloudwatch trigger + type: string + metricStat: + default: Average + description: Statistics metric to be used by Trigger. + Used to define Stat for CloudWatch query. Default is + Average. + type: string + metricType: + default: Average + description: 'The type of metric to be used by HPA. Enum: + AverageValue - Uses average value of metric per pod, + Value - Uses absolute metric value' + enum: + - Value + - Average + type: string + minValue: + default: 0 + description: Minimum metric value used in case of empty + response from CloudWatch. Default is 0. + type: number + name: + description: Name for the CloudWatch trigger + type: string + namespace: + description: AWS CloudWatch namespace for metric + type: string + targetValue: + description: TargetValue for CloudWatch metric + type: number + useCachedMetrics: + default: true + description: Enable caching of metric values during polling + interval. Default is true + type: boolean + type: object + cloudWatchTriggerList: + description: Multiple CloudWatch metric triggers to use for + autoscaling. Takes priority over CloudWatchTrigger if both + are provided. + items: + properties: + activationTargetValue: + default: 0 + description: Activation Value for CloudWatch metric + to scale from 0 to 1. Only applicable if minReplicaCount + = 0 + type: number + dimensions: + description: Dimensions for Cloudwatch metrics + items: + properties: + name: + description: CloudWatch Metric dimension name + type: string + value: + description: CloudWatch Metric dimension value + type: string + required: + - name + - value + type: object + type: array + metricCollectionPeriod: + default: 300 + description: Defines the Period for CloudWatch query + format: int32 + type: integer + metricCollectionStartTime: + default: 300 + description: Defines the StartTime for CloudWatch query + format: int32 + type: integer + metricName: + description: Metric name to query for Cloudwatch trigger + type: string + metricStat: + default: Average + description: Statistics metric to be used by Trigger. + Used to define Stat for CloudWatch query. Default + is Average. + type: string + metricType: + default: Average + description: 'The type of metric to be used by HPA. + Enum: AverageValue - Uses average value of metric + per pod, Value - Uses absolute metric value' + enum: + - Value + - Average + type: string + minValue: + default: 0 + description: Minimum metric value used in case of empty + response from CloudWatch. Default is 0. + type: number + name: + description: Name for the CloudWatch trigger + type: string + namespace: + description: AWS CloudWatch namespace for metric + type: string + targetValue: + description: TargetValue for CloudWatch metric + type: number + useCachedMetrics: + default: true + description: Enable caching of metric values during + polling interval. Default is true + type: boolean + type: object + maxItems: 100 + type: array + cooldownPeriod: + default: 300 + description: The period to wait after the last trigger reported + active before scaling the resource back to 0. Default 300 + seconds. + format: int32 + minimum: 0 + type: integer + initialCooldownPeriod: + default: 300 + description: The delay before the cooldownPeriod starts after + the initial creation of the ScaledObject. Default 300 seconds. + format: int32 + minimum: 0 + type: integer + maxReplicaCount: + default: 5 + description: The maximum number of model pods to scale to. + Default 5. + format: int32 + minimum: 0 + type: integer + minReplicaCount: + default: 1 + description: The minimum number of model pods to scale down + to. Default 1. + format: int32 + minimum: 0 + type: integer + pollingInterval: + default: 30 + description: This is the interval to check each trigger on. + Default 30 seconds. + format: int32 + minimum: 0 + type: integer + prometheusTrigger: + description: Prometheus metric trigger to use for autoscaling + properties: + activationTargetValue: + default: 0 + description: Activation Value for Prometheus metric to + scale from 0 to 1. Only applicable if minReplicaCount + = 0 + type: number + customHeaders: + description: Custom headers to include while querying + the prometheus endpoint. + type: string + metricType: + default: Average + description: 'The type of metric to be used by HPA. Enum: + AverageValue - Uses average value of metric per pod, + Value - Uses absolute metric value' + enum: + - Value + - Average + type: string + name: + description: Name for the Prometheus trigger + type: string + namespace: + description: Namespace for namespaced queries + type: string + query: + description: PromQLQuery for the metric. + type: string + serverAddress: + description: Server address for AMP workspace + pattern: ^https:\/\/aps-workspaces\.[a-zA-Z0-9-]+(?:\.[a-zA-Z0-9-]+)*\.amazonaws\.com\/workspaces\/ws-[a-zA-Z0-9-]+\/[a-zA-Z0-9-]+$|^$ + type: string + targetValue: + description: Target metric value for scaling + type: number + useCachedMetrics: + default: true + description: Enable caching of metric values during polling + interval. Default is true + type: boolean + type: object + prometheusTriggerList: + description: Multiple Prometheus metric triggers to use for + autoscaling. Takes priority over PrometheusTrigger if both + are provided. + items: + properties: + activationTargetValue: + default: 0 + description: Activation Value for Prometheus metric + to scale from 0 to 1. Only applicable if minReplicaCount + = 0 + type: number + customHeaders: + description: Custom headers to include while querying + the prometheus endpoint. + type: string + metricType: + default: Average + description: 'The type of metric to be used by HPA. + Enum: AverageValue - Uses average value of metric + per pod, Value - Uses absolute metric value' + enum: + - Value + - Average + type: string + name: + description: Name for the Prometheus trigger + type: string + namespace: + description: Namespace for namespaced queries + type: string + query: + description: PromQLQuery for the metric. + type: string + serverAddress: + description: Server address for AMP workspace + pattern: ^https:\/\/aps-workspaces\.[a-zA-Z0-9-]+(?:\.[a-zA-Z0-9-]+)*\.amazonaws\.com\/workspaces\/ws-[a-zA-Z0-9-]+\/[a-zA-Z0-9-]+$|^$ + type: string + targetValue: + description: Target metric value for scaling + type: number + useCachedMetrics: + default: true + description: Enable caching of metric values during + polling interval. Default is true + type: boolean + type: object + maxItems: 100 + type: array + scaleDownStabilizationTime: + default: 300 + description: The time window to stabilize for HPA before scaling + down. Default 300 seconds. + format: int32 + minimum: 0 + type: integer + scaleUpStabilizationTime: + default: 0 + description: The time window to stabilize for HPA before scaling + up. Default 0 seconds. + format: int32 + minimum: 0 + type: integer + type: object + enabled: + default: false + description: Once set, the enabled field cannot be modified + type: boolean + routingStrategy: + default: prefixaware + enum: + - prefixaware + - kvaware + - session + - roundrobin + type: string + type: object + kvCacheSpec: + description: |- + Configuration for KV Cache specification + By default L1CacheOffloading will be enabled + properties: + cacheConfigFile: + description: KVCache configuration file path. If specified, override + other configurations provided via spec + type: string + enableL1Cache: + default: true + description: Enable CPU offloading + type: boolean + enableL2Cache: + default: false + type: boolean + l2CacheSpec: + description: Configuration for providing L2 Cache offloading + properties: + l2CacheBackend: + description: L2 cache backend type. Required when L2CacheSpec + is provided. + pattern: (?i)redis|tieredstorage + type: string + l2CacheLocalUrl: + description: Provide the L2 cache URL to local storage + type: string + type: object + type: object loadBalancer: description: Configuration for Application Load Balancer properties: @@ -477,6 +820,10 @@ spec: type: object server: properties: + acceleratorPartitionType: + description: MIG profile to use for GPU partitioning + pattern: ^mig-.*$ + type: string executionRole: description: The Amazon Resource Name (ARN) of an IAM role that will be used to deploy and manage the inference server @@ -489,6 +836,15 @@ spec: Must be one of the supported types. pattern: ^ml\..* type: string + validations: + description: Validations configuration for the server + properties: + acceleratorPartitionValidation: + default: true + description: Enable MIG validation for GPU partitioning. Default + is true. + type: boolean + type: object required: - instanceType type: object diff --git a/helm_chart/HyperPodHelmChart/charts/inference-operator/config/manager/manager.yaml b/helm_chart/HyperPodHelmChart/charts/inference-operator/config/manager/manager.yaml index 9fe34cdb..24075cef 100644 --- a/helm_chart/HyperPodHelmChart/charts/inference-operator/config/manager/manager.yaml +++ b/helm_chart/HyperPodHelmChart/charts/inference-operator/config/manager/manager.yaml @@ -48,6 +48,94 @@ spec: # versions < 1.19 or on vendors versions which do NOT support this field by default (i.e. Openshift < 4.11 ). # seccompProfile: # type: RuntimeDefault + initContainers: + - command: + - bash + - -lc + - | + set -euo pipefail + KUBECTL="$(command -v kubectl || true)" + if [ -z "${KUBECTL}" ]; then + for p in /opt/bitnami/kubectl/bin/kubectl /usr/local/bin/kubectl /usr/bin/kubectl /bin/kubectl; do + if [ -x "$p" ]; then KUBECTL="$p"; break; fi + done + fi + if [ -z "${KUBECTL}" ]; then + echo "kubectl not found in PATH or common locations" > /dev/termination-log + exit 2 + fi + + CHECKS="${CHECKS:-drivers crds}" + + log() { echo "$1" > /dev/termination-log; } + + require_csidriver() { + local provisioner="$1" + local friendly="$2" + + # Try with error capture so we can disambiguate RBAC vs missing + if "${KUBECTL}" get csidriver "$provisioner" >/dev/null 2>&1 || \ + "${KUBECTL}" get csidrivers.storage.k8s.io "$provisioner" >/dev/null 2>&1; then + return 0 + fi + + + # Final attempt to capture the real error + err_msg="$("${KUBECTL}" get csidriver "$provisioner" 2>&1 || true)" + [ -z "$err_msg" ] && err_msg="$("${KUBECTL}" get csidrivers.storage.k8s.io "$provisioner" 2>&1 || true)" + + if echo "$err_msg" | grep -qiE 'forbidden|permission|unauthorized|cannot.*get'; then + log "$friendly check failed: RBAC insufficient to read CSIDriver $provisioner. "${KUBECTL}" said: ${err_msg}" + exit 2 + fi + + log "$friendly not installed (missing CSIDriver $provisioner). kubectl said: ${err_msg}" + exit 1 + } + + require_crd() { + local crd="$1" + # Same idea: attempt and parse error text + if "${KUBECTL}" get crd "$crd" >/dev/null 2>&1; then + return 0 + fi + err="$("${KUBECTL}" get crd "$crd" 2>&1 || true)" + if echo "$err" | grep -qiE 'forbidden|permission|unauthorized|cannot.*get'; then + log "CRD check failed: RBAC insufficient to read $crd. "${KUBECTL}" said: ${err}" + exit 2 + fi + log "Missing required CRD: $crd. "${KUBECTL}" said: ${err}" + exit 1 + } + + # Dispatch selected checks + for c in $CHECKS; do + case "$c" in + drivers) + require_csidriver "s3.csi.aws.com" "S3 CSI driver" + require_csidriver "fsx.csi.aws.com" "FSx CSI driver" + ;; + crds) + require_crd "certificaterequests.cert-manager.io" "cert-manager CRD" + require_crd "certificates.cert-manager.io" "cert-manager CRD" + ;; + *) + log "Unknown check: $c" + exit 1 + ;; + esac + done + + log "Checks passed: $CHECKS" + exit 0 + env: + - name: CHECKS + value: "drivers crds" + image: "public.ecr.aws/bitnami/kubectl:1.30" + imagePullPolicy: Always + name: check-csi-drivers + resources: { } + terminationMessagePath: /dev/termination-log containers: - command: - /hyperpod-inference-manager @@ -93,7 +181,7 @@ spec: resources: limits: cpu: 500m - memory: 128Mi + memory: 256Mi requests: cpu: 10m memory: 64Mi @@ -125,4 +213,4 @@ spec: volumes: - name: webhook-certs secret: - secretName: webhook-server-cert \ No newline at end of file + secretName: webhook-server-cert diff --git a/helm_chart/HyperPodHelmChart/charts/inference-operator/values.yaml b/helm_chart/HyperPodHelmChart/charts/inference-operator/values.yaml index 868b7765..878fb183 100644 --- a/helm_chart/HyperPodHelmChart/charts/inference-operator/values.yaml +++ b/helm_chart/HyperPodHelmChart/charts/inference-operator/values.yaml @@ -21,7 +21,7 @@ image: ap-southeast-4: 311141544681.dkr.ecr.ap-southeast-4.amazonaws.com ap-southeast-3: 158128612970.dkr.ecr.ap-southeast-3.amazonaws.com eu-south-2: 025050981094.dkr.ecr.eu-south-2.amazonaws.com - tag: v2.0 + tag: v2.1 pullPolicy: Always repository: hyperpodClusterArn: diff --git a/hyperpod-jumpstart-inference-template/hyperpod_jumpstart_inference_template/registry.py b/hyperpod-jumpstart-inference-template/hyperpod_jumpstart_inference_template/registry.py index d1abfdea..96b80a47 100644 --- a/hyperpod-jumpstart-inference-template/hyperpod_jumpstart_inference_template/registry.py +++ b/hyperpod-jumpstart-inference-template/hyperpod_jumpstart_inference_template/registry.py @@ -10,13 +10,17 @@ # distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF # ANY KIND, either express or implied. See the License for the specific # language governing permissions and limitations under the License. -from hyperpod_jumpstart_inference_template.v1_0 import model as v1 -from hyperpod_jumpstart_inference_template.v1_0.template import TEMPLATE_CONTENT as v1_template +from hyperpod_jumpstart_inference_template.v1_0 import model as v1_0 +from hyperpod_jumpstart_inference_template.v1_1 import model as v1_1 +from hyperpod_jumpstart_inference_template.v1_0.template import TEMPLATE_CONTENT as v1_0_template +from hyperpod_jumpstart_inference_template.v1_1.template import TEMPLATE_CONTENT as v1_1_template SCHEMA_REGISTRY = { - "1.0": v1.FlatHPJumpStartEndpoint, + "1.0": v1_0.FlatHPJumpStartEndpoint, + "1.1": v1_1.FlatHPJumpStartEndpoint, } TEMPLATE_REGISTRY = { - "1.0": v1_template + "1.0": v1_0_template, + "1.1": v1_1_template, } diff --git a/hyperpod-jumpstart-inference-template/hyperpod_jumpstart_inference_template/v1_1/__init__.py b/hyperpod-jumpstart-inference-template/hyperpod_jumpstart_inference_template/v1_1/__init__.py new file mode 100644 index 00000000..68054b98 --- /dev/null +++ b/hyperpod-jumpstart-inference-template/hyperpod_jumpstart_inference_template/v1_1/__init__.py @@ -0,0 +1,12 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. \ No newline at end of file diff --git a/hyperpod-jumpstart-inference-template/hyperpod_jumpstart_inference_template/v1_1/model.py b/hyperpod-jumpstart-inference-template/hyperpod_jumpstart_inference_template/v1_1/model.py new file mode 100644 index 00000000..3b428f13 --- /dev/null +++ b/hyperpod-jumpstart-inference-template/hyperpod_jumpstart_inference_template/v1_1/model.py @@ -0,0 +1,136 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +from pydantic import BaseModel, Field, model_validator, ConfigDict +from typing import Optional + +# reuse the nested types +from sagemaker.hyperpod.inference.config.hp_jumpstart_endpoint_config import ( + Model, + SageMakerEndpoint, + Server, + TlsConfig, + Validations, +) +from sagemaker.hyperpod.inference.hp_jumpstart_endpoint import HPJumpStartEndpoint +from sagemaker.hyperpod.common.config.metadata import Metadata + + +class FlatHPJumpStartEndpoint(BaseModel): + model_config = ConfigDict(extra="forbid") + + namespace: Optional[str] = Field( + default=None, description="Kubernetes namespace", min_length=1 + ) + + accept_eula: bool = Field( + False, + alias="accept_eula", + description="Whether model terms of use have been accepted", + ) + + metadata_name: Optional[str] = Field( + None, + alias="metadata_name", + description="Name of the jumpstart endpoint object", + max_length=63, + pattern=r"^[a-zA-Z0-9](-*[a-zA-Z0-9]){0,62}$", + ) + + model_id: str = Field( + ..., + alias="model_id", + description="Unique identifier of the model within the hub", + min_length=1, + max_length=63, + pattern=r"^[a-zA-Z0-9](-*[a-zA-Z0-9]){0,62}$", + ) + + model_version: Optional[str] = Field( + None, + alias="model_version", + description="Semantic version of the model to deploy (e.g. 1.0.0)", + min_length=5, + max_length=14, + pattern=r"^\d{1,4}\.\d{1,4}\.\d{1,4}$", + ) + + instance_type: str = Field( + ..., + alias="instance_type", + description="EC2 instance type for the inference server", + pattern=r"^ml\..*", + ) + + accelerator_partition_type: Optional[str] = Field( + None, + alias="accelerator_partition_type", + description="MIG profile to use for GPU partitioning", + pattern=r"^mig-.*$", + ) + + accelerator_partition_validation: Optional[bool] = Field( + True, + alias="accelerator_partition_validation", + description="Enable MIG validation for GPU partitioning. Default is true." + ) + + endpoint_name: Optional[str] = Field( + None, + alias="endpoint_name", + description="Name of SageMaker endpoint; empty string means no creation", + max_length=63, + pattern=r"^[a-zA-Z0-9](-*[a-zA-Z0-9]){0,62}$", + ) + tls_certificate_output_s3_uri: Optional[str] = Field( + None, + alias="tls_certificate_output_s3_uri", + description="S3 URI to write the TLS certificate", + pattern=r"^s3://([^/]+)/?(.*)$", + ) + + @model_validator(mode="after") + def validate_name(self): + if not self.metadata_name and not self.endpoint_name: + raise ValueError("Either metadata_name or endpoint_name must be provided") + return self + + def to_domain(self) -> HPJumpStartEndpoint: + if self.endpoint_name and not self.metadata_name: + self.metadata_name = self.endpoint_name + + metadata = Metadata(name=self.metadata_name, namespace=self.namespace) + + model = Model( + accept_eula=self.accept_eula, + model_id=self.model_id, + model_version=self.model_version, + ) + validations = Validations( + accelerator_partition_validation=self.accelerator_partition_validation, + ) + server = Server( + instance_type=self.instance_type, + accelerator_partition_type=self.accelerator_partition_type, + validations=validations, + ) + sage_ep = SageMakerEndpoint(name=self.endpoint_name) + tls = TlsConfig( + tls_certificate_output_s3_uri=self.tls_certificate_output_s3_uri + ) + return HPJumpStartEndpoint( + metadata=metadata, + model=model, + server=server, + sage_maker_endpoint=sage_ep, + tls_config=tls, + ) diff --git a/hyperpod-jumpstart-inference-template/hyperpod_jumpstart_inference_template/v1_1/schema.json b/hyperpod-jumpstart-inference-template/hyperpod_jumpstart_inference_template/v1_1/schema.json new file mode 100644 index 00000000..df966f63 --- /dev/null +++ b/hyperpod-jumpstart-inference-template/hyperpod_jumpstart_inference_template/v1_1/schema.json @@ -0,0 +1,132 @@ +{ + "additionalProperties": false, + "properties": { + "namespace": { + "anyOf": [ + { + "minLength": 1, + "type": "string" + }, + { + "type": "null" + } + ], + "default": null, + "description": "Kubernetes namespace", + "title": "Namespace" + }, + "accept_eula": { + "default": false, + "description": "Whether model terms of use have been accepted", + "title": "Accept Eula", + "type": "boolean" + }, + "metadata_name": { + "anyOf": [ + { + "maxLength": 63, + "pattern": "^[a-zA-Z0-9](-*[a-zA-Z0-9]){0,62}$", + "type": "string" + }, + { + "type": "null" + } + ], + "default": null, + "description": "Name of the jumpstart endpoint object", + "title": "Metadata Name" + }, + "model_id": { + "description": "Unique identifier of the model within the hub", + "maxLength": 63, + "minLength": 1, + "pattern": "^[a-zA-Z0-9](-*[a-zA-Z0-9]){0,62}$", + "title": "Model Id", + "type": "string" + }, + "model_version": { + "anyOf": [ + { + "maxLength": 14, + "minLength": 5, + "pattern": "^\\d{1,4}\\.\\d{1,4}\\.\\d{1,4}$", + "type": "string" + }, + { + "type": "null" + } + ], + "default": null, + "description": "Semantic version of the model to deploy (e.g. 1.0.0)", + "title": "Model Version" + }, + "instance_type": { + "description": "EC2 instance type for the inference server", + "pattern": "^ml\\..*", + "title": "Instance Type", + "type": "string" + }, + "accelerator_partition_type": { + "anyOf": [ + { + "type": "string" + }, + { + "type": "null" + } + ], + "default": null, + "description": "MIG profile to use for GPU partitioning", + "pattern": "^mig-.*$", + "title": "Accelerator Partition Type" + }, + "accelerator_partition_validation": { + "anyOf": [ + { + "type": "boolean" + }, + { + "type": "null" + } + ], + "default": true, + "description": "Enable MIG validation for GPU partitioning. Default is true.", + "title": "Accelerator Partition Validation" + }, + "endpoint_name": { + "anyOf": [ + { + "maxLength": 63, + "pattern": "^[a-zA-Z0-9](-*[a-zA-Z0-9]){0,62}$", + "type": "string" + }, + { + "type": "null" + } + ], + "default": null, + "description": "Name of SageMaker endpoint; empty string means no creation", + "title": "Endpoint Name" + }, + "tls_certificate_output_s3_uri": { + "anyOf": [ + { + "pattern": "^s3://([^/]+)/?(.*)$", + "type": "string" + }, + { + "type": "null" + } + ], + "default": null, + "description": "S3 URI to write the TLS certificate", + "title": "Tls Certificate Output S3 Uri" + } + }, + "required": [ + "model_id", + "instance_type" + ], + "title": "FlatHPJumpStartEndpoint", + "type": "object" +} \ No newline at end of file diff --git a/hyperpod-jumpstart-inference-template/hyperpod_jumpstart_inference_template/v1_1/template.py b/hyperpod-jumpstart-inference-template/hyperpod_jumpstart_inference_template/v1_1/template.py new file mode 100644 index 00000000..580cf514 --- /dev/null +++ b/hyperpod-jumpstart-inference-template/hyperpod_jumpstart_inference_template/v1_1/template.py @@ -0,0 +1,21 @@ +TEMPLATE_CONTENT = """ +apiVersion: inference.sagemaker.aws.amazon.com/v1alpha1 +kind: JumpStartModel +metadata: + name: {{ model_id }} + namespace: {{ namespace or "default" }} +spec: + model: + acceptEula: {{ accept_eula or false }} + modelHubName: "SageMakerPublicHub" + modelId: {{ model_id }} + modelVersion: {{ model_version or "" }} + sageMakerEndpoint: + name: {{ endpoint_name or "" }} + server: + instanceType: {{ instance_type }} + {% if accelerator_partition_type is not none %}acceleratorPartitionType: "{{ accelerator_partition_type }}"{% endif %} + {% if accelerator_partition_validation is not none %}validations: + {% if accelerator_partition_validation is not none %} acceleratorPartitionValidation: {{ accelerator_partition_validation }}{% endif %} + {% endif %} +""" \ No newline at end of file diff --git a/src/sagemaker/hyperpod/cli/commands/inference.py b/src/sagemaker/hyperpod/cli/commands/inference.py index f63cb590..20440dc4 100644 --- a/src/sagemaker/hyperpod/cli/commands/inference.py +++ b/src/sagemaker/hyperpod/cli/commands/inference.py @@ -20,7 +20,7 @@ # CREATE @click.command("hyp-jumpstart-endpoint") -@click.option("--version", default="1.0", help="Schema version to use") +@click.option("--version", default="1.1", help="Schema version to use") @click.option("--debug", default=False, help="Enable debug mode") @generate_click_command( schema_pkg="hyperpod_jumpstart_inference_template", @@ -37,7 +37,7 @@ def js_create(version, debug, js_endpoint): @click.command("hyp-custom-endpoint") -@click.option("--version", default="1.0", help="Schema version to use") +@click.option("--version", default="1.1", help="Schema version to use") @click.option("--debug", default=False, help="Enable debug mode") @generate_click_command( schema_pkg="hyperpod_custom_inference_template", diff --git a/src/sagemaker/hyperpod/inference/config/hp_jumpstart_endpoint_config.py b/src/sagemaker/hyperpod/inference/config/hp_jumpstart_endpoint_config.py index ff4e4fc6..5e971868 100644 --- a/src/sagemaker/hyperpod/inference/config/hp_jumpstart_endpoint_config.py +++ b/src/sagemaker/hyperpod/inference/config/hp_jumpstart_endpoint_config.py @@ -255,6 +255,16 @@ class SageMakerEndpoint(BaseModel): ) +class Validations(BaseModel): + model_config = ConfigDict(extra='forbid') + + acceleratorPartitionValidation: Optional[bool] = Field( + default=True, + alias="accelerator_partition_validation", + description="Enable MIG validation for GPU partitioning. Default is true." + ) + + class Server(BaseModel): model_config = ConfigDict(extra="forbid") @@ -268,6 +278,17 @@ class Server(BaseModel): description="The EC2 instance type to use for the inference server. Must be one of the supported types.", ) + acceleratorPartitionType: Optional[str] = Field( + default=None, + alias="accelerator_partition_type", + description="MIG profile to use for GPU partitioning" + ) + + validations: Optional[Validations] = Field( + default=None, + description="Validations configuration for the server" + ) + class TlsConfig(BaseModel): model_config = ConfigDict(extra="forbid") diff --git a/src/sagemaker/hyperpod/inference/constant.py b/src/sagemaker/hyperpod/inference/constant.py new file mode 100644 index 00000000..edf6fa78 --- /dev/null +++ b/src/sagemaker/hyperpod/inference/constant.py @@ -0,0 +1,58 @@ +INSTANCE_MIG_PROFILES = { + "ml.p4d.24xlarge": [ + "mig-1g.5gb", + "mig-1g.10gb", + "mig-2g.10gb", + "mig-3g.20gb", + "mig-4g.20gb", + "mig-7g.40gb" + ], + "ml.p4de.24xlarge": [ + "mig-1g.5gb", + "mig-1g.10gb", + "mig-2g.10gb", + "mig-3g.20gb", + "mig-4g.20gb", + "mig-7g.40gb" + ], + "ml.p5.48xlarge": [ + "mig-1g.10gb", + "mig-1g.20gb", + "mig-2g.20gb", + "mig-3g.40gb", + "mig-4g.40gb", + "mig-7g.80gb" + ], + "ml.p5e.48xlarge": [ + "mig-1g.18gb", + "mig-1g.35gb", + "mig-2g.35gb", + "mig-3g.71gb", + "mig-4g.71gb", + "mig-7g.141gb" + ], + "ml.p5en.48xlarge": [ + "mig-1g.18gb", + "mig-1g.35gb", + "mig-2g.35gb", + "mig-3g.71gb", + "mig-4g.71gb", + "mig-7g.141gb" + ], + "p6-b200.48xlarge": [ + "mig-1g.23gb", + "mig-1g.47gb", + "mig-2g.47gb", + "mig-3g.93gb", + "mig-4g.93gb", + "mig-7g.186gb" + ], + "ml.p6e-gb200.36xlarge": [ + "mig-1g.23gb", + "mig-1g.47gb", + "mig-2g.47gb", + "mig-3g.93gb", + "mig-4g.93gb", + "mig-7g.186gb" + ] +} \ No newline at end of file diff --git a/src/sagemaker/hyperpod/inference/hp_jumpstart_endpoint.py b/src/sagemaker/hyperpod/inference/hp_jumpstart_endpoint.py index d406dc07..c58a3061 100644 --- a/src/sagemaker/hyperpod/inference/hp_jumpstart_endpoint.py +++ b/src/sagemaker/hyperpod/inference/hp_jumpstart_endpoint.py @@ -1,6 +1,7 @@ from typing import Dict, List, Optional from pydantic import Field, ValidationError from sagemaker.hyperpod.inference.config.constants import * +from sagemaker.hyperpod.inference.constant import INSTANCE_MIG_PROFILES from sagemaker.hyperpod.inference.hp_endpoint_base import HPEndpointBase from sagemaker.hyperpod.common.config.metadata import Metadata from sagemaker.hyperpod.common.utils import ( @@ -29,6 +30,9 @@ class HPJumpStartEndpoint(_HPJumpStartEndpoint, HPEndpointBase): def _create_internal(self, spec, debug=False): """Shared internal create logic""" + + @_hyperpod_telemetry_emitter(Feature.HYPERPOD, "create_js_endpoint") + def create(self, debug=False) -> None: logger = self.get_logger() logger = setup_logging(logger, debug) @@ -40,7 +44,7 @@ def _create_internal(self, spec, debug=False): endpoint_name = spec.sageMakerEndpoint.name if not endpoint_name and not name: - raise Exception('Either metadata name or endpoint name must be provided') + raise Exception("Either metadata name or endpoint name must be provided") if not name: name = endpoint_name @@ -56,7 +60,11 @@ def _create_internal(self, spec, debug=False): annotations=self.metadata.annotations if self.metadata else None, ) - self.validate_instance_type(spec.model.modelId, spec.server.instanceType) + # Only validate instance type if accelerator_partition_validation is provided + if not spec.server.acceleratorPartitionType: + self.validate_instance_type(spec.model.modelId, spec.server.instanceType) + else: + self.validate_mig_profile(spec.server.acceleratorPartitionType, spec.server.instanceType) self.call_create_api( metadata=metadata, @@ -85,9 +93,50 @@ def create_from_dict( input: Dict, debug = False ) -> None: + logger = self.get_logger() + logger = setup_logging(logger, debug) + spec = _HPJumpStartEndpoint.model_validate(input, by_name=True) self._create_internal(spec, debug) + endpoint_name = "" + name = self.metadata.name if self.metadata else None + namespace = self.metadata.namespace if self.metadata else None + + if spec.sageMakerEndpoint and spec.sageMakerEndpoint.name: + endpoint_name = spec.sageMakerEndpoint.name + + if not endpoint_name and not name: + raise Exception('Input "name" is required if endpoint name is not provided') + + if not name: + name = endpoint_name + + if not namespace: + namespace = get_default_namespace() + + # Only validate instance type if accelerator_partition_validation is provided + if not spec.server.acceleratorPartitionType: + self.validate_instance_type(spec.model.modelId, spec.server.instanceType) + else: + self.validate_mig_profile(spec.server.acceleratorPartitionType, spec.server.instanceType) + + self.call_create_api( + name=name, # use model name as metadata name + kind=JUMPSTART_MODEL_KIND, + namespace=namespace, + spec=spec, + debug=debug, + ) + + self.metadata = Metadata( + name=name, + namespace=namespace, + ) + + logger.info( + f"Creating JumpStart model and sagemaker endpoint. Endpoint name: {endpoint_name}.\n The process may take a few minutes..." + ) def refresh(self): if not self.metadata: @@ -224,6 +273,40 @@ def validate_instance_type(self, model_id: str, instance_type: str): f"Current HyperPod cluster does not have instance type {instance_type}. Supported instance types are {cluster_instance_types}" ) + def validate_mig_profile(self, mig_profile: str, instance_type: str): + """ + Validate if the MIG profile is supported for the given instance type. + + Args: + instance_type: SageMaker instance type (e.g., "ml.p4d.24xlarge") + mig_profile: MIG profile (e.g., "1g.10gb") + + Raises: + ValueError: If the instance type doesn't support MIG profiles or if the MIG profile is not supported for the instance type + """ + logger = self.get_logger() + logger = setup_logging(logger) + + if instance_type not in INSTANCE_MIG_PROFILES: + error_msg = ( + f"Instance type '{instance_type}' does not support MIG profiles. " + f"Supported instance types: {list(INSTANCE_MIG_PROFILES.keys())}" + ) + logger.error(error_msg) + raise ValueError(error_msg) + + if mig_profile not in INSTANCE_MIG_PROFILES[instance_type]: + error_msg = ( + f"MIG profile '{mig_profile}' is not supported for instance type '{instance_type}'. " + f"Supported MIG profiles for {instance_type}: {INSTANCE_MIG_PROFILES[instance_type]}" + ) + logger.error(error_msg) + raise ValueError(error_msg) + + logger.info( + f"MIG profile '{mig_profile}' is valid for instance type '{instance_type}'" + ) + @classmethod @_hyperpod_telemetry_emitter(Feature.HYPERPOD, "list_pods_endpoint") def list_pods(cls, namespace=None, endpoint_name=None): diff --git a/test/unit_tests/cli/test_inference.py b/test/unit_tests/cli/test_inference.py index c9e3e695..e4b3a162 100644 --- a/test/unit_tests/cli/test_inference.py +++ b/test/unit_tests/cli/test_inference.py @@ -47,11 +47,80 @@ def test_js_create_with_required_args(): "sagemaker.hyperpod.common.cli_decorators._is_valid_jumpstart_model_id" ) as mock_model_validation, patch( "sagemaker.hyperpod.common.cli_decorators._namespace_exists" - ) as mock_namespace_exists: + ) as mock_namespace_exists, patch( + "sagemaker.hyperpod.inference.hp_jumpstart_endpoint.HPJumpStartEndpoint.validate_instance_type" + ) as mock_validate_instance, patch( + "sagemaker.hyperpod.common.utils.get_jumpstart_model_instance_types" + ) as mock_get_instance_types, patch( + "sagemaker.hyperpod.common.utils.get_cluster_instance_types" + ) as mock_get_cluster_types, patch( + "sagemaker.hyperpod.inference.hp_jumpstart_endpoint.HPJumpStartEndpoint.create" + ) as mock_create: # Mock enhanced error handling mock_model_validation.return_value = True # Allow test model-id mock_namespace_exists.return_value = True # Allow test namespace + mock_validate_instance.return_value = None # Skip validation + mock_get_instance_types.return_value = [ + "ml.p4d.24xlarge" + ] # Mock supported types + mock_get_cluster_types.return_value = ["ml.p4d.24xlarge"] # Mock cluster types + mock_create.return_value = None # Mock successful creation + + # Prepare mock model-to-domain mapping + mock_model_class = Mock() + mock_model_instance = Mock() + domain_obj = Mock() + domain_obj.create = mock_create + mock_model_instance.to_domain.return_value = domain_obj + mock_model_class.return_value = mock_model_instance + + # Set up the registry for version 1.0 + jreg.SCHEMA_REGISTRY["1.0"] = mock_model_class + + runner = CliRunner() + result = runner.invoke( + js_create, + [ + "--namespace", + "test-ns", + "--version", + "1.0", + "--model-id", + "test-model-id", + "--instance-type", + "ml.p4d.24xlarge", # Use a supported instance type + "--endpoint-name", + "test-endpoint", + ], + ) + + assert result.exit_code == 0, result.output + mock_create.assert_called_once_with(debug=False) + + +def test_js_create_missing_required_args(): + runner = CliRunner() + result = runner.invoke(js_create, []) + assert result.exit_code != 0 + assert "Missing option" in result.output + + +def test_js_create_with_mig_profile(): + """ + Test js_create with MIG profile (accelerator partition) options using v1.1 schema. + """ + with patch( + "sagemaker.hyperpod.cli.commands.inference.HPJumpStartEndpoint" + ) as mock_endpoint_class, patch( + "sagemaker.hyperpod.common.cli_decorators._is_valid_jumpstart_model_id" + ) as mock_model_validation, patch( + "sagemaker.hyperpod.common.cli_decorators._namespace_exists" + ) as mock_namespace_exists: + + # Mock enhanced error handling + mock_model_validation.return_value = True + mock_namespace_exists.return_value = True # Mock schema loading mock_load_schema.return_value = { @@ -71,7 +140,8 @@ def test_js_create_with_required_args(): mock_endpoint_class.model_construct.return_value = domain_obj jreg.SCHEMA_REGISTRY.clear() - jreg.SCHEMA_REGISTRY["1.0"] = mock_model_class + # Set up the registry for version 1.1 + jreg.SCHEMA_REGISTRY["1.1"] = mock_model_class runner = CliRunner() result = runner.invoke( @@ -80,11 +150,15 @@ def test_js_create_with_required_args(): "--namespace", "test-ns", "--version", - "1.0", + "1.1", "--model-id", "test-model-id", "--instance-type", - "ml.t2.micro", + "ml.p4d.24xlarge", + "--accelerator-partition-type", + "mig-1g.5gb", + "--accelerator-partition-validation", + "true", "--endpoint-name", "test-endpoint", ], @@ -93,6 +167,11 @@ def test_js_create_with_required_args(): assert result.exit_code == 0, result.output domain_obj.create.assert_called_once_with(debug=False) + # Verify the model instance was created with MIG profile parameters + mock_model_class.assert_called_once() + call_args = mock_model_class.call_args[1] + assert "accelerator_partition_type" in call_args + assert "accelerator_partition_validation" in call_args def test_js_create_missing_required_args(): runner = CliRunner() @@ -100,6 +179,62 @@ def test_js_create_missing_required_args(): assert result.exit_code != 0 assert "Missing option" in result.output +def test_js_create_mig_validation_error_handling(): + """ + Test js_create properly handles MIG profile validation errors using v1.1 schema. + """ + with patch( + "sagemaker.hyperpod.cli.commands.inference.HPJumpStartEndpoint" + ) as mock_endpoint_class, patch( + "sagemaker.hyperpod.common.cli_decorators._is_valid_jumpstart_model_id" + ) as mock_model_validation, patch( + "sagemaker.hyperpod.common.cli_decorators._namespace_exists" + ) as mock_namespace_exists: + + # Mock enhanced error handling + mock_model_validation.return_value = True + mock_namespace_exists.return_value = True + + # Prepare mock model-to-domain mapping that raises validation error + mock_model_class = Mock() + mock_model_instance = Mock() + domain_obj = Mock() + # Simulate MIG validation error during create + domain_obj.create.side_effect = ValueError( + "MIG profile '1g.5gb' is not supported for instance type 'ml.c5.2xlarge'" + ) + mock_model_instance.to_domain.return_value = domain_obj + mock_model_class.return_value = mock_model_instance + mock_endpoint_class.model_construct.return_value = domain_obj + + # Set up the registry for version 1.1 + jreg.SCHEMA_REGISTRY["1.1"] = mock_model_class + + runner = CliRunner() + result = runner.invoke( + js_create, + [ + "--namespace", + "test-ns", + "--version", + "1.1", + "--model-id", + "test-model-id", + "--instance-type", + "ml.c5.2xlarge", # Instance type that doesn't support MIG + "--accelerator-partition-type", + "1g.5gb", # Invalid MIG profile for this instance + "--accelerator-partition-validation", + "true", + "--endpoint-name", + "test-endpoint", + ], + ) + + # Should fail due to MIG validation error + assert result.exit_code != 0 + assert "MIG profile" in result.output or "not supported" in result.output + @patch("sagemaker.hyperpod.common.cli_decorators._namespace_exists") @patch("sagemaker.hyperpod.cli.commands.inference.HPJumpStartEndpoint") diff --git a/test/unit_tests/inference/test_hp_jumpstart_endpoint.py b/test/unit_tests/inference/test_hp_jumpstart_endpoint.py index 09999b56..86c24267 100644 --- a/test/unit_tests/inference/test_hp_jumpstart_endpoint.py +++ b/test/unit_tests/inference/test_hp_jumpstart_endpoint.py @@ -7,9 +7,11 @@ Server, SageMakerEndpoint, TlsConfig, + Validations, ) from sagemaker.hyperpod.common.config import Metadata + class TestHPJumpStartEndpoint(unittest.TestCase): def setUp(self): @@ -35,8 +37,13 @@ def setUp(self): @patch.object(HPJumpStartEndpoint, "validate_instance_type") @patch.object(HPJumpStartEndpoint, "call_create_api") - @patch('sagemaker.hyperpod.inference.hp_jumpstart_endpoint.get_default_namespace', return_value='default') - def test_create(self, mock_get_namespace, mock_create_api, mock_validate_instance_type): + @patch( + "sagemaker.hyperpod.inference.hp_jumpstart_endpoint.get_default_namespace", + return_value="default", + ) + def test_create( + self, mock_get_namespace, mock_create_api, mock_validate_instance_type + ): self.endpoint.create() @@ -48,18 +55,17 @@ def test_create(self, mock_get_namespace, mock_create_api, mock_validate_instanc ) self.assertEqual(self.endpoint.metadata.name, "bert-testing-jumpstart-7-2-2") - @patch.object(HPJumpStartEndpoint, "validate_instance_type") @patch.object(HPJumpStartEndpoint, "call_create_api") def test_create_with_metadata(self, mock_create_api, mock_validate_instance_type): """Test create_from_dict uses metadata name and namespace when endpoint name not provided""" - + # Create endpoint without sageMakerEndpoint name to force using metadata endpoint_without_name = HPJumpStartEndpoint( model=Model(model_id="test-model"), server=Server(instance_type="ml.c5.2xlarge"), tls_config=TlsConfig(tls_certificate_output_s3_uri="s3://test-bucket"), - metadata=Metadata(name="metadata-test-name", namespace="metadata-test-ns") + metadata=Metadata(name="metadata-test-name", namespace="metadata-test-ns"), ) endpoint_without_name.create() @@ -73,8 +79,13 @@ def test_create_with_metadata(self, mock_create_api, mock_validate_instance_type @patch.object(HPJumpStartEndpoint, "validate_instance_type") @patch.object(HPJumpStartEndpoint, "call_create_api") - @patch('sagemaker.hyperpod.inference.hp_jumpstart_endpoint.get_default_namespace', return_value='default') - def test_create_from_dict(self, mock_get_namespace, mock_create_api, mock_validate_instance_type): + @patch( + "sagemaker.hyperpod.inference.hp_jumpstart_endpoint.get_default_namespace", + return_value="default", + ) + def test_create_from_dict( + self, mock_get_namespace, mock_create_api, mock_validate_instance_type + ): input_dict = self.endpoint.model_dump(exclude_none=True) @@ -178,13 +189,7 @@ def test_list_pods(self, mock_verify_config, mock_core_api, mock_list_api): mock_pod3, ] - mock_list_api.return_value = { - "items": [ - { - "metadata": {"name": "js-endpoint"} - } - ] - } + mock_list_api.return_value = {"items": [{"metadata": {"name": "js-endpoint"}}]} result = self.endpoint.list_pods(namespace="test-ns") @@ -211,9 +216,280 @@ def test_list_pods_with_endpoint_name(self, mock_verify_config, mock_core_api): mock_pod3, ] - result = self.endpoint.list_pods(namespace="test-ns", endpoint_name="js-endpoint1") + result = self.endpoint.list_pods( + namespace="test-ns", endpoint_name="js-endpoint1" + ) self.assertEqual(result, ["js-endpoint1-pod1", "js-endpoint1-pod2"]) mock_core_api.return_value.list_namespaced_pod.assert_called_once_with( namespace="test-ns" ) + + def test_validate_mig_profile_valid(self): + """Test validate_mig_profile with valid instance type and MIG profile""" + # Test with valid combinations + self.endpoint.validate_mig_profile("mig-1g.5gb", "ml.p4d.24xlarge") + self.endpoint.validate_mig_profile("mig-7g.40gb", "ml.p4d.24xlarge") + self.endpoint.validate_mig_profile("mig-1g.10gb", "ml.p4de.24xlarge") + self.endpoint.validate_mig_profile("mig-7g.80gb", "ml.p5.48xlarge") + + def test_validate_mig_profile_invalid_instance_type(self): + """Test validate_mig_profile with unsupported instance type""" + with self.assertRaises(ValueError) as context: + self.endpoint.validate_mig_profile("1g.5gb", "ml.c5.2xlarge") + + self.assertIn( + "Instance type 'ml.c5.2xlarge' does not support MIG profiles", + str(context.exception), + ) + self.assertIn("Supported instance types:", str(context.exception)) + + def test_validate_mig_profile_invalid_mig_profile(self): + """Test validate_mig_profile with unsupported MIG profile for valid instance type""" + with self.assertRaises(ValueError) as context: + self.endpoint.validate_mig_profile("invalid.profile", "ml.p4d.24xlarge") + + self.assertIn( + "MIG profile 'invalid.profile' is not supported for instance type 'ml.p4d.24xlarge'", + str(context.exception), + ) + self.assertIn( + "Supported MIG profiles for ml.p4d.24xlarge:", str(context.exception) + ) + + def test_validate_mig_profile_wrong_profile_for_instance(self): + """Test validate_mig_profile with MIG profile that exists but not for the specific instance type""" + # 7g.80gb is valid for p4de but not p4d + with self.assertRaises(ValueError) as context: + self.endpoint.validate_mig_profile("7g.80gb", "ml.p4d.24xlarge") + + self.assertIn( + "MIG profile '7g.80gb' is not supported for instance type 'ml.p4d.24xlarge'", + str(context.exception), + ) + + @patch.object(HPJumpStartEndpoint, "validate_mig_profile") + @patch.object(HPJumpStartEndpoint, "call_create_api") + @patch( + "sagemaker.hyperpod.inference.hp_jumpstart_endpoint.get_default_namespace", + return_value="default", + ) + def test_create_with_accelerator_partition_validation( + self, mock_get_namespace, mock_create_api, mock_validate_mig + ): + """Test create method uses MIG validation when accelerator_partition_validation is True""" + # Create endpoint with accelerator partition validation enabled + model = Model(model_id="test-model") + validations = Validations( + accelerator_partition_validation=True, + ) + server = Server( + instance_type="ml.p4d.24xlarge", + validations=validations, + accelerator_partition_type="1g.5gb", + ) + endpoint = HPJumpStartEndpoint( + model=model, + server=server, + sage_maker_endpoint=SageMakerEndpoint(name="test-endpoint"), + tls_config=TlsConfig(tls_certificate_output_s3_uri="s3://test-bucket"), + ) + + endpoint.create() + + # Should call validate_mig_profile instead of validate_instance_type + mock_validate_mig.assert_called_once_with("1g.5gb", "ml.p4d.24xlarge") + mock_create_api.assert_called_once() + + @patch.object(HPJumpStartEndpoint, "validate_instance_type") + @patch.object(HPJumpStartEndpoint, "call_create_api") + @patch( + "sagemaker.hyperpod.inference.hp_jumpstart_endpoint.get_default_namespace", + return_value="default", + ) + def test_create_without_accelerator_partition_validation( + self, mock_get_namespace, mock_create_api, mock_validate_instance + ): + """Test create method uses instance type validation when accelerator_partition_validation is False/None""" + # Create endpoint without accelerator partition validation (default behavior) + model = Model(model_id="test-model") + server = Server(instance_type="ml.c5.2xlarge") + endpoint = HPJumpStartEndpoint( + model=model, + server=server, + sage_maker_endpoint=SageMakerEndpoint(name="test-endpoint"), + tls_config=TlsConfig(tls_certificate_output_s3_uri="s3://test-bucket"), + ) + + endpoint.create() + + # Should call validate_instance_type instead of validate_mig_profile + mock_validate_instance.assert_called_once_with("test-model", "ml.c5.2xlarge") + mock_create_api.assert_called_once() + + @patch.object(HPJumpStartEndpoint, "validate_mig_profile") + @patch.object(HPJumpStartEndpoint, "call_create_api") + @patch( + "sagemaker.hyperpod.inference.hp_jumpstart_endpoint.get_default_namespace", + return_value="default", + ) + def test_create_from_dict_with_accelerator_partition_validation( + self, mock_get_namespace, mock_create_api, mock_validate_mig + ): + """Test create_from_dict method uses MIG validation when accelerator_partition_validation is True""" + input_dict = { + "model": {"modelId": "test-model"}, + "server": { + "instanceType": "ml.p4d.24xlarge", + "validations": { + "acceleratorPartitionValidation": True + }, + "acceleratorPartitionType": "1g.5gb", + }, + "sageMakerEndpoint": {"name": "test-endpoint"}, + "tlsConfig": {"tlsCertificateOutputS3Uri": "s3://test-bucket"}, + } + + endpoint = HPJumpStartEndpoint( + model=Model(model_id="dummy"), + server=Server(instance_type="dummy"), + tls_config=TlsConfig(tls_certificate_output_s3_uri="s3://dummy"), + ) + endpoint.create_from_dict(input_dict) + + # Should call validate_mig_profile instead of validate_instance_type + mock_validate_mig.assert_called_once_with("1g.5gb", "ml.p4d.24xlarge") + mock_create_api.assert_called_once() + + @patch.object(HPJumpStartEndpoint, "validate_instance_type") + @patch.object(HPJumpStartEndpoint, "call_create_api") + @patch( + "sagemaker.hyperpod.inference.hp_jumpstart_endpoint.get_default_namespace", + return_value="default", + ) + def test_create_from_dict_without_accelerator_partition_validation( + self, mock_get_namespace, mock_create_api, mock_validate_instance + ): + """Test create_from_dict method uses instance type validation when accelerator_partition_validation is False/None""" + input_dict = { + "model": {"modelId": "test-model"}, + "server": {"instanceType": "ml.c5.2xlarge"}, + "sageMakerEndpoint": {"name": "test-endpoint"}, + "tlsConfig": {"tlsCertificateOutputS3Uri": "s3://test-bucket"}, + } + + endpoint = HPJumpStartEndpoint( + model=Model(model_id="dummy"), + server=Server(instance_type="dummy"), + tls_config=TlsConfig(tls_certificate_output_s3_uri="s3://dummy"), + ) + endpoint.create_from_dict(input_dict) + + # Should call validate_instance_type instead of validate_mig_profile + mock_validate_instance.assert_called_once_with("test-model", "ml.c5.2xlarge") + mock_create_api.assert_called_once() + + def test_validate_mig_profile_edge_cases(self): + """Test validate_mig_profile with various edge cases""" + # Test with different instance types and their specific profiles + test_cases = [ + ("ml.p4de.24xlarge", "mig-1g.5gb"), + ("ml.p5.48xlarge", "mig-3g.40gb"), + ("ml.p5e.48xlarge", "mig-1g.18gb"), + ("ml.p5en.48xlarge", "mig-7g.141gb"), + ("p6-b200.48xlarge", "mig-1g.23gb"), + ("ml.p6e-gb200.36xlarge", "mig-7g.186gb"), + ] + + for instance_type, mig_profile in test_cases: + with self.subTest(instance_type=instance_type, mig_profile=mig_profile): + # Should not raise any exception + self.endpoint.validate_mig_profile(mig_profile, instance_type) + + def test_validate_mig_profile_case_sensitivity(self): + """Test that MIG profile validation is case sensitive""" + with self.assertRaises(ValueError): + # Test uppercase - should fail as profiles are lowercase + self.endpoint.validate_mig_profile("1G.5GB", "ml.p4d.24xlarge") + + @patch.object(HPJumpStartEndpoint, "validate_mig_profile") + @patch.object(HPJumpStartEndpoint, "validate_instance_type") + @patch.object(HPJumpStartEndpoint, "call_create_api") + @patch( + "sagemaker.hyperpod.inference.hp_jumpstart_endpoint.get_default_namespace", + return_value="default", + ) + def test_create_validation_logic_priority( + self, + mock_get_namespace, + mock_create_api, + mock_validate_instance, + mock_validate_mig, + ): + """Test that accelerator_partition_validation takes priority over regular validation""" + # Create endpoint with both accelerator partition validation and regular fields + model = Model(model_id="test-model") + validations = Validations( + accelerator_partition_validation=True, + ) + server = Server( + instance_type="ml.p4d.24xlarge", + validations=validations, + accelerator_partition_type="1g.5gb", + ) + endpoint = HPJumpStartEndpoint( + model=model, + server=server, + sage_maker_endpoint=SageMakerEndpoint(name="test-endpoint"), + tls_config=TlsConfig(tls_certificate_output_s3_uri="s3://test-bucket"), + ) + + endpoint.create() + + # Should only call validate_mig_profile, not validate_instance_type + mock_validate_mig.assert_called_once_with("1g.5gb", "ml.p4d.24xlarge") + mock_validate_instance.assert_not_called() + mock_create_api.assert_called_once() + + def test_create_missing_name_and_endpoint_name(self): + """Test create method raises exception when both metadata name and endpoint name are missing""" + model = Model(model_id="test-model") + server = Server(instance_type="ml.c5.2xlarge") + endpoint = HPJumpStartEndpoint( + model=model, + server=server, + tls_config=TlsConfig(tls_certificate_output_s3_uri="s3://test-bucket"), + # No sageMakerEndpoint name and no metadata + ) + + with self.assertRaises(Exception) as context: + endpoint.create() + + self.assertIn( + "Either metadata name or endpoint name must be provided", + str(context.exception), + ) + + def test_create_from_dict_missing_name_and_endpoint_name(self): + """Test create_from_dict method raises exception when both name and endpoint name are missing""" + input_dict = { + "model": {"modelId": "test-model"}, + "server": {"instanceType": "ml.c5.2xlarge"}, + "tlsConfig": {"tlsCertificateOutputS3Uri": "s3://test-bucket"}, + # No sageMakerEndpoint name + } + + endpoint = HPJumpStartEndpoint( + model=Model(model_id="dummy"), + server=Server(instance_type="dummy"), + tls_config=TlsConfig(tls_certificate_output_s3_uri="s3://dummy"), + # No metadata + ) + + with self.assertRaises(Exception) as context: + endpoint.create_from_dict(input_dict) + + self.assertIn( + 'Input "name" is required if endpoint name is not provided', + str(context.exception), + )