Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion google/cloud/dataproc_ml/inference/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,5 +17,11 @@
from .gen_ai_model_handler import GenAiModelHandler
from .pytorch_model_handler import PyTorchModelHandler
from .tensorflow_model_handler import TensorFlowModelHandler
from .vertex_endpoint_handler import VertexEndpointHandler

__all__ = ("GenAiModelHandler", "PyTorchModelHandler", "TensorFlowModelHandler")
__all__ = (
"GenAiModelHandler",
"PyTorchModelHandler",
"VertexEndpointHandler",
"TensorFlowModelHandler",
)
142 changes: 142 additions & 0 deletions google/cloud/dataproc_ml/inference/vertex_endpoint_handler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,142 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""A module for handling model inference on Spark DataFrames using a
Vertex AI Endpoint."""
import logging
from typing import Dict, List, Optional

import pandas as pd
from pyspark.sql.types import ArrayType, DoubleType

from google.cloud import aiplatform
from google.cloud.dataproc_ml.inference.base_model_handler import (
BaseModelHandler,
Model,
)

logger = logging.getLogger(__name__)


class VertexEndpoint(Model):
"""A concrete implementation of the Model interface for a
Vertex AI Endpoint."""

def __init__(
self,
endpoint: str,
project: Optional[str] = None,
location: Optional[str] = None,
predict_parameters: Optional[Dict] = None,
batch_size: Optional[int] = None,
use_dedicated_endpoint: bool = False,
):
"""Initializes the VertexEndpoint.

Args:
endpoint: The name of the Vertex AI Endpoint.
project: The GCP project ID.
location: The GCP location.
predict_parameters: Parameters for the prediction call.
batch_size: The number of instances to include in each prediction
request. Defaults to 10.
use_dedicated_endpoint: Whether to use the dedicated endpoint for
prediction. Defaults to False.
"""
aiplatform.init(project=project, location=location)
self.endpoint_client = aiplatform.Endpoint(endpoint_name=endpoint)
self.predict_parameters = predict_parameters
self.batch_size = batch_size

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The batch_size parameter can be None, but the call method doesn't handle this case, which will lead to a TypeError when it's used in range(). The docstring for __init__ also states that batch_size defaults to 10, but this default is not applied here. To prevent runtime errors and align with the documentation, you should assign a default value if None is provided.

Suggested change
self.batch_size = batch_size
self.batch_size = batch_size if batch_size is not None else 10

self.use_dedicated_endpoint = use_dedicated_endpoint

def call(self, batch: pd.Series) -> pd.Series:
"""Overrides the base method to send instances to the
Vertex AI Endpoint."""

# Convert the pandas Series to a list of instances.
instances: List = batch.tolist()

all_predictions = []

for i in range(0, len(instances), self.batch_size):
batch_instances = instances[i : i + self.batch_size]
prediction_result = self.endpoint_client.predict(
instances=batch_instances,
parameters=self.predict_parameters,
use_dedicated_endpoint=self.use_dedicated_endpoint,
)
all_predictions.extend(prediction_result.predictions)

assert len(all_predictions) == len(instances), (
f"Mismatch between number of instances ({len(instances)}) and "
f"predictions ({len(all_predictions)}). Potential API issue."
)

return pd.Series(all_predictions, index=batch.index)


class VertexEndpointHandler(BaseModelHandler):
"""A handler for running inference with a deployed model on a
Vertex AI Endpoint."""

def __init__(self, endpoint: str):
super().__init__()
self.endpoint = endpoint
self._project = None
self._location = None
self._predict_parameters = None
self._batch_size = 10
self._use_dedicated_endpoint = False
self.set_return_type(ArrayType(DoubleType()))

def project(self, project: str) -> "VertexEndpointHandler":
"""Sets the Google Cloud project for the Vertex AI API call."""
self._project = project
return self

def location(self, location: str) -> "VertexEndpointHandler":
"""Sets the Google Cloud location (region) for Vertex AI API call."""
self._location = location
return self

def predict_parameters(self, parameters: Dict) -> "VertexEndpointHandler":
"""Sets the parameters for the prediction call."""
self._predict_parameters = parameters
return self

def batch_size(self, batch_size: int) -> "VertexEndpointHandler":
"""Sets the number of instances to send in each prediction request.

Defaults to 10 if not set.
"""
self._batch_size = batch_size
return self

def use_dedicated_endpoint(
self, use_dedicated_endpoint: bool
) -> "VertexEndpointHandler":
"""Sets whether to use the dedicated endpoint for prediction."""
self._use_dedicated_endpoint = use_dedicated_endpoint
return self

def _load_model(self) -> Model:
"""Loads the VertexEndpoint instance on each Spark executor."""
return VertexEndpoint(
self.endpoint,
project=self._project,
location=self._location,
predict_parameters=self._predict_parameters,
batch_size=self._batch_size,
use_dedicated_endpoint=self._use_dedicated_endpoint,
)
70 changes: 70 additions & 0 deletions tests/integration/inference/test_vertex_endpoint_handler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Integration test for VertexEndpointHandler."""

import os

import pandas as pd
from pyspark.sql import SparkSession
from pyspark.sql.types import StringType

from google.cloud.dataproc_ml.inference import VertexEndpointHandler


def create_prompt(text_series: pd.Series) -> pd.Series:
"""A pre-processor that wraps each text input in a dictionary with
a 'prompt' key."""
return text_series.apply(lambda x: {"prompt": x, "max_tokens": 256})


def test_vertex_endpoint_handler():
"""Tests the VertexEndpointHandler with a live endpoint."""
spark = SparkSession.builder.appName(
"VertexEndpointHandlerTest"
).getOrCreate()

project = os.getenv("GOOGLE_CLOUD_PROJECT")
location = os.getenv("GOOGLE_CLOUD_LOCATION", "us-central1")
# TODO: Replace with endpoint creation during test run which shouldn't
# take more than 20 mins
endpoint_name = "1121351227238514688"

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The integration test relies on a hardcoded endpoint name. This makes the test fragile and dependent on an external resource that might change or be deleted, causing test failures. While the TODO comment acknowledges this, it's important to prioritize making this test self-contained. Consider using a test fixture that programmatically creates a temporary endpoint for the test run and tears it down afterwards to ensure the test is reliable and independent of pre-existing infrastructure.


# Create a sample DataFrame with feature vectors
data = [
("Write a paragraph on India.",),
("Who is James Bond?",),
]
df = spark.createDataFrame(data, ["features"])

# Configure and apply the handler
handler = (
VertexEndpointHandler(endpoint=endpoint_name)
.input_cols("features")
.output_col("predictions")
.use_dedicated_endpoint(True)
.pre_processor(create_prompt)
.set_return_type(StringType())
.project(project)
.location(location)
)

result_df = handler.transform(df)
results = result_df.collect()

assert len(results) == 2
assert "predictions" in result_df.columns
assert len(results[0]["predictions"]) > 0 # Check for non-empty prediction

spark.stop()
Loading