diff --git a/google/cloud/dataproc_ml/inference/__init__.py b/google/cloud/dataproc_ml/inference/__init__.py index ae1e1d6..e609964 100644 --- a/google/cloud/dataproc_ml/inference/__init__.py +++ b/google/cloud/dataproc_ml/inference/__init__.py @@ -17,5 +17,11 @@ from .gen_ai_model_handler import GenAiModelHandler from .pytorch_model_handler import PyTorchModelHandler from .tensorflow_model_handler import TensorFlowModelHandler +from .vertex_endpoint_handler import VertexEndpointHandler -__all__ = ("GenAiModelHandler", "PyTorchModelHandler", "TensorFlowModelHandler") +__all__ = ( + "GenAiModelHandler", + "PyTorchModelHandler", + "VertexEndpointHandler", + "TensorFlowModelHandler", +) diff --git a/google/cloud/dataproc_ml/inference/vertex_endpoint_handler.py b/google/cloud/dataproc_ml/inference/vertex_endpoint_handler.py new file mode 100644 index 0000000..59abd92 --- /dev/null +++ b/google/cloud/dataproc_ml/inference/vertex_endpoint_handler.py @@ -0,0 +1,142 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""A module for handling model inference on Spark DataFrames using a +Vertex AI Endpoint.""" +import logging +from typing import Dict, List, Optional + +import pandas as pd +from pyspark.sql.types import ArrayType, DoubleType + +from google.cloud import aiplatform +from google.cloud.dataproc_ml.inference.base_model_handler import ( + BaseModelHandler, + Model, +) + +logger = logging.getLogger(__name__) + + +class VertexEndpoint(Model): + """A concrete implementation of the Model interface for a + Vertex AI Endpoint.""" + + def __init__( + self, + endpoint: str, + project: Optional[str] = None, + location: Optional[str] = None, + predict_parameters: Optional[Dict] = None, + batch_size: Optional[int] = None, + use_dedicated_endpoint: bool = False, + ): + """Initializes the VertexEndpoint. + + Args: + endpoint: The name of the Vertex AI Endpoint. + project: The GCP project ID. + location: The GCP location. + predict_parameters: Parameters for the prediction call. + batch_size: The number of instances to include in each prediction + request. Defaults to 10. + use_dedicated_endpoint: Whether to use the dedicated endpoint for + prediction. Defaults to False. + """ + aiplatform.init(project=project, location=location) + self.endpoint_client = aiplatform.Endpoint(endpoint_name=endpoint) + self.predict_parameters = predict_parameters + self.batch_size = batch_size + self.use_dedicated_endpoint = use_dedicated_endpoint + + def call(self, batch: pd.Series) -> pd.Series: + """Overrides the base method to send instances to the + Vertex AI Endpoint.""" + + # Convert the pandas Series to a list of instances. + instances: List = batch.tolist() + + all_predictions = [] + + for i in range(0, len(instances), self.batch_size): + batch_instances = instances[i : i + self.batch_size] + prediction_result = self.endpoint_client.predict( + instances=batch_instances, + parameters=self.predict_parameters, + use_dedicated_endpoint=self.use_dedicated_endpoint, + ) + all_predictions.extend(prediction_result.predictions) + + assert len(all_predictions) == len(instances), ( + f"Mismatch between number of instances ({len(instances)}) and " + f"predictions ({len(all_predictions)}). Potential API issue." + ) + + return pd.Series(all_predictions, index=batch.index) + + +class VertexEndpointHandler(BaseModelHandler): + """A handler for running inference with a deployed model on a + Vertex AI Endpoint.""" + + def __init__(self, endpoint: str): + super().__init__() + self.endpoint = endpoint + self._project = None + self._location = None + self._predict_parameters = None + self._batch_size = 10 + self._use_dedicated_endpoint = False + self.set_return_type(ArrayType(DoubleType())) + + def project(self, project: str) -> "VertexEndpointHandler": + """Sets the Google Cloud project for the Vertex AI API call.""" + self._project = project + return self + + def location(self, location: str) -> "VertexEndpointHandler": + """Sets the Google Cloud location (region) for Vertex AI API call.""" + self._location = location + return self + + def predict_parameters(self, parameters: Dict) -> "VertexEndpointHandler": + """Sets the parameters for the prediction call.""" + self._predict_parameters = parameters + return self + + def batch_size(self, batch_size: int) -> "VertexEndpointHandler": + """Sets the number of instances to send in each prediction request. + + Defaults to 10 if not set. + """ + self._batch_size = batch_size + return self + + def use_dedicated_endpoint( + self, use_dedicated_endpoint: bool + ) -> "VertexEndpointHandler": + """Sets whether to use the dedicated endpoint for prediction.""" + self._use_dedicated_endpoint = use_dedicated_endpoint + return self + + def _load_model(self) -> Model: + """Loads the VertexEndpoint instance on each Spark executor.""" + return VertexEndpoint( + self.endpoint, + project=self._project, + location=self._location, + predict_parameters=self._predict_parameters, + batch_size=self._batch_size, + use_dedicated_endpoint=self._use_dedicated_endpoint, + ) diff --git a/tests/integration/inference/test_vertex_endpoint_handler.py b/tests/integration/inference/test_vertex_endpoint_handler.py new file mode 100644 index 0000000..b6277af --- /dev/null +++ b/tests/integration/inference/test_vertex_endpoint_handler.py @@ -0,0 +1,70 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Integration test for VertexEndpointHandler.""" + +import os + +import pandas as pd +from pyspark.sql import SparkSession +from pyspark.sql.types import StringType + +from google.cloud.dataproc_ml.inference import VertexEndpointHandler + + +def create_prompt(text_series: pd.Series) -> pd.Series: + """A pre-processor that wraps each text input in a dictionary with + a 'prompt' key.""" + return text_series.apply(lambda x: {"prompt": x, "max_tokens": 256}) + + +def test_vertex_endpoint_handler(): + """Tests the VertexEndpointHandler with a live endpoint.""" + spark = SparkSession.builder.appName( + "VertexEndpointHandlerTest" + ).getOrCreate() + + project = os.getenv("GOOGLE_CLOUD_PROJECT") + location = os.getenv("GOOGLE_CLOUD_LOCATION", "us-central1") + # TODO: Replace with endpoint creation during test run which shouldn't + # take more than 20 mins + endpoint_name = "1121351227238514688" + + # Create a sample DataFrame with feature vectors + data = [ + ("Write a paragraph on India.",), + ("Who is James Bond?",), + ] + df = spark.createDataFrame(data, ["features"]) + + # Configure and apply the handler + handler = ( + VertexEndpointHandler(endpoint=endpoint_name) + .input_cols("features") + .output_col("predictions") + .use_dedicated_endpoint(True) + .pre_processor(create_prompt) + .set_return_type(StringType()) + .project(project) + .location(location) + ) + + result_df = handler.transform(df) + results = result_df.collect() + + assert len(results) == 2 + assert "predictions" in result_df.columns + assert len(results[0]["predictions"]) > 0 # Check for non-empty prediction + + spark.stop() diff --git a/tests/unit/inference/test_vertex_endpoint_handler.py b/tests/unit/inference/test_vertex_endpoint_handler.py new file mode 100644 index 0000000..6dbdc56 --- /dev/null +++ b/tests/unit/inference/test_vertex_endpoint_handler.py @@ -0,0 +1,183 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Unit tests for VertexEndpointHandler.""" + +import unittest +from unittest.mock import MagicMock, patch + +import pandas as pd +from pyspark.sql.types import ArrayType, DoubleType, StringType + +from google.cloud.dataproc_ml.inference.vertex_endpoint_handler import ( + VertexEndpoint, + VertexEndpointHandler, +) + +# The path for patching must be where the object is *looked up*. +ENDPOINT_HANDLER_PATH = ( + "google.cloud.dataproc_ml.inference.vertex_endpoint_handler" +) + + +class TestVertexEndpoint(unittest.TestCase): + """Tests for the VertexEndpoint class.""" + + @patch(f"{ENDPOINT_HANDLER_PATH}.aiplatform") + def test_init(self, mock_aiplatform): + """Tests that the VertexEndpoint initializes the API client.""" + endpoint_name = "my-endpoint" + project = "my-project" + location = "us-central1" + + VertexEndpoint( + endpoint=endpoint_name, project=project, location=location + ) + + mock_aiplatform.init.assert_called_once_with( + project=project, location=location + ) + mock_aiplatform.Endpoint.assert_called_once_with( + endpoint_name=endpoint_name + ) + + @patch(f"{ENDPOINT_HANDLER_PATH}.aiplatform") + def test_call_sends_batched_requests(self, mock_aiplatform): + """Tests that the call method sends requests in batches.""" + # 1. Setup mock endpoint and a side_effect to generate dynamic responses + mock_endpoint_client = mock_aiplatform.Endpoint.return_value + + def mock_predict_side_effect(instances, **kwargs): + mock_result = MagicMock() + # Return a number of predictions matching the number of instances + mock_result.predictions = [[0.1, 0.9]] * len(instances) + return mock_result + + mock_endpoint_client.predict.side_effect = mock_predict_side_effect + + # 2. Instantiate the model and create a test batch + predict_params = {"param1": "value1"} + model = VertexEndpoint( + endpoint="test-endpoint", + batch_size=2, + predict_parameters=predict_params, + use_dedicated_endpoint=True, + ) + + input_batch = pd.Series([[1], [2], [3]], index=[10, 20, 30]) + + # 3. Call the method to be tested + output_series = model.call(input_batch) + + # 4. Assertions + self.assertEqual(mock_endpoint_client.predict.call_count, 2) + mock_endpoint_client.predict.assert_any_call( + instances=[[1], [2]], + parameters=predict_params, + use_dedicated_endpoint=True, + ) + mock_endpoint_client.predict.assert_any_call( + instances=[[3]], + parameters=predict_params, + use_dedicated_endpoint=True, + ) + + # The side_effect returns [[0.1, 0.9]] for each instance. + expected_output = pd.Series( + [[0.1, 0.9], [0.1, 0.9], [0.1, 0.9]], index=[10, 20, 30] + ) + pd.testing.assert_series_equal(output_series, expected_output) + + @patch(f"{ENDPOINT_HANDLER_PATH}.aiplatform") + def test_call_raises_error_on_prediction_mismatch(self, mock_aiplatform): + """Tests that call() raises an error if prediction count mismatches.""" + # 1. Setup mock endpoint to return fewer predictions than instances + mock_endpoint_client = mock_aiplatform.Endpoint.return_value + mock_prediction_result = MagicMock() + mock_prediction_result.predictions = [[0.1, 0.9]] # Only 1 prediction + mock_endpoint_client.predict.return_value = mock_prediction_result + + # 2. Instantiate the model and create a test batch + model = VertexEndpoint(endpoint="test-endpoint", batch_size=2) + input_batch = pd.Series([[1], [2]]) # 2 instances + + # 3. Call the method and assert it raises an AssertionError + with self.assertRaisesRegex(AssertionError, "Mismatch between number"): + model.call(input_batch) + + +class TestVertexEndpointHandler(unittest.TestCase): + """Tests for the VertexEndpointHandler class.""" + + def setUp(self): + """Set up a new handler for each test.""" + self.handler = VertexEndpointHandler(endpoint="test-endpoint") + + def test_initialization_defaults(self): + """Test that the handler initializes with correct default values.""" + self.assertEqual(self.handler.endpoint, "test-endpoint") + self.assertIsNone(self.handler._project) + self.assertIsNone(self.handler._location) + self.assertIsNone(self.handler._predict_parameters) + self.assertEqual(self.handler._batch_size, 10) + self.assertFalse(self.handler._use_dedicated_endpoint) + self.assertIsInstance(self.handler._return_type, ArrayType) + self.assertIsInstance(self.handler._return_type.elementType, DoubleType) + + def test_builder_methods_chaining(self): + """Test that builder methods correctly set values and allow chaining.""" + project = "test-project" + location = "us-central1" + params = {"key": "value"} + return_type = StringType() + + chained_handler = ( + self.handler.project(project) + .location(location) + .predict_parameters(params) + .batch_size(50) + .use_dedicated_endpoint(True) + .set_return_type(return_type) + ) + + self.assertIs(chained_handler, self.handler) + self.assertEqual(self.handler._project, project) + self.assertEqual(self.handler._location, location) + self.assertEqual(self.handler._predict_parameters, params) + self.assertEqual(self.handler._batch_size, 50) + self.assertTrue(self.handler._use_dedicated_endpoint) + self.assertEqual(self.handler._return_type, return_type) + + @patch(f"{ENDPOINT_HANDLER_PATH}.VertexEndpoint") + def test_load_model_success(self, mock_vertex_endpoint): + """Test the successful loading of a model.""" + project = "my-project" + location = "us-east1" + params = {"a": 1} + + self.handler.project(project).location(location).predict_parameters( + params + ).batch_size(20).use_dedicated_endpoint(True) + + loaded_model = self.handler._load_model() + + mock_vertex_endpoint.assert_called_once_with( + self.handler.endpoint, + project=project, + location=location, + predict_parameters=params, + batch_size=20, + use_dedicated_endpoint=True, + ) + self.assertEqual(loaded_model, mock_vertex_endpoint.return_value)