-
Notifications
You must be signed in to change notification settings - Fork 2
Add VertexEndpointHandler for online inference #6
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,142 @@ | ||
| # Copyright 2025 Google LLC | ||
| # | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # | ||
| # https://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
|
|
||
| """A module for handling model inference on Spark DataFrames using a | ||
| Vertex AI Endpoint.""" | ||
| import logging | ||
| from typing import Dict, List, Optional | ||
|
|
||
| import pandas as pd | ||
| from pyspark.sql.types import ArrayType, DoubleType | ||
|
|
||
| from google.cloud import aiplatform | ||
| from google.cloud.dataproc_ml.inference.base_model_handler import ( | ||
| BaseModelHandler, | ||
| Model, | ||
| ) | ||
|
|
||
| logger = logging.getLogger(__name__) | ||
|
|
||
|
|
||
| class VertexEndpoint(Model): | ||
| """A concrete implementation of the Model interface for a | ||
| Vertex AI Endpoint.""" | ||
|
|
||
| def __init__( | ||
| self, | ||
| endpoint: str, | ||
| project: Optional[str] = None, | ||
| location: Optional[str] = None, | ||
| predict_parameters: Optional[Dict] = None, | ||
| batch_size: Optional[int] = None, | ||
| use_dedicated_endpoint: bool = False, | ||
| ): | ||
| """Initializes the VertexEndpoint. | ||
|
|
||
| Args: | ||
| endpoint: The name of the Vertex AI Endpoint. | ||
| project: The GCP project ID. | ||
| location: The GCP location. | ||
| predict_parameters: Parameters for the prediction call. | ||
| batch_size: The number of instances to include in each prediction | ||
| request. Defaults to 10. | ||
| use_dedicated_endpoint: Whether to use the dedicated endpoint for | ||
| prediction. Defaults to False. | ||
| """ | ||
| aiplatform.init(project=project, location=location) | ||
| self.endpoint_client = aiplatform.Endpoint(endpoint_name=endpoint) | ||
| self.predict_parameters = predict_parameters | ||
| self.batch_size = batch_size | ||
| self.use_dedicated_endpoint = use_dedicated_endpoint | ||
|
|
||
| def call(self, batch: pd.Series) -> pd.Series: | ||
| """Overrides the base method to send instances to the | ||
| Vertex AI Endpoint.""" | ||
|
|
||
| # Convert the pandas Series to a list of instances. | ||
| instances: List = batch.tolist() | ||
|
|
||
| all_predictions = [] | ||
|
|
||
| for i in range(0, len(instances), self.batch_size): | ||
| batch_instances = instances[i : i + self.batch_size] | ||
| prediction_result = self.endpoint_client.predict( | ||
| instances=batch_instances, | ||
| parameters=self.predict_parameters, | ||
| use_dedicated_endpoint=self.use_dedicated_endpoint, | ||
| ) | ||
| all_predictions.extend(prediction_result.predictions) | ||
|
|
||
| assert len(all_predictions) == len(instances), ( | ||
| f"Mismatch between number of instances ({len(instances)}) and " | ||
| f"predictions ({len(all_predictions)}). Potential API issue." | ||
| ) | ||
|
|
||
| return pd.Series(all_predictions, index=batch.index) | ||
|
|
||
|
|
||
| class VertexEndpointHandler(BaseModelHandler): | ||
| """A handler for running inference with a deployed model on a | ||
| Vertex AI Endpoint.""" | ||
|
|
||
| def __init__(self, endpoint: str): | ||
| super().__init__() | ||
| self.endpoint = endpoint | ||
| self._project = None | ||
| self._location = None | ||
| self._predict_parameters = None | ||
| self._batch_size = 10 | ||
| self._use_dedicated_endpoint = False | ||
| self.set_return_type(ArrayType(DoubleType())) | ||
|
|
||
| def project(self, project: str) -> "VertexEndpointHandler": | ||
| """Sets the Google Cloud project for the Vertex AI API call.""" | ||
| self._project = project | ||
| return self | ||
|
|
||
| def location(self, location: str) -> "VertexEndpointHandler": | ||
| """Sets the Google Cloud location (region) for Vertex AI API call.""" | ||
| self._location = location | ||
| return self | ||
|
|
||
| def predict_parameters(self, parameters: Dict) -> "VertexEndpointHandler": | ||
| """Sets the parameters for the prediction call.""" | ||
| self._predict_parameters = parameters | ||
| return self | ||
|
|
||
| def batch_size(self, batch_size: int) -> "VertexEndpointHandler": | ||
| """Sets the number of instances to send in each prediction request. | ||
|
|
||
| Defaults to 10 if not set. | ||
| """ | ||
| self._batch_size = batch_size | ||
| return self | ||
|
|
||
| def use_dedicated_endpoint( | ||
| self, use_dedicated_endpoint: bool | ||
| ) -> "VertexEndpointHandler": | ||
| """Sets whether to use the dedicated endpoint for prediction.""" | ||
| self._use_dedicated_endpoint = use_dedicated_endpoint | ||
| return self | ||
|
|
||
| def _load_model(self) -> Model: | ||
| """Loads the VertexEndpoint instance on each Spark executor.""" | ||
| return VertexEndpoint( | ||
| self.endpoint, | ||
| project=self._project, | ||
| location=self._location, | ||
| predict_parameters=self._predict_parameters, | ||
| batch_size=self._batch_size, | ||
| use_dedicated_endpoint=self._use_dedicated_endpoint, | ||
| ) | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,70 @@ | ||
| # Copyright 2025 Google LLC | ||
| # | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # | ||
| # https://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
|
|
||
| """Integration test for VertexEndpointHandler.""" | ||
|
|
||
| import os | ||
|
|
||
| import pandas as pd | ||
| from pyspark.sql import SparkSession | ||
| from pyspark.sql.types import StringType | ||
|
|
||
| from google.cloud.dataproc_ml.inference import VertexEndpointHandler | ||
|
|
||
|
|
||
| def create_prompt(text_series: pd.Series) -> pd.Series: | ||
| """A pre-processor that wraps each text input in a dictionary with | ||
| a 'prompt' key.""" | ||
| return text_series.apply(lambda x: {"prompt": x, "max_tokens": 256}) | ||
|
|
||
|
|
||
| def test_vertex_endpoint_handler(): | ||
| """Tests the VertexEndpointHandler with a live endpoint.""" | ||
| spark = SparkSession.builder.appName( | ||
| "VertexEndpointHandlerTest" | ||
| ).getOrCreate() | ||
|
|
||
| project = os.getenv("GOOGLE_CLOUD_PROJECT") | ||
| location = os.getenv("GOOGLE_CLOUD_LOCATION", "us-central1") | ||
| # TODO: Replace with endpoint creation during test run which shouldn't | ||
| # take more than 20 mins | ||
| endpoint_name = "1121351227238514688" | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The integration test relies on a hardcoded endpoint name. This makes the test fragile and dependent on an external resource that might change or be deleted, causing test failures. While the |
||
|
|
||
| # Create a sample DataFrame with feature vectors | ||
| data = [ | ||
| ("Write a paragraph on India.",), | ||
| ("Who is James Bond?",), | ||
| ] | ||
| df = spark.createDataFrame(data, ["features"]) | ||
|
|
||
| # Configure and apply the handler | ||
| handler = ( | ||
| VertexEndpointHandler(endpoint=endpoint_name) | ||
| .input_cols("features") | ||
| .output_col("predictions") | ||
| .use_dedicated_endpoint(True) | ||
| .pre_processor(create_prompt) | ||
| .set_return_type(StringType()) | ||
| .project(project) | ||
| .location(location) | ||
| ) | ||
|
|
||
| result_df = handler.transform(df) | ||
| results = result_df.collect() | ||
|
|
||
| assert len(results) == 2 | ||
| assert "predictions" in result_df.columns | ||
| assert len(results[0]["predictions"]) > 0 # Check for non-empty prediction | ||
|
|
||
| spark.stop() | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The
batch_sizeparameter can beNone, but thecallmethod doesn't handle this case, which will lead to aTypeErrorwhen it's used inrange(). The docstring for__init__also states thatbatch_sizedefaults to 10, but this default is not applied here. To prevent runtime errors and align with the documentation, you should assign a default value ifNoneis provided.