input);
+
+}
diff --git a/sdks/java/ml/remoteinference/src/main/java/org/apache/beam/sdk/ml/remoteinference/base/BaseModelParameters.java b/sdks/java/ml/remoteinference/src/main/java/org/apache/beam/sdk/ml/remoteinference/base/BaseModelParameters.java
new file mode 100644
index 000000000000..46ee72c73001
--- /dev/null
+++ b/sdks/java/ml/remoteinference/src/main/java/org/apache/beam/sdk/ml/remoteinference/base/BaseModelParameters.java
@@ -0,0 +1,40 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * License); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an AS IS BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.sdk.ml.remoteinference.base;
+
+import java.io.Serializable;
+
+/**
+ * Base interface for defining model-specific parameters used to configure remote inference clients.
+ *
+ * Implementations of this interface encapsulate all configuration needed to initialize
+ * and communicate with a remote model inference service. This typically includes:
+ *
+ * - Authentication credentials (API keys, tokens)
+ * - Model identifiers or names
+ * - Endpoint URLs or connection settings
+ * - Inference configuration (temperature, max tokens, timeout values, etc.)
+ *
+ *
+ * Parameters must be serializable. Consider using
+ * the builder pattern for complex parameter objects.
+ *
+ */
+public interface BaseModelParameters extends Serializable {
+
+}
diff --git a/sdks/java/ml/remoteinference/src/main/java/org/apache/beam/sdk/ml/remoteinference/base/BaseResponse.java b/sdks/java/ml/remoteinference/src/main/java/org/apache/beam/sdk/ml/remoteinference/base/BaseResponse.java
new file mode 100644
index 000000000000..b3c050c45c11
--- /dev/null
+++ b/sdks/java/ml/remoteinference/src/main/java/org/apache/beam/sdk/ml/remoteinference/base/BaseResponse.java
@@ -0,0 +1,34 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * License); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an AS IS BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.sdk.ml.remoteinference.base;
+
+import java.io.Serializable;
+
+/**
+ * Base class for defining response types returned from remote inference operations.
+
+ *
Implementations:
+ *
+ * - Contain the inference results (predictions, classifications, generated text, etc.)
+ * - Includes any relevant metadata
+ *
+ *
+ */
+public abstract class BaseResponse implements Serializable {
+
+}
diff --git a/sdks/java/ml/remoteinference/src/main/java/org/apache/beam/sdk/ml/remoteinference/base/PredictionResult.java b/sdks/java/ml/remoteinference/src/main/java/org/apache/beam/sdk/ml/remoteinference/base/PredictionResult.java
new file mode 100644
index 000000000000..edc30fd11246
--- /dev/null
+++ b/sdks/java/ml/remoteinference/src/main/java/org/apache/beam/sdk/ml/remoteinference/base/PredictionResult.java
@@ -0,0 +1,53 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * License); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an AS IS BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.sdk.ml.remoteinference.base;
+
+import java.io.Serializable;
+
+/**
+ * Pairs an input with its corresponding inference output.
+ *
+ * This class maintains the association between input data and its model's results
+ * for Downstream processing
+ */
+public class PredictionResult implements Serializable {
+
+ private final InputT input;
+ private final OutputT output;
+
+ private PredictionResult(InputT input, OutputT output) {
+ this.input = input;
+ this.output = output;
+
+ }
+
+ /* Returns input to handler */
+ public InputT getInput() {
+ return input;
+ }
+
+ /* Returns model handler's response*/
+ public OutputT getOutput() {
+ return output;
+ }
+
+ /* Creates a PredictionResult instance of provided input, output and types */
+ public static PredictionResult create(InputT input, OutputT output) {
+ return new PredictionResult<>(input, output);
+ }
+}
diff --git a/sdks/java/ml/remoteinference/src/main/java/org/apache/beam/sdk/ml/remoteinference/openai/OpenAIModelHandler.java b/sdks/java/ml/remoteinference/src/main/java/org/apache/beam/sdk/ml/remoteinference/openai/OpenAIModelHandler.java
new file mode 100644
index 000000000000..87616ee693d1
--- /dev/null
+++ b/sdks/java/ml/remoteinference/src/main/java/org/apache/beam/sdk/ml/remoteinference/openai/OpenAIModelHandler.java
@@ -0,0 +1,165 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * License); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an AS IS BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.sdk.ml.remoteinference.openai;
+
+import com.fasterxml.jackson.annotation.JsonProperty;
+import com.fasterxml.jackson.annotation.JsonPropertyDescription;
+import com.fasterxml.jackson.core.JsonProcessingException;
+import com.fasterxml.jackson.databind.ObjectMapper;
+import com.openai.client.OpenAIClient;
+import com.openai.client.okhttp.OpenAIOkHttpClient;
+import com.openai.core.JsonSchemaLocalValidation;
+import com.openai.models.responses.ResponseCreateParams;
+import com.openai.models.responses.StructuredResponseCreateParams;
+import org.apache.beam.sdk.ml.remoteinference.base.BaseModelHandler;
+import org.apache.beam.sdk.ml.remoteinference.base.PredictionResult;
+
+import java.util.List;
+import java.util.stream.Collectors;
+
+/**
+ * Model handler for OpenAI API inference requests.
+ *
+ * This handler manages communication with OpenAI's API, including client initialization,
+ * request formatting, and response parsing. It uses OpenAI's structured output feature to
+ * ensure reliable input-output pairing.
+ *
+ *
Usage
+ * {@code
+ * OpenAIModelParameters params = OpenAIModelParameters.builder()
+ * .apiKey("sk-...")
+ * .modelName("gpt-4")
+ * .instructionPrompt("Classify the following text into one of the categories: {CATEGORIES}")
+ * .build();
+ *
+ * PCollection inputs = ...;
+ * PCollection>> results =
+ * inputs.apply(
+ * RemoteInference.invoke()
+ * .handler(OpenAIModelHandler.class)
+ * .withParameters(params)
+ * );
+ * }
+ *
+ */
+public class OpenAIModelHandler
+ implements BaseModelHandler {
+
+ private transient OpenAIClient client;
+ private transient StructuredResponseCreateParams clientParams;
+ private OpenAIModelParameters modelParameters;
+
+ /**
+ * Initializes the OpenAI client with the provided parameters.
+ *
+ * This method is called once during setup. It creates an authenticated
+ * OpenAI client using the API key from the parameters.
+ *
+ * @param parameters the configuration parameters including API key and model name
+ */
+ @Override
+ public void createClient(OpenAIModelParameters parameters) {
+ this.modelParameters = parameters;
+ this.client = OpenAIOkHttpClient.builder()
+ .apiKey(this.modelParameters.getApiKey())
+ .build();
+ }
+
+ /**
+ * Performs inference on a batch of inputs using the OpenAI Client.
+ *
+ *
This method serializes the input batch to JSON string, sends it to OpenAI with structured
+ * output requirements, and parses the response into {@link PredictionResult} objects
+ * that pair each input with its corresponding output.
+ *
+ * @param input the list of inputs to process
+ * @return an iterable of model results and input pairs
+ */
+ @Override
+ public Iterable> request(List input) {
+
+ try {
+ // Convert input list to JSON string
+ String inputBatch = new ObjectMapper()
+ .writeValueAsString(input.stream().map(OpenAIModelInput::getModelInput).toList());
+
+ // Build structured response parameters
+ this.clientParams = ResponseCreateParams.builder()
+ .model(modelParameters.getModelName())
+ .input(inputBatch)
+ .text(StructuredInputOutput.class, JsonSchemaLocalValidation.NO)
+ .instructions(modelParameters.getInstructionPrompt())
+ .build();
+
+ // Get structured output from the model
+ StructuredInputOutput structuredOutput = client.responses()
+ .create(clientParams)
+ .output()
+ .stream()
+ .flatMap(item -> item.message().stream())
+ .flatMap(message -> message.content().stream())
+ .flatMap(content -> content.outputText().stream())
+ .findFirst()
+ .orElse(null);
+
+ if (structuredOutput == null || structuredOutput.responses == null) {
+ throw new RuntimeException("Model returned no structured responses");
+ }
+
+ // Map responses to PredictionResults
+ List> results = structuredOutput.responses.stream()
+ .map(response -> PredictionResult.create(
+ OpenAIModelInput.create(response.input),
+ OpenAIModelResponse.create(response.output)))
+ .collect(Collectors.toList());
+
+ return results;
+
+ } catch (JsonProcessingException e) {
+ throw new RuntimeException("Failed to serialize input batch", e);
+ }
+ }
+
+ /**
+ * Schema class for structured output response.
+ *
+ * Represents a single input-output pair returned by the OpenAI API.
+ */
+ public static class Response {
+ @JsonProperty(required = true)
+ @JsonPropertyDescription("The input string")
+ public String input;
+
+ @JsonProperty(required = true)
+ @JsonPropertyDescription("The output string")
+ public String output;
+ }
+
+ /**
+ * Schema class for structured output containing multiple responses.
+ *
+ *
This class defines the expected JSON structure for OpenAI's structured output,
+ * ensuring reliable parsing of batched inference results.
+ */
+ public static class StructuredInputOutput {
+ @JsonProperty(required = true)
+ @JsonPropertyDescription("Array of input-output pairs")
+ public List responses;
+ }
+
+}
diff --git a/sdks/java/ml/remoteinference/src/main/java/org/apache/beam/sdk/ml/remoteinference/openai/OpenAIModelInput.java b/sdks/java/ml/remoteinference/src/main/java/org/apache/beam/sdk/ml/remoteinference/openai/OpenAIModelInput.java
new file mode 100644
index 000000000000..1ef59c89da66
--- /dev/null
+++ b/sdks/java/ml/remoteinference/src/main/java/org/apache/beam/sdk/ml/remoteinference/openai/OpenAIModelInput.java
@@ -0,0 +1,64 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * License); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an AS IS BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.sdk.ml.remoteinference.openai;
+
+import org.apache.beam.sdk.ml.remoteinference.base.BaseInput;
+
+/**
+ * Input for OpenAI model inference requests.
+ *
+ * This class encapsulates text input to be sent to OpenAI models.
+ *
+ *
Example Usage
+ * {@code
+ * OpenAIModelInput input = OpenAIModelInput.create("Translate to French: Hello");
+ * String text = input.getModelInput(); // "Translate to French: Hello"
+ * }
+ *
+ * @see OpenAIModelHandler
+ * @see OpenAIModelResponse
+ */
+public class OpenAIModelInput extends BaseInput {
+
+ private final String input;
+
+ private OpenAIModelInput(String input) {
+
+ this.input = input;
+ }
+
+ /**
+ * Returns the text input for the model.
+ *
+ * @return the input text string
+ */
+ public String getModelInput() {
+ return input;
+ }
+
+ /**
+ * Creates a new input instance with the specified text.
+ *
+ * @param input the text to send to the model
+ * @return a new {@link OpenAIModelInput} instance
+ */
+ public static OpenAIModelInput create(String input) {
+ return new OpenAIModelInput(input);
+ }
+
+}
diff --git a/sdks/java/ml/remoteinference/src/main/java/org/apache/beam/sdk/ml/remoteinference/openai/OpenAIModelParameters.java b/sdks/java/ml/remoteinference/src/main/java/org/apache/beam/sdk/ml/remoteinference/openai/OpenAIModelParameters.java
new file mode 100644
index 000000000000..3aac4112cba3
--- /dev/null
+++ b/sdks/java/ml/remoteinference/src/main/java/org/apache/beam/sdk/ml/remoteinference/openai/OpenAIModelParameters.java
@@ -0,0 +1,114 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * License); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an AS IS BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.sdk.ml.remoteinference.openai;
+
+import org.apache.beam.sdk.ml.remoteinference.base.BaseModelParameters;
+
+/**
+ * Configuration parameters required for OpenAI model inference.
+ *
+ * This class encapsulates all configuration needed to initialize and communicate with
+ * OpenAI's API, including authentication credentials, model selection, and inference instructions.
+ *
+ *
Example Usage
+ * {@code
+ * OpenAIModelParameters params = OpenAIModelParameters.builder()
+ * .apiKey("sk-...")
+ * .modelName("gpt-4")
+ * .instructionPrompt("Translate the following text to French:")
+ * .build();
+ * }
+ *
+ * @see OpenAIModelHandler
+ */
+public class OpenAIModelParameters implements BaseModelParameters {
+
+ private final String apiKey;
+ private final String modelName;
+ private final String instructionPrompt;
+
+ private OpenAIModelParameters(Builder builder) {
+ this.apiKey = builder.apiKey;
+ this.modelName = builder.modelName;
+ this.instructionPrompt = builder.instructionPrompt;
+ }
+
+ public String getApiKey() {
+ return apiKey;
+ }
+
+ public String getModelName() {
+ return modelName;
+ }
+
+ public String getInstructionPrompt() {
+ return instructionPrompt;
+ }
+
+ public static Builder builder() {
+ return new Builder();
+ }
+
+
+ public static class Builder {
+ private String apiKey;
+ private String modelName;
+ private String instructionPrompt;
+
+ private Builder() {
+ }
+
+ /**
+ * Sets the OpenAI API key for authentication.
+ *
+ * @param apiKey the API key (required)
+ */
+ public Builder apiKey(String apiKey) {
+ this.apiKey = apiKey;
+ return this;
+ }
+
+ /**
+ * Sets the name of the OpenAI model to use.
+ *
+ * @param modelName the model name, e.g., "gpt-4" (required)
+ */
+ public Builder modelName(String modelName) {
+ this.modelName = modelName;
+ return this;
+ }
+ /**
+ * Sets the instruction prompt for the model.
+ * This prompt provides context or instructions to the model about how to process
+ * the input text.
+ *
+ * @param prompt the instruction text (required)
+ */
+ public Builder instructionPrompt(String prompt) {
+ this.instructionPrompt = prompt;
+ return this;
+ }
+
+ /**
+ * Builds the {@link OpenAIModelParameters} instance.
+ */
+ public OpenAIModelParameters build() {
+ return new OpenAIModelParameters(this);
+ }
+ }
+}
diff --git a/sdks/java/ml/remoteinference/src/main/java/org/apache/beam/sdk/ml/remoteinference/openai/OpenAIModelResponse.java b/sdks/java/ml/remoteinference/src/main/java/org/apache/beam/sdk/ml/remoteinference/openai/OpenAIModelResponse.java
new file mode 100644
index 000000000000..abfd9d7cb5c3
--- /dev/null
+++ b/sdks/java/ml/remoteinference/src/main/java/org/apache/beam/sdk/ml/remoteinference/openai/OpenAIModelResponse.java
@@ -0,0 +1,61 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * License); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an AS IS BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.sdk.ml.remoteinference.openai;
+
+import org.apache.beam.sdk.ml.remoteinference.base.BaseResponse;
+
+/**
+ * Response from OpenAI model inference results.
+ * This class encapsulates the text output returned from OpenAI models..
+ *
+ *
Example Usage
+ * {@code
+ * OpenAIModelResponse response = OpenAIModelResponse.create("Bonjour");
+ * String output = response.getModelResponse(); // "Bonjour"
+ * }
+ *
+ * @see OpenAIModelHandler
+ * @see OpenAIModelInput
+ */
+public class OpenAIModelResponse extends BaseResponse {
+
+ private final String output;
+
+ private OpenAIModelResponse(String output) {
+ this.output = output;
+ }
+
+ /**
+ * Returns the text output from the model.
+ *
+ * @return the output text string
+ */
+ public String getModelResponse() {
+ return output;
+ }
+
+ /**
+ * Creates a new response instance with the specified output text.
+ *
+ * @param output the text returned by the model
+ * @return a new {@link OpenAIModelResponse} instance
+ */
+ public static OpenAIModelResponse create(String output) {
+ return new OpenAIModelResponse(output);
+ }
+}
diff --git a/sdks/java/ml/remoteinference/src/test/java/org/apache/beam/sdk/ml/remoteinference/RemoteInferenceTest.java b/sdks/java/ml/remoteinference/src/test/java/org/apache/beam/sdk/ml/remoteinference/RemoteInferenceTest.java
new file mode 100644
index 000000000000..3f351b9f88a3
--- /dev/null
+++ b/sdks/java/ml/remoteinference/src/test/java/org/apache/beam/sdk/ml/remoteinference/RemoteInferenceTest.java
@@ -0,0 +1,586 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * License); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an AS IS BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.sdk.ml.remoteinference;
+
+import java.util.Arrays;
+import java.util.Collections;
+import java.util.List;
+import java.util.stream.Collectors;
+import java.util.stream.StreamSupport;
+
+import org.apache.beam.sdk.coders.SerializableCoder;
+import org.apache.beam.sdk.ml.remoteinference.base.*;
+import org.apache.beam.sdk.testing.PAssert;
+import org.apache.beam.sdk.testing.TestPipeline;
+import org.apache.beam.sdk.transforms.Create;
+import org.apache.beam.sdk.values.PCollection;
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertNotNull;
+import static org.junit.Assert.assertThrows;
+import static org.junit.Assert.assertTrue;
+import static org.junit.Assert.fail;
+import org.junit.Rule;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.junit.runners.JUnit4;
+
+
+
+@RunWith(JUnit4.class)
+public class RemoteInferenceTest {
+
+ @Rule
+ public final transient TestPipeline pipeline = TestPipeline.create();
+
+ // Test input class
+ public static class TestInput extends BaseInput {
+ private final String value;
+
+ private TestInput(String value) {
+ this.value = value;
+ }
+
+ public static TestInput create(String value) {
+ return new TestInput(value);
+ }
+
+ public String getModelInput() {
+ return value;
+ }
+
+ @Override
+ public boolean equals(Object o) {
+ if (this == o)
+ return true;
+ if (!(o instanceof TestInput))
+ return false;
+ TestInput testInput = (TestInput) o;
+ return value.equals(testInput.value);
+ }
+
+ @Override
+ public int hashCode() {
+ return value.hashCode();
+ }
+
+ @Override
+ public String toString() {
+ return "TestInput{value='" + value + "'}";
+ }
+ }
+
+ // Test output class
+ public static class TestOutput extends BaseResponse {
+ private final String result;
+
+ private TestOutput(String result) {
+ this.result = result;
+ }
+
+ public static TestOutput create(String result) {
+ return new TestOutput(result);
+ }
+
+ public String getModelResponse() {
+ return result;
+ }
+
+ @Override
+ public boolean equals(Object o) {
+ if (this == o)
+ return true;
+ if (!(o instanceof TestOutput))
+ return false;
+ TestOutput that = (TestOutput) o;
+ return result.equals(that.result);
+ }
+
+ @Override
+ public int hashCode() {
+ return result.hashCode();
+ }
+
+ @Override
+ public String toString() {
+ return "TestOutput{result='" + result + "'}";
+ }
+ }
+
+ // Test parameters class
+ public static class TestParameters implements BaseModelParameters {
+ private final String config;
+
+ private TestParameters(Builder builder) {
+ this.config = builder.config;
+ }
+
+ public String getConfig() {
+ return config;
+ }
+
+ @Override
+ public String toString() {
+ return "TestParameters{config='" + config + "'}";
+ }
+
+ @Override
+ public boolean equals(Object o) {
+ if (this == o)
+ return true;
+ if (!(o instanceof TestParameters))
+ return false;
+ TestParameters that = (TestParameters) o;
+ return config.equals(that.config);
+ }
+
+ @Override
+ public int hashCode() {
+ return config.hashCode();
+ }
+
+ // Builder
+ public static class Builder {
+ private String config;
+
+ public Builder setConfig(String config) {
+ this.config = config;
+ return this;
+ }
+
+ public TestParameters build() {
+ return new TestParameters(this);
+ }
+ }
+
+ public static Builder builder() {
+ return new Builder();
+ }
+ }
+
+ // Mock handler for successful inference
+ public static class MockSuccessHandler
+ implements BaseModelHandler {
+
+ private TestParameters parameters;
+ private boolean clientCreated = false;
+
+ @Override
+ public void createClient(TestParameters parameters) {
+ this.parameters = parameters;
+ this.clientCreated = true;
+ }
+
+ @Override
+ public Iterable> request(List input) {
+ if (!clientCreated) {
+ throw new IllegalStateException("Client not initialized");
+ }
+ return input.stream()
+ .map(i -> PredictionResult.create(
+ i,
+ new TestOutput("processed-" + i.getModelInput())))
+ .collect(Collectors.toList());
+ }
+ }
+
+ // Mock handler that returns empty results
+ public static class MockEmptyResultHandler
+ implements BaseModelHandler {
+
+ @Override
+ public void createClient(TestParameters parameters) {
+ // Setup succeeds
+ }
+
+ @Override
+ public Iterable> request(List input) {
+ return Collections.emptyList();
+ }
+ }
+
+ // Mock handler that throws exception during setup
+ public static class MockFailingSetupHandler
+ implements BaseModelHandler {
+
+ @Override
+ public void createClient(TestParameters parameters) {
+ throw new RuntimeException("Setup failed intentionally");
+ }
+
+ @Override
+ public Iterable> request(List input) {
+ return Collections.emptyList();
+ }
+ }
+
+ // Mock handler that throws exception during request
+ public static class MockFailingRequestHandler
+ implements BaseModelHandler {
+
+ @Override
+ public void createClient(TestParameters parameters) {
+ // Setup succeeds
+ }
+
+ @Override
+ public Iterable> request(List input) {
+ throw new RuntimeException("Request failed intentionally");
+ }
+ }
+
+ // Mock handler without default constructor (to test error handling)
+ public static class MockNoDefaultConstructorHandler
+ implements BaseModelHandler {
+
+ private final String required;
+
+ public MockNoDefaultConstructorHandler(String required) {
+ this.required = required;
+ }
+
+ @Override
+ public void createClient(TestParameters parameters) {
+ }
+
+ @Override
+ public Iterable> request(List input) {
+ return Collections.emptyList();
+ }
+ }
+
+ @Test
+ public void testInvokeWithSingleElement() {
+ TestInput input = TestInput.create("test-value");
+ TestParameters params = TestParameters.builder()
+ .setConfig("test-config")
+ .build();
+
+ PCollection inputCollection = pipeline.apply(Create.of(input));
+
+ PCollection>> results = inputCollection
+ .apply("RemoteInference",
+ RemoteInference.invoke()
+ .handler(MockSuccessHandler.class)
+ .withParameters(params));
+
+ // Verify the output contains expected predictions
+ PAssert.thatSingleton(results).satisfies(batch -> {
+ List> resultList = StreamSupport.stream(batch.spliterator(), false)
+ .collect(Collectors.toList());
+
+ assertEquals("Expected exactly 1 result", 1, resultList.size());
+
+ PredictionResult result = resultList.get(0);
+ assertEquals("test-value", result.getInput().getModelInput());
+ assertEquals("processed-test-value", result.getOutput().getModelResponse());
+
+ return null;
+ });
+
+ pipeline.run().waitUntilFinish();
+ }
+
+ @Test
+ public void testInvokeWithMultipleElements() {
+ List inputs = Arrays.asList(
+ new TestInput("input1"),
+ new TestInput("input2"),
+ new TestInput("input3"));
+
+ TestParameters params = TestParameters.builder()
+ .setConfig("test-config")
+ .build();
+
+ PCollection inputCollection = pipeline
+ .apply("CreateInputs", Create.of(inputs).withCoder(SerializableCoder.of(TestInput.class)));
+
+ PCollection>> results = inputCollection
+ .apply("RemoteInference",
+ RemoteInference.invoke()
+ .handler(MockSuccessHandler.class)
+ .withParameters(params));
+
+ // Count total results across all batches
+ PAssert.that(results).satisfies(batches -> {
+ int totalCount = 0;
+ for (Iterable> batch : batches) {
+ for (PredictionResult result : batch) {
+ totalCount++;
+ assertTrue("Output should start with 'processed-'",
+ result.getOutput().getModelResponse().startsWith("processed-"));
+ assertNotNull("Input should not be null", result.getInput());
+ assertNotNull("Output should not be null", result.getOutput());
+ }
+ }
+ assertEquals("Expected 3 total results", 3, totalCount);
+ return null;
+ });
+
+ pipeline.run().waitUntilFinish();
+ }
+
+ @Test
+ public void testInvokeWithEmptyCollection() {
+ TestParameters params = TestParameters.builder()
+ .setConfig("test-config")
+ .build();
+
+ PCollection inputCollection = pipeline
+ .apply("CreateEmptyInput", Create.empty(SerializableCoder.of(TestInput.class)));
+
+ PCollection>> results = inputCollection
+ .apply("RemoteInference",
+ RemoteInference.invoke()
+ .handler(MockSuccessHandler.class)
+ .withParameters(params));
+
+ // assertion for empty PCollection
+ PAssert.that(results).empty();
+
+ pipeline.run().waitUntilFinish();
+ }
+
+ @Test
+ public void testHandlerReturnsEmptyResults() {
+ TestInput input = new TestInput("test-value");
+ TestParameters params = TestParameters.builder()
+ .setConfig("test-config")
+ .build();
+
+ PCollection inputCollection = pipeline
+ .apply("CreateInput", Create.of(input).withCoder(SerializableCoder.of(TestInput.class)));
+
+ PCollection>> results = inputCollection
+ .apply("RemoteInference",
+ RemoteInference.invoke()
+ .handler(MockEmptyResultHandler.class)
+ .withParameters(params));
+
+ // Verify we still get a result, but it's empty
+ PAssert.thatSingleton(results).satisfies(batch -> {
+ List> resultList = StreamSupport.stream(batch.spliterator(), false)
+ .collect(Collectors.toList());
+ assertEquals("Expected empty result list", 0, resultList.size());
+ return null;
+ });
+
+ pipeline.run().waitUntilFinish();
+ }
+
+ @Test
+ public void testHandlerSetupFailure() {
+ TestInput input = new TestInput("test-value");
+ TestParameters params = TestParameters.builder()
+ .setConfig("test-config")
+ .build();
+
+ PCollection inputCollection = pipeline
+ .apply("CreateInput", Create.of(input).withCoder(SerializableCoder.of(TestInput.class)));
+
+ inputCollection.apply("RemoteInference",
+ RemoteInference.invoke()
+ .handler(MockFailingSetupHandler.class)
+ .withParameters(params));
+
+ // Verify pipeline fails with expected error
+ try {
+ pipeline.run().waitUntilFinish();
+ fail("Expected pipeline to fail due to handler setup failure");
+ } catch (Exception e) {
+ String message = e.getMessage();
+ assertTrue("Exception should mention setup failure or handler instantiation failure",
+ message != null && (message.contains("Setup failed intentionally") ||
+ message.contains("Failed to instantiate handler")));
+ }
+ }
+
+ @Test
+ public void testHandlerRequestFailure() {
+ TestInput input = new TestInput("test-value");
+ TestParameters params = TestParameters.builder()
+ .setConfig("test-config")
+ .build();
+
+ PCollection inputCollection = pipeline
+ .apply("CreateInput", Create.of(input).withCoder(SerializableCoder.of(TestInput.class)));
+
+ inputCollection.apply("RemoteInference",
+ RemoteInference.invoke()
+ .handler(MockFailingRequestHandler.class)
+ .withParameters(params));
+
+ // Verify pipeline fails with expected error
+ try {
+ pipeline.run().waitUntilFinish();
+ fail("Expected pipeline to fail due to request failure");
+ } catch (Exception e) {
+ String message = e.getMessage();
+ assertTrue("Exception should mention request failure",
+ message != null && message.contains("Request failed intentionally"));
+ }
+ }
+
+ @Test
+ public void testHandlerWithoutDefaultConstructor() {
+ TestInput input = new TestInput("test-value");
+ TestParameters params = TestParameters.builder()
+ .setConfig("test-config")
+ .build();
+
+ PCollection inputCollection = pipeline
+ .apply("CreateInput", Create.of(input).withCoder(SerializableCoder.of(TestInput.class)));
+
+ inputCollection.apply("RemoteInference",
+ RemoteInference.invoke()
+ .handler(MockNoDefaultConstructorHandler.class)
+ .withParameters(params));
+
+ // Verify pipeline fails when handler cannot be instantiated
+ try {
+ pipeline.run().waitUntilFinish();
+ fail("Expected pipeline to fail due to missing default constructor");
+ } catch (Exception e) {
+ String message = e.getMessage();
+ assertTrue("Exception should mention handler instantiation failure",
+ message != null && message.contains("Failed to instantiate handler"));
+ }
+ }
+
+ @Test
+ public void testBuilderPattern() {
+ TestParameters params = TestParameters.builder()
+ .setConfig("test-config")
+ .build();
+
+ RemoteInference.Invoke transform = RemoteInference.invoke()
+ .handler(MockSuccessHandler.class)
+ .withParameters(params);
+
+ assertNotNull("Transform should not be null", transform);
+ }
+
+ @Test
+ public void testPredictionResultMapping() {
+ TestInput input = new TestInput("mapping-test");
+ TestParameters params = TestParameters.builder()
+ .setConfig("test-config")
+ .build();
+
+ PCollection inputCollection = pipeline
+ .apply("CreateInput", Create.of(input).withCoder(SerializableCoder.of(TestInput.class)));
+
+ PCollection>> results = inputCollection
+ .apply("RemoteInference",
+ RemoteInference.invoke()
+ .handler(MockSuccessHandler.class)
+ .withParameters(params));
+
+ PAssert.thatSingleton(results).satisfies(batch -> {
+ for (PredictionResult result : batch) {
+ // Verify that input is preserved in the result
+ assertNotNull("Input should not be null", result.getInput());
+ assertNotNull("Output should not be null", result.getOutput());
+ assertEquals("mapping-test", result.getInput().getModelInput());
+ assertTrue("Output should contain input value",
+ result.getOutput().getModelResponse().contains("mapping-test"));
+ }
+ return null;
+ });
+
+ pipeline.run().waitUntilFinish();
+ }
+
+ // Temporary behaviour until we introduce java BatchElements transform
+ // to batch elements in RemoteInference
+ @Test
+ public void testMultipleInputsProduceSeparateBatches() {
+ List inputs = Arrays.asList(
+ new TestInput("input1"),
+ new TestInput("input2"));
+
+ TestParameters params = TestParameters.builder()
+ .setConfig("test-config")
+ .build();
+
+ PCollection inputCollection = pipeline
+ .apply("CreateInputs", Create.of(inputs).withCoder(SerializableCoder.of(TestInput.class)));
+
+ PCollection>> results = inputCollection
+ .apply("RemoteInference",
+ RemoteInference.invoke()
+ .handler(MockSuccessHandler.class)
+ .withParameters(params));
+
+ PAssert.that(results).satisfies(batches -> {
+ int batchCount = 0;
+ for (Iterable> batch : batches) {
+ batchCount++;
+ int elementCount = 0;
+ elementCount += StreamSupport.stream(batch.spliterator(), false).count();
+ // Each batch should contain exactly 1 element
+ assertEquals("Each batch should contain 1 element", 1, elementCount);
+ }
+ assertEquals("Expected 2 batches", 2, batchCount);
+ return null;
+ });
+
+ pipeline.run().waitUntilFinish();
+ }
+
+ @Test
+ public void testWithEmptyParameters() {
+
+ pipeline.enableAbandonedNodeEnforcement(false);
+
+ TestInput input = TestInput.create("test-value");
+ PCollection inputCollection = pipeline.apply(Create.of(input));
+
+ IllegalArgumentException thrown = assertThrows(
+ IllegalArgumentException.class,
+ () -> inputCollection.apply("RemoteInference",
+ RemoteInference.invoke()
+ .handler(MockSuccessHandler.class)));
+
+ assertTrue(
+ "Expected message to contain 'withParameters() is required', but got: " + thrown.getMessage(),
+ thrown.getMessage().contains("withParameters() is required"));
+ }
+
+ @Test
+ public void testWithEmptyHandler() {
+
+ pipeline.enableAbandonedNodeEnforcement(false);
+
+ TestParameters params = TestParameters.builder()
+ .setConfig("test-config")
+ .build();
+
+ TestInput input = TestInput.create("test-value");
+ PCollection inputCollection = pipeline.apply(Create.of(input));
+
+ IllegalArgumentException thrown = assertThrows(
+ IllegalArgumentException.class,
+ () -> inputCollection.apply("RemoteInference",
+ RemoteInference.invoke()
+ .withParameters(params)));
+
+ assertTrue(
+ "Expected message to contain 'handler() is required', but got: " + thrown.getMessage(),
+ thrown.getMessage().contains("handler() is required"));
+ }
+}
diff --git a/sdks/java/ml/remoteinference/src/test/java/org/apache/beam/sdk/ml/remoteinference/openai/OpenAIModelHandlerIT.java b/sdks/java/ml/remoteinference/src/test/java/org/apache/beam/sdk/ml/remoteinference/openai/OpenAIModelHandlerIT.java
new file mode 100644
index 000000000000..fb24b090cb68
--- /dev/null
+++ b/sdks/java/ml/remoteinference/src/test/java/org/apache/beam/sdk/ml/remoteinference/openai/OpenAIModelHandlerIT.java
@@ -0,0 +1,366 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * License); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an AS IS BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.sdk.ml.remoteinference.openai;
+
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.List;
+
+import org.apache.beam.sdk.testing.PAssert;
+import org.apache.beam.sdk.testing.TestPipeline;
+import org.apache.beam.sdk.transforms.Create;
+import org.apache.beam.sdk.transforms.MapElements;
+import org.apache.beam.sdk.values.PCollection;
+import org.apache.beam.sdk.values.TypeDescriptor;
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertFalse;
+import static org.junit.Assert.assertNotNull;
+import static org.junit.Assert.assertTrue;
+import static org.junit.Assert.fail;
+import static org.junit.Assume.assumeNotNull;
+import static org.junit.Assume.assumeTrue;
+import org.junit.Before;
+import org.junit.Rule;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.junit.runners.JUnit4;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import org.apache.beam.sdk.ml.remoteinference.base.*;
+import org.apache.beam.sdk.ml.remoteinference.RemoteInference;
+import org.apache.beam.sdk.ml.remoteinference.openai.OpenAIModelHandler.StructuredInputOutput;
+import org.apache.beam.sdk.ml.remoteinference.openai.OpenAIModelHandler.Response;
+
+public class OpenAIModelHandlerIT {
+ private static final Logger LOG = LoggerFactory.getLogger(OpenAIModelHandlerIT.class);
+
+ @Rule
+ public final transient TestPipeline pipeline = TestPipeline.create();
+
+ private String apiKey;
+ private static final String API_KEY_ENV = "OPENAI_API_KEY";
+ private static final String DEFAULT_MODEL = "gpt-4o-mini";
+
+
+ @Before
+ public void setUp() {
+ // Get API key
+ apiKey = System.getenv(API_KEY_ENV);
+
+ // Skip tests if API key is not provided
+ assumeNotNull(
+ "OpenAI API key not found. Set " + API_KEY_ENV
+ + " environment variable to run integration tests.",
+ apiKey);
+ assumeTrue("OpenAI API key is empty. Set " + API_KEY_ENV
+ + " environment variable to run integration tests.",
+ !apiKey.trim().isEmpty());
+ }
+
+ @Test
+ public void testSentimentAnalysisWithSingleInput() {
+ String input = "This product is absolutely amazing! I love it!";
+
+ PCollection inputs = pipeline
+ .apply("CreateSingleInput", Create.of(input))
+ .apply("MapToInput", MapElements
+ .into(TypeDescriptor.of(OpenAIModelInput.class))
+ .via(OpenAIModelInput::create));
+
+ PCollection>> results = inputs
+ .apply("SentimentInference",
+ RemoteInference.invoke()
+ .handler(OpenAIModelHandler.class)
+ .withParameters(OpenAIModelParameters.builder()
+ .apiKey(apiKey)
+ .modelName(DEFAULT_MODEL)
+ .instructionPrompt(
+ "Analyze the sentiment as 'positive' or 'negative'. Return only one word.")
+ .build()));
+
+ // Verify results
+ PAssert.that(results).satisfies(batches -> {
+ int count = 0;
+ for (Iterable> batch : batches) {
+ for (PredictionResult result : batch) {
+ count++;
+ assertNotNull("Input should not be null", result.getInput());
+ assertNotNull("Output should not be null", result.getOutput());
+ assertNotNull("Output text should not be null",
+ result.getOutput().getModelResponse());
+
+ String sentiment = result.getOutput().getModelResponse().toLowerCase();
+ assertTrue("Sentiment should be positive or negative, got: " + sentiment,
+ sentiment.contains("positive")
+ || sentiment.contains("negative"));
+ }
+ }
+ assertEquals("Should have exactly 1 result", 1, count);
+ return null;
+ });
+
+ pipeline.run().waitUntilFinish();
+ }
+
+ @Test
+ public void testSentimentAnalysisWithMultipleInputs() {
+ List inputs = Arrays.asList(
+ "An excellent B2B SaaS solution that streamlines business processes efficiently.",
+ "The customer support is terrible. I've been waiting for days without any response.",
+ "The application works as expected. Installation was straightforward.",
+ "Really impressed with the innovative features! The AI capabilities are groundbreaking!",
+ "Mediocre product with occasional glitches. Documentation could be better.");
+
+ PCollection inputCollection = pipeline
+ .apply("CreateMultipleInputs", Create.of(inputs))
+ .apply("MapToInputs", MapElements
+ .into(TypeDescriptor.of(OpenAIModelInput.class))
+ .via(OpenAIModelInput::create));
+
+ PCollection>> results = inputCollection
+ .apply("SentimentInference",
+ RemoteInference.invoke()
+ .handler(OpenAIModelHandler.class)
+ .withParameters(OpenAIModelParameters.builder()
+ .apiKey(apiKey)
+ .modelName(DEFAULT_MODEL)
+ .instructionPrompt(
+ "Analyze sentiment as positive or negative")
+ .build()));
+
+ // Verify we get results for all inputs
+ PAssert.that(results).satisfies(batches -> {
+ int totalCount = 0;
+ for (Iterable> batch : batches) {
+ for (PredictionResult result : batch) {
+ totalCount++;
+ assertNotNull("Input should not be null", result.getInput());
+ assertNotNull("Output should not be null", result.getOutput());
+ assertFalse("Output should not be empty",
+ result.getOutput().getModelResponse().trim().isEmpty());
+ }
+ }
+ assertEquals("Should have results for all 5 inputs", 5, totalCount);
+ return null;
+ });
+
+ pipeline.run().waitUntilFinish();
+ }
+
+ @Test
+ public void testTextClassification() {
+ List inputs = Arrays.asList(
+ "How do I reset my password?",
+ "Your product is broken and I want a refund!",
+ "Thank you for the excellent service!");
+
+ PCollection inputCollection = pipeline
+ .apply("CreateInputs", Create.of(inputs))
+ .apply("MapToInputs", MapElements
+ .into(TypeDescriptor.of(OpenAIModelInput.class))
+ .via(OpenAIModelInput::create));
+
+ PCollection>> results = inputCollection
+ .apply("ClassificationInference",
+ RemoteInference.invoke()
+ .handler(OpenAIModelHandler.class)
+ .withParameters(OpenAIModelParameters.builder()
+ .apiKey(apiKey)
+ .modelName(DEFAULT_MODEL)
+ .instructionPrompt(
+ "Classify each text into one category: 'question', 'complaint', or 'praise'. Return only the category.")
+ .build()));
+
+ PAssert.that(results).satisfies(batches -> {
+ List categories = new ArrayList<>();
+ for (Iterable> batch : batches) {
+ for (PredictionResult result : batch) {
+ String category = result.getOutput().getModelResponse().toLowerCase();
+ categories.add(category);
+ }
+ }
+
+ assertEquals("Should have 3 categories", 3, categories.size());
+
+ // Verify expected categories
+ boolean hasQuestion = categories.stream().anyMatch(c -> c.contains("question"));
+ boolean hasComplaint = categories.stream().anyMatch(c -> c.contains("complaint"));
+ boolean hasPraise = categories.stream().anyMatch(c -> c.contains("praise"));
+
+ assertTrue("Should have at least one recognized category",
+ hasQuestion || hasComplaint || hasPraise);
+
+ return null;
+ });
+
+ pipeline.run().waitUntilFinish();
+ }
+
+ @Test
+ public void testInputOutputMapping() {
+ List inputs = Arrays.asList("apple", "banana", "cherry");
+
+ PCollection inputCollection = pipeline
+ .apply("CreateInputs", Create.of(inputs))
+ .apply("MapToInputs", MapElements
+ .into(TypeDescriptor.of(OpenAIModelInput.class))
+ .via(OpenAIModelInput::create));
+
+ PCollection>> results = inputCollection
+ .apply("MappingInference",
+ RemoteInference.invoke()
+ .handler(OpenAIModelHandler.class)
+ .withParameters(OpenAIModelParameters.builder()
+ .apiKey(apiKey)
+ .modelName(DEFAULT_MODEL)
+ .instructionPrompt(
+ "Return the input word in uppercase")
+ .build()));
+
+ // Verify input-output pairing is preserved
+ PAssert.that(results).satisfies(batches -> {
+ for (Iterable> batch : batches) {
+ for (PredictionResult result : batch) {
+ String input = result.getInput().getModelInput();
+ String output = result.getOutput().getModelResponse().toLowerCase();
+
+ // Verify the output relates to the input
+ assertTrue("Output should relate to input '" + input + "', got: " + output,
+ output.contains(input.toLowerCase()));
+ }
+ }
+ return null;
+ });
+
+ pipeline.run().waitUntilFinish();
+ }
+
+ @Test
+ public void testWithDifferentModel() {
+ // Test with a different model
+ String input = "Explain quantum computing in one sentence.";
+
+ PCollection inputs = pipeline
+ .apply("CreateInput", Create.of(input))
+ .apply("MapToInput", MapElements
+ .into(TypeDescriptor.of(OpenAIModelInput.class))
+ .via(OpenAIModelInput::create));
+
+ PCollection>> results = inputs
+ .apply("DifferentModelInference",
+ RemoteInference.invoke()
+ .handler(OpenAIModelHandler.class)
+ .withParameters(OpenAIModelParameters.builder()
+ .apiKey(apiKey)
+ .modelName("gpt-5")
+ .instructionPrompt("Respond concisely")
+ .build()));
+
+ PAssert.that(results).satisfies(batches -> {
+ for (Iterable> batch : batches) {
+ for (PredictionResult result : batch) {
+ assertNotNull("Output should not be null",
+ result.getOutput().getModelResponse());
+ assertFalse("Output should not be empty",
+ result.getOutput().getModelResponse().trim().isEmpty());
+ }
+ }
+ return null;
+ });
+
+ pipeline.run().waitUntilFinish();
+ }
+
+ @Test
+ public void testWithInvalidApiKey() {
+ String input = "Test input";
+
+ PCollection inputs = pipeline
+ .apply("CreateInput", Create.of(input))
+ .apply("MapToInput", MapElements
+ .into(TypeDescriptor.of(OpenAIModelInput.class))
+ .via(OpenAIModelInput::create));
+
+ inputs.apply("InvalidKeyInference",
+ RemoteInference.invoke()
+ .handler(OpenAIModelHandler.class)
+ .withParameters(OpenAIModelParameters.builder()
+ .apiKey("invalid-api-key-12345")
+ .modelName(DEFAULT_MODEL)
+ .instructionPrompt("Test")
+ .build()));
+
+ // Expect pipeline to fail with authentication error
+ try {
+ pipeline.run().waitUntilFinish();
+ fail("Expected pipeline to fail with invalid API key");
+ } catch (Exception e) {
+ // Expected - verify it's an authentication error
+ String message = e.getMessage().toLowerCase();
+ assertTrue("Exception should mention authentication or API key issue, got: " + message,
+ message.contains("auth") ||
+ message.contains("api") ||
+ message.contains("key") ||
+ message.contains("401"));
+ }
+ }
+
+ /**
+ * Test with custom instruction formats
+ */
+ @Test
+ public void testWithJsonOutputFormat() {
+ String input = "Paris is the capital of France";
+
+ PCollection inputs = pipeline
+ .apply("CreateInput", Create.of(input))
+ .apply("MapToInput", MapElements
+ .into(TypeDescriptor.of(OpenAIModelInput.class))
+ .via(OpenAIModelInput::create));
+
+ PCollection>> results = inputs
+ .apply("JsonFormatInference",
+ RemoteInference.invoke()
+ .handler(OpenAIModelHandler.class)
+ .withParameters(OpenAIModelParameters.builder()
+ .apiKey(apiKey)
+ .modelName(DEFAULT_MODEL)
+ .instructionPrompt(
+ "Extract the city and country. Return as: City: [city], Country: [country]")
+ .build()));
+
+ PAssert.that(results).satisfies(batches -> {
+ for (Iterable> batch : batches) {
+ for (PredictionResult result : batch) {
+ String output = result.getOutput().getModelResponse();
+ LOG.info("Structured output: " + output);
+
+ // Verify output contains expected information
+ assertTrue("Output should mention Paris: " + output,
+ output.toLowerCase().contains("paris"));
+ assertTrue("Output should mention France: " + output,
+ output.toLowerCase().contains("france"));
+ }
+ }
+ return null;
+ });
+
+ pipeline.run().waitUntilFinish();
+ }
+}
diff --git a/sdks/java/ml/remoteinference/src/test/java/org/apache/beam/sdk/ml/remoteinference/openai/OpenAIModelHandlerTest.java b/sdks/java/ml/remoteinference/src/test/java/org/apache/beam/sdk/ml/remoteinference/openai/OpenAIModelHandlerTest.java
new file mode 100644
index 000000000000..3bff5600aa45
--- /dev/null
+++ b/sdks/java/ml/remoteinference/src/test/java/org/apache/beam/sdk/ml/remoteinference/openai/OpenAIModelHandlerTest.java
@@ -0,0 +1,450 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * License); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an AS IS BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.sdk.ml.remoteinference.openai;
+
+import java.util.Arrays;
+import java.util.Collections;
+import java.util.List;
+import java.util.stream.Collectors;
+
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertNotNull;
+import static org.junit.Assert.assertTrue;
+import static org.junit.Assert.fail;
+import org.junit.Before;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.junit.runners.JUnit4;
+
+import org.apache.beam.sdk.ml.remoteinference.base.*;
+import org.apache.beam.sdk.ml.remoteinference.openai.OpenAIModelHandler.StructuredInputOutput;
+import org.apache.beam.sdk.ml.remoteinference.openai.OpenAIModelHandler.Response;
+
+
+
+@RunWith(JUnit4.class)
+public class OpenAIModelHandlerTest {
+ private OpenAIModelParameters testParameters;
+
+ @Before
+ public void setUp() {
+ testParameters = OpenAIModelParameters.builder()
+ .apiKey("test-api-key")
+ .modelName("gpt-4")
+ .instructionPrompt("Test instruction")
+ .build();
+ }
+
+ /**
+ * Fake OpenAiModelHandler for testing.
+ */
+ static class FakeOpenAiModelHandler extends OpenAIModelHandler {
+
+ private boolean clientCreated = false;
+ private OpenAIModelParameters storedParameters;
+ private List responsesToReturn;
+ private RuntimeException exceptionToThrow;
+ private boolean shouldReturnNull = false;
+
+ public void setResponsesToReturn(List responses) {
+ this.responsesToReturn = responses;
+ }
+
+ public void setExceptionToThrow(RuntimeException exception) {
+ this.exceptionToThrow = exception;
+ }
+
+ public void setShouldReturnNull(boolean shouldReturnNull) {
+ this.shouldReturnNull = shouldReturnNull;
+ }
+
+ public boolean isClientCreated() {
+ return clientCreated;
+ }
+
+ public OpenAIModelParameters getStoredParameters() {
+ return storedParameters;
+ }
+
+ @Override
+ public void createClient(OpenAIModelParameters parameters) {
+ this.storedParameters = parameters;
+ this.clientCreated = true;
+
+ if (exceptionToThrow != null) {
+ throw exceptionToThrow;
+ }
+ }
+
+ @Override
+ public Iterable> request(
+ List input) {
+
+ if (!clientCreated) {
+ throw new IllegalStateException("Client not initialized");
+ }
+
+ if (exceptionToThrow != null) {
+ throw exceptionToThrow;
+ }
+
+ if (shouldReturnNull || responsesToReturn == null) {
+ throw new RuntimeException("Model returned no structured responses");
+ }
+
+ StructuredInputOutput structuredOutput = responsesToReturn.get(0);
+
+ if (structuredOutput == null || structuredOutput.responses == null) {
+ throw new RuntimeException("Model returned no structured responses");
+ }
+
+ return structuredOutput.responses.stream()
+ .map(response -> PredictionResult.create(
+ OpenAIModelInput.create(response.input),
+ OpenAIModelResponse.create(response.output)))
+ .collect(Collectors.toList());
+ }
+ }
+
+ @Test
+ public void testCreateClient() {
+ FakeOpenAiModelHandler handler = new FakeOpenAiModelHandler();
+
+ OpenAIModelParameters params = OpenAIModelParameters.builder()
+ .apiKey("test-key")
+ .modelName("gpt-4")
+ .instructionPrompt("test prompt")
+ .build();
+
+ handler.createClient(params);
+
+ assertTrue("Client should be created", handler.isClientCreated());
+ assertNotNull("Parameters should be stored", handler.getStoredParameters());
+ assertEquals("test-key", handler.getStoredParameters().getApiKey());
+ assertEquals("gpt-4", handler.getStoredParameters().getModelName());
+ }
+
+ @Test
+ public void testRequestWithSingleInput() {
+ FakeOpenAiModelHandler handler = new FakeOpenAiModelHandler();
+
+ List inputs = Collections.singletonList(
+ OpenAIModelInput.create("test input"));
+
+ StructuredInputOutput structuredOutput = new StructuredInputOutput();
+ Response response = new Response();
+ response.input = "test input";
+ response.output = "test output";
+ structuredOutput.responses = Collections.singletonList(response);
+
+ handler.setResponsesToReturn(Collections.singletonList(structuredOutput));
+ handler.createClient(testParameters);
+
+ Iterable> results = handler.request(inputs);
+
+ assertNotNull("Results should not be null", results);
+
+ List> resultList = iterableToList(results);
+
+ assertEquals("Should have 1 result", 1, resultList.size());
+
+ PredictionResult result = resultList.get(0);
+ assertEquals("test input", result.getInput().getModelInput());
+ assertEquals("test output", result.getOutput().getModelResponse());
+ }
+
+ @Test
+ public void testRequestWithMultipleInputs() {
+ FakeOpenAiModelHandler handler = new FakeOpenAiModelHandler();
+
+ List inputs = Arrays.asList(
+ OpenAIModelInput.create("input1"),
+ OpenAIModelInput.create("input2"),
+ OpenAIModelInput.create("input3"));
+
+ StructuredInputOutput structuredOutput = new StructuredInputOutput();
+
+ Response response1 = new Response();
+ response1.input = "input1";
+ response1.output = "output1";
+
+ Response response2 = new Response();
+ response2.input = "input2";
+ response2.output = "output2";
+
+ Response response3 = new Response();
+ response3.input = "input3";
+ response3.output = "output3";
+
+ structuredOutput.responses = Arrays.asList(response1, response2, response3);
+
+ handler.setResponsesToReturn(Collections.singletonList(structuredOutput));
+ handler.createClient(testParameters);
+
+ Iterable> results = handler.request(inputs);
+
+ List> resultList = iterableToList(results);
+
+ assertEquals("Should have 3 results", 3, resultList.size());
+
+ for (int i = 0; i < 3; i++) {
+ PredictionResult result = resultList.get(i);
+ assertEquals("input" + (i + 1), result.getInput().getModelInput());
+ assertEquals("output" + (i + 1), result.getOutput().getModelResponse());
+ }
+ }
+
+ @Test
+ public void testRequestWithEmptyInput() {
+ FakeOpenAiModelHandler handler = new FakeOpenAiModelHandler();
+
+ List inputs = Collections.emptyList();
+
+ StructuredInputOutput structuredOutput = new StructuredInputOutput();
+ structuredOutput.responses = Collections.emptyList();
+
+ handler.setResponsesToReturn(Collections.singletonList(structuredOutput));
+ handler.createClient(testParameters);
+
+ Iterable> results = handler.request(inputs);
+
+ List> resultList = iterableToList(results);
+ assertEquals("Should have 0 results", 0, resultList.size());
+ }
+
+ @Test
+ public void testRequestWithNullStructuredOutput() {
+ FakeOpenAiModelHandler handler = new FakeOpenAiModelHandler();
+
+ List inputs = Collections.singletonList(
+ OpenAIModelInput.create("test input"));
+
+ handler.setShouldReturnNull(true);
+ handler.createClient(testParameters);
+
+ try {
+ handler.request(inputs);
+ fail("Expected RuntimeException when structured output is null");
+ } catch (RuntimeException e) {
+ assertTrue("Exception message should mention no structured responses",
+ e.getMessage().contains("Model returned no structured responses"));
+ }
+ }
+
+ @Test
+ public void testRequestWithNullResponsesList() {
+ FakeOpenAiModelHandler handler = new FakeOpenAiModelHandler();
+
+ List inputs = Collections.singletonList(
+ OpenAIModelInput.create("test input"));
+
+ StructuredInputOutput structuredOutput = new StructuredInputOutput();
+ structuredOutput.responses = null;
+
+ handler.setResponsesToReturn(Collections.singletonList(structuredOutput));
+ handler.createClient(testParameters);
+
+ try {
+ handler.request(inputs);
+ fail("Expected RuntimeException when responses list is null");
+ } catch (RuntimeException e) {
+ assertTrue("Exception message should mention no structured responses",
+ e.getMessage().contains("Model returned no structured responses"));
+ }
+ }
+
+ @Test
+ public void testCreateClientFailure() {
+ FakeOpenAiModelHandler handler = new FakeOpenAiModelHandler();
+ handler.setExceptionToThrow(new RuntimeException("Setup failed"));
+
+ try {
+ handler.createClient(testParameters);
+ fail("Expected RuntimeException during client creation");
+ } catch (RuntimeException e) {
+ assertEquals("Setup failed", e.getMessage());
+ }
+ }
+
+ @Test
+ public void testRequestApiFailure() {
+ FakeOpenAiModelHandler handler = new FakeOpenAiModelHandler();
+
+ List inputs = Collections.singletonList(
+ OpenAIModelInput.create("test input"));
+
+ handler.createClient(testParameters);
+ handler.setExceptionToThrow(new RuntimeException("API Error"));
+
+ try {
+ handler.request(inputs);
+ fail("Expected RuntimeException when API fails");
+ } catch (RuntimeException e) {
+ assertEquals("API Error", e.getMessage());
+ }
+ }
+
+ @Test
+ public void testRequestWithoutClientInitialization() {
+ FakeOpenAiModelHandler handler = new FakeOpenAiModelHandler();
+
+ List inputs = Collections.singletonList(
+ OpenAIModelInput.create("test input"));
+
+ StructuredInputOutput structuredOutput = new StructuredInputOutput();
+ Response response = new Response();
+ response.input = "test input";
+ response.output = "test output";
+ structuredOutput.responses = Collections.singletonList(response);
+
+ handler.setResponsesToReturn(Collections.singletonList(structuredOutput));
+
+ // Don't call createClient
+ try {
+ handler.request(inputs);
+ fail("Expected IllegalStateException when client not initialized");
+ } catch (IllegalStateException e) {
+ assertTrue("Exception should mention client not initialized",
+ e.getMessage().contains("Client not initialized"));
+ }
+ }
+
+ @Test
+ public void testInputOutputMapping() {
+ FakeOpenAiModelHandler handler = new FakeOpenAiModelHandler();
+
+ List inputs = Arrays.asList(
+ OpenAIModelInput.create("alpha"),
+ OpenAIModelInput.create("beta"));
+
+ StructuredInputOutput structuredOutput = new StructuredInputOutput();
+
+ Response response1 = new Response();
+ response1.input = "alpha";
+ response1.output = "ALPHA";
+
+ Response response2 = new Response();
+ response2.input = "beta";
+ response2.output = "BETA";
+
+ structuredOutput.responses = Arrays.asList(response1, response2);
+
+ handler.setResponsesToReturn(Collections.singletonList(structuredOutput));
+ handler.createClient(testParameters);
+
+ Iterable> results = handler.request(inputs);
+
+ List> resultList = iterableToList(results);
+
+ assertEquals(2, resultList.size());
+ assertEquals("alpha", resultList.get(0).getInput().getModelInput());
+ assertEquals("ALPHA", resultList.get(0).getOutput().getModelResponse());
+
+ assertEquals("beta", resultList.get(1).getInput().getModelInput());
+ assertEquals("BETA", resultList.get(1).getOutput().getModelResponse());
+ }
+
+ @Test
+ public void testParametersBuilder() {
+ OpenAIModelParameters params = OpenAIModelParameters.builder()
+ .apiKey("my-api-key")
+ .modelName("gpt-4-turbo")
+ .instructionPrompt("Custom instruction")
+ .build();
+
+ assertEquals("my-api-key", params.getApiKey());
+ assertEquals("gpt-4-turbo", params.getModelName());
+ assertEquals("Custom instruction", params.getInstructionPrompt());
+ }
+
+ @Test
+ public void testOpenAIModelInputCreate() {
+ OpenAIModelInput input = OpenAIModelInput.create("test value");
+
+ assertNotNull("Input should not be null", input);
+ assertEquals("test value", input.getModelInput());
+ }
+
+ @Test
+ public void testOpenAIModelResponseCreate() {
+ OpenAIModelResponse response = OpenAIModelResponse.create("test output");
+
+ assertNotNull("Response should not be null", response);
+ assertEquals("test output", response.getModelResponse());
+ }
+
+ @Test
+ public void testStructuredInputOutputStructure() {
+ Response response = new Response();
+ response.input = "test-input";
+ response.output = "test-output";
+
+ assertEquals("test-input", response.input);
+ assertEquals("test-output", response.output);
+
+ StructuredInputOutput structured = new StructuredInputOutput();
+ structured.responses = Collections.singletonList(response);
+
+ assertNotNull("Responses should not be null", structured.responses);
+ assertEquals("Should have 1 response", 1, structured.responses.size());
+ assertEquals("test-input", structured.responses.get(0).input);
+ }
+
+ @Test
+ public void testMultipleRequestsWithSameHandler() {
+ FakeOpenAiModelHandler handler = new FakeOpenAiModelHandler();
+ handler.createClient(testParameters);
+
+ // First request
+ StructuredInputOutput output1 = new StructuredInputOutput();
+ Response response1 = new Response();
+ response1.input = "first";
+ response1.output = "FIRST";
+ output1.responses = Collections.singletonList(response1);
+ handler.setResponsesToReturn(Collections.singletonList(output1));
+
+ List inputs1 = Collections.singletonList(
+ OpenAIModelInput.create("first"));
+ Iterable> results1 = handler.request(inputs1);
+
+ List> resultList1 = iterableToList(results1);
+ assertEquals("FIRST", resultList1.get(0).getOutput().getModelResponse());
+
+ // Second request with different data
+ StructuredInputOutput output2 = new StructuredInputOutput();
+ Response response2 = new Response();
+ response2.input = "second";
+ response2.output = "SECOND";
+ output2.responses = Collections.singletonList(response2);
+ handler.setResponsesToReturn(Collections.singletonList(output2));
+
+ List inputs2 = Collections.singletonList(
+ OpenAIModelInput.create("second"));
+ Iterable> results2 = handler.request(inputs2);
+
+ List> resultList2 = iterableToList(results2);
+ assertEquals("SECOND", resultList2.get(0).getOutput().getModelResponse());
+ }
+
+ // Helper method to convert Iterable to List
+ private List iterableToList(Iterable iterable) {
+ List list = new java.util.ArrayList<>();
+ iterable.forEach(list::add);
+ return list;
+ }
+}
diff --git a/settings.gradle.kts b/settings.gradle.kts
index 72c5194ec93d..fdfc5da6854c 100644
--- a/settings.gradle.kts
+++ b/settings.gradle.kts
@@ -383,3 +383,5 @@ include("sdks:java:extensions:sql:iceberg")
findProject(":sdks:java:extensions:sql:iceberg")?.name = "iceberg"
include("examples:java:iceberg")
findProject(":examples:java:iceberg")?.name = "iceberg"
+
+include("sdks:java:ml:remoteinference")
\ No newline at end of file