diff --git a/sdks/java/ml/build.gradle b/sdks/java/ml/build.gradle new file mode 100644 index 000000000000..b18480559e14 --- /dev/null +++ b/sdks/java/ml/build.gradle @@ -0,0 +1,30 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * License); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an AS IS BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +plugins { id 'org.apache.beam.module' } +applyJavaNature( + automaticModuleName: 'org.apache.beam.sdk.ml', +) +provideIntegrationTestingDependencies() +enableJavaPerformanceTesting() + +description = "Apache Beam :: SDKs :: Java :: ML" +ext.summary = "Java ML module" + +dependencies { +} diff --git a/sdks/java/ml/remoteinference/build.gradle.kts b/sdks/java/ml/remoteinference/build.gradle.kts new file mode 100644 index 000000000000..98e3ce620bb9 --- /dev/null +++ b/sdks/java/ml/remoteinference/build.gradle.kts @@ -0,0 +1,47 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * License); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an AS IS BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +plugins { + id("org.apache.beam.module") + id("java-library") +} + +description = "Apache Beam :: SDKs :: Java :: ML :: RemoteInference" + +dependencies { + // Core Beam SDK + implementation(project(":sdks:java:core")) + + implementation("com.openai:openai-java:4.3.0") + compileOnly("com.google.auto.value:auto-value-annotations:1.11.0") + compileOnly("org.checkerframework:checker-qual:3.42.0") + annotationProcessor("com.google.auto.value:auto-value:1.11.0") + implementation("com.fasterxml.jackson.core:jackson-core:2.20.0") + implementation("org.apache.beam:beam-vendor-guava-32_1_2-jre:0.1") + implementation("org.slf4j:slf4j-api:2.0.9") + implementation("org.slf4j:slf4j-simple:2.0.9") + + // testing + testImplementation(project(":runners:direct-java")) + testImplementation("org.hamcrest:hamcrest:2.2") + testImplementation("org.mockito:mockito-core:5.8.0") + testImplementation("org.mockito:mockito-junit-jupiter:5.8.0") + testImplementation("junit:junit:4.13.2") + testImplementation(project(":sdks:java:testing:test-utils")) +} + diff --git a/sdks/java/ml/remoteinference/src/main/java/org/apache/beam/sdk/ml/remoteinference/RemoteInference.java b/sdks/java/ml/remoteinference/src/main/java/org/apache/beam/sdk/ml/remoteinference/RemoteInference.java new file mode 100644 index 000000000000..6a8b9b656bdf --- /dev/null +++ b/sdks/java/ml/remoteinference/src/main/java/org/apache/beam/sdk/ml/remoteinference/RemoteInference.java @@ -0,0 +1,169 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * License); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an AS IS BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.sdk.ml.remoteinference; + +import org.apache.beam.sdk.ml.remoteinference.base.*; +import org.apache.beam.sdk.transforms.*; +import org.checkerframework.checker.nullness.qual.Nullable; +import static org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions.checkArgument; +import org.apache.beam.sdk.values.PCollection; +import com.google.auto.value.AutoValue; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; + +/** + * A {@link PTransform} for making remote inference calls to external machine learning services. + * + *

{@code RemoteInference} provides a framework for integrating remote ML model + * inference into Apache Beam pipelines and handles the communication between pipelines + * and external inference APIs. + * + *

Example: OpenAI Model Inference

+ * + *
{@code
+ * // Create model parameters
+ * OpenAIModelParameters params = OpenAIModelParameters.builder()
+ *     .apiKey("your-api-key")
+ *     .modelName("gpt-4")
+ *     .instructionPrompt("Analyse sentiment as positive or negative")
+ *     .build();
+ *
+ * // Apply remote inference transform
+ * PCollection inputs = pipeline.apply(Create.of(
+ *     OpenAIModelInput.create("An excellent B2B SaaS solution that streamlines business processes efficiently."),
+ *     OpenAIModelInput.create("Really impressed with the innovative features!")
+ * ));
+ *
+ * PCollection>> results =
+ *     inputs.apply(
+ *         RemoteInference.invoke()
+ *             .handler(OpenAIModelHandler.class)
+ *             .withParameters(params)
+ *     );
+ * }
+ * + */ +@SuppressWarnings({ "rawtypes", "unchecked" }) +public class RemoteInference { + + /** Invoke the model handler with model parameters */ + public static Invoke invoke() { + return new AutoValue_RemoteInference_Invoke.Builder().setParameters(null) + .build(); + } + + private RemoteInference() { + } + + @AutoValue + public abstract static class Invoke + extends PTransform, PCollection>>> { + + abstract @Nullable Class handler(); + + abstract @Nullable BaseModelParameters parameters(); + + + abstract Builder builder(); + + @AutoValue.Builder + abstract static class Builder { + + abstract Builder setHandler(Class modelHandler); + + abstract Builder setParameters(BaseModelParameters modelParameters); + + + abstract Invoke build(); + } + + /** + * Model handler class for inference. + */ + public Invoke handler(Class modelHandler) { + return builder().setHandler(modelHandler).build(); + } + + /** + * Configures the parameters for model initialization. + */ + public Invoke withParameters(BaseModelParameters modelParameters) { + return builder().setParameters(modelParameters).build(); + } + + + @Override + public PCollection>> expand(PCollection input) { + checkArgument(handler() != null, "handler() is required"); + checkArgument(parameters() != null, "withParameters() is required"); + return input + .apply("WrapInputInList", MapElements.via(new SimpleFunction>() { + @Override + public List apply(InputT element) { + return Collections.singletonList(element); + } + })) + // Pass the list to the inference function + .apply("RemoteInference", ParDo.of(new RemoteInferenceFn(this))); + } + + /** + * A {@link DoFn} that performs remote inference operation. + * + *

This function manages the lifecycle of the model handler: + *

    + *
  • Instantiates the handler during {@link Setup}
  • + *
  • Initializes the remote client via {@link BaseModelHandler#createClient}
  • + *
  • Processes elements by calling {@link BaseModelHandler#request}
  • + *
+ */ + static class RemoteInferenceFn + extends DoFn, Iterable>> { + + private final Class handlerClass; + private final BaseModelParameters parameters; + private transient BaseModelHandler handler; + + RemoteInferenceFn(Invoke spec) { + this.handlerClass = spec.handler(); + this.parameters = spec.parameters(); + } + + /** Instantiate the model handler and client*/ + @Setup + public void setupHandler() { + try { + this.handler = handlerClass.getDeclaredConstructor().newInstance(); + this.handler.createClient(parameters); + } catch (Exception e) { + throw new RuntimeException("Failed to instantiate handler: " + + handlerClass.getName(), e); + } + } + /** Perform Inference */ + @ProcessElement + public void processElement(ProcessContext c) { + Iterable> response = this.handler.request(c.element()); + c.output(response); + } + } + + } +} diff --git a/sdks/java/ml/remoteinference/src/main/java/org/apache/beam/sdk/ml/remoteinference/base/BaseInput.java b/sdks/java/ml/remoteinference/src/main/java/org/apache/beam/sdk/ml/remoteinference/base/BaseInput.java new file mode 100644 index 000000000000..0a84a67fdb4c --- /dev/null +++ b/sdks/java/ml/remoteinference/src/main/java/org/apache/beam/sdk/ml/remoteinference/base/BaseInput.java @@ -0,0 +1,28 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * License); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an AS IS BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.sdk.ml.remoteinference.base; + +import java.io.Serializable; + +/** + * Base class for defining input types used with remote inference transforms. + *Implementations holds the data needed for inference (text, images, etc.) + */ +public abstract class BaseInput implements Serializable { + +} diff --git a/sdks/java/ml/remoteinference/src/main/java/org/apache/beam/sdk/ml/remoteinference/base/BaseModelHandler.java b/sdks/java/ml/remoteinference/src/main/java/org/apache/beam/sdk/ml/remoteinference/base/BaseModelHandler.java new file mode 100644 index 000000000000..1128a287d927 --- /dev/null +++ b/sdks/java/ml/remoteinference/src/main/java/org/apache/beam/sdk/ml/remoteinference/base/BaseModelHandler.java @@ -0,0 +1,69 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * License); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an AS IS BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.sdk.ml.remoteinference.base; + +import java.util.List; + +/** + * Interface for model-specific handlers that perform remote inference operations. + * + *

Implementations of this interface encapsulate all logic for communicating with a + * specific remote inference service. Each handler is responsible for: + *

    + *
  • Initializing and managing client connections
  • + *
  • Converting Beam inputs to service-specific request formats
  • + *
  • Making inference API calls
  • + *
  • Converting service responses to Beam output types
  • + *
  • Handling errors and retries if applicable
  • + *
+ * + *

Lifecycle

+ * + *

Handler instances follow this lifecycle: + *

    + *
  1. Instantiation via no-argument constructor
  2. + *
  3. {@link #createClient} called with parameters during setup
  4. + *
  5. {@link #request} called for each batch of inputs
  6. + *
+ * + * + *

Handlers typically contain non-serializable client objects. + * Mark client fields as {@code transient} and initialize them in {@link #createClient} + * + *

Batching Considerations

+ * + *

The {@link #request} method receives a list of inputs. Implementations should: + *

    + *
  • Batch inputs efficiently if the service supports batch inference
  • + *
  • Return results in the same order as inputs
  • + *
  • Maintain input-output correspondence in {@link PredictionResult}
  • + *
+ * + */ +public interface BaseModelHandler { + /** + * Initializes the remote model client with the provided parameters. + */ + public void createClient(ParamT parameters); + + /** + * Performs inference on a batch of inputs and returns the results. + */ + public Iterable> request(List input); + +} diff --git a/sdks/java/ml/remoteinference/src/main/java/org/apache/beam/sdk/ml/remoteinference/base/BaseModelParameters.java b/sdks/java/ml/remoteinference/src/main/java/org/apache/beam/sdk/ml/remoteinference/base/BaseModelParameters.java new file mode 100644 index 000000000000..46ee72c73001 --- /dev/null +++ b/sdks/java/ml/remoteinference/src/main/java/org/apache/beam/sdk/ml/remoteinference/base/BaseModelParameters.java @@ -0,0 +1,40 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * License); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an AS IS BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.sdk.ml.remoteinference.base; + +import java.io.Serializable; + +/** + * Base interface for defining model-specific parameters used to configure remote inference clients. + * + *

Implementations of this interface encapsulate all configuration needed to initialize + * and communicate with a remote model inference service. This typically includes: + *

    + *
  • Authentication credentials (API keys, tokens)
  • + *
  • Model identifiers or names
  • + *
  • Endpoint URLs or connection settings
  • + *
  • Inference configuration (temperature, max tokens, timeout values, etc.)
  • + *
+ * + *

Parameters must be serializable. Consider using + * the builder pattern for complex parameter objects. + * + */ +public interface BaseModelParameters extends Serializable { + +} diff --git a/sdks/java/ml/remoteinference/src/main/java/org/apache/beam/sdk/ml/remoteinference/base/BaseResponse.java b/sdks/java/ml/remoteinference/src/main/java/org/apache/beam/sdk/ml/remoteinference/base/BaseResponse.java new file mode 100644 index 000000000000..b3c050c45c11 --- /dev/null +++ b/sdks/java/ml/remoteinference/src/main/java/org/apache/beam/sdk/ml/remoteinference/base/BaseResponse.java @@ -0,0 +1,34 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * License); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an AS IS BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.sdk.ml.remoteinference.base; + +import java.io.Serializable; + +/** + * Base class for defining response types returned from remote inference operations. + + *

Implementations: + *

    + *
  • Contain the inference results (predictions, classifications, generated text, etc.)
  • + *
  • Includes any relevant metadata
  • + *
+ * + */ +public abstract class BaseResponse implements Serializable { + +} diff --git a/sdks/java/ml/remoteinference/src/main/java/org/apache/beam/sdk/ml/remoteinference/base/PredictionResult.java b/sdks/java/ml/remoteinference/src/main/java/org/apache/beam/sdk/ml/remoteinference/base/PredictionResult.java new file mode 100644 index 000000000000..edc30fd11246 --- /dev/null +++ b/sdks/java/ml/remoteinference/src/main/java/org/apache/beam/sdk/ml/remoteinference/base/PredictionResult.java @@ -0,0 +1,53 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * License); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an AS IS BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.sdk.ml.remoteinference.base; + +import java.io.Serializable; + +/** + * Pairs an input with its corresponding inference output. + * + *

This class maintains the association between input data and its model's results + * for Downstream processing + */ +public class PredictionResult implements Serializable { + + private final InputT input; + private final OutputT output; + + private PredictionResult(InputT input, OutputT output) { + this.input = input; + this.output = output; + + } + + /* Returns input to handler */ + public InputT getInput() { + return input; + } + + /* Returns model handler's response*/ + public OutputT getOutput() { + return output; + } + + /* Creates a PredictionResult instance of provided input, output and types */ + public static PredictionResult create(InputT input, OutputT output) { + return new PredictionResult<>(input, output); + } +} diff --git a/sdks/java/ml/remoteinference/src/main/java/org/apache/beam/sdk/ml/remoteinference/openai/OpenAIModelHandler.java b/sdks/java/ml/remoteinference/src/main/java/org/apache/beam/sdk/ml/remoteinference/openai/OpenAIModelHandler.java new file mode 100644 index 000000000000..87616ee693d1 --- /dev/null +++ b/sdks/java/ml/remoteinference/src/main/java/org/apache/beam/sdk/ml/remoteinference/openai/OpenAIModelHandler.java @@ -0,0 +1,165 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * License); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an AS IS BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.sdk.ml.remoteinference.openai; + +import com.fasterxml.jackson.annotation.JsonProperty; +import com.fasterxml.jackson.annotation.JsonPropertyDescription; +import com.fasterxml.jackson.core.JsonProcessingException; +import com.fasterxml.jackson.databind.ObjectMapper; +import com.openai.client.OpenAIClient; +import com.openai.client.okhttp.OpenAIOkHttpClient; +import com.openai.core.JsonSchemaLocalValidation; +import com.openai.models.responses.ResponseCreateParams; +import com.openai.models.responses.StructuredResponseCreateParams; +import org.apache.beam.sdk.ml.remoteinference.base.BaseModelHandler; +import org.apache.beam.sdk.ml.remoteinference.base.PredictionResult; + +import java.util.List; +import java.util.stream.Collectors; + +/** + * Model handler for OpenAI API inference requests. + * + *

This handler manages communication with OpenAI's API, including client initialization, + * request formatting, and response parsing. It uses OpenAI's structured output feature to + * ensure reliable input-output pairing. + * + *

Usage

+ *
{@code
+ * OpenAIModelParameters params = OpenAIModelParameters.builder()
+ *     .apiKey("sk-...")
+ *     .modelName("gpt-4")
+ *     .instructionPrompt("Classify the following text into one of the categories: {CATEGORIES}")
+ *     .build();
+ *
+ * PCollection inputs = ...;
+ * PCollection>> results =
+ *     inputs.apply(
+ *         RemoteInference.invoke()
+ *             .handler(OpenAIModelHandler.class)
+ *             .withParameters(params)
+ *     );
+ * }
+ * + */ +public class OpenAIModelHandler + implements BaseModelHandler { + + private transient OpenAIClient client; + private transient StructuredResponseCreateParams clientParams; + private OpenAIModelParameters modelParameters; + + /** + * Initializes the OpenAI client with the provided parameters. + * + *

This method is called once during setup. It creates an authenticated + * OpenAI client using the API key from the parameters. + * + * @param parameters the configuration parameters including API key and model name + */ + @Override + public void createClient(OpenAIModelParameters parameters) { + this.modelParameters = parameters; + this.client = OpenAIOkHttpClient.builder() + .apiKey(this.modelParameters.getApiKey()) + .build(); + } + + /** + * Performs inference on a batch of inputs using the OpenAI Client. + * + *

This method serializes the input batch to JSON string, sends it to OpenAI with structured + * output requirements, and parses the response into {@link PredictionResult} objects + * that pair each input with its corresponding output. + * + * @param input the list of inputs to process + * @return an iterable of model results and input pairs + */ + @Override + public Iterable> request(List input) { + + try { + // Convert input list to JSON string + String inputBatch = new ObjectMapper() + .writeValueAsString(input.stream().map(OpenAIModelInput::getModelInput).toList()); + + // Build structured response parameters + this.clientParams = ResponseCreateParams.builder() + .model(modelParameters.getModelName()) + .input(inputBatch) + .text(StructuredInputOutput.class, JsonSchemaLocalValidation.NO) + .instructions(modelParameters.getInstructionPrompt()) + .build(); + + // Get structured output from the model + StructuredInputOutput structuredOutput = client.responses() + .create(clientParams) + .output() + .stream() + .flatMap(item -> item.message().stream()) + .flatMap(message -> message.content().stream()) + .flatMap(content -> content.outputText().stream()) + .findFirst() + .orElse(null); + + if (structuredOutput == null || structuredOutput.responses == null) { + throw new RuntimeException("Model returned no structured responses"); + } + + // Map responses to PredictionResults + List> results = structuredOutput.responses.stream() + .map(response -> PredictionResult.create( + OpenAIModelInput.create(response.input), + OpenAIModelResponse.create(response.output))) + .collect(Collectors.toList()); + + return results; + + } catch (JsonProcessingException e) { + throw new RuntimeException("Failed to serialize input batch", e); + } + } + + /** + * Schema class for structured output response. + * + *

Represents a single input-output pair returned by the OpenAI API. + */ + public static class Response { + @JsonProperty(required = true) + @JsonPropertyDescription("The input string") + public String input; + + @JsonProperty(required = true) + @JsonPropertyDescription("The output string") + public String output; + } + + /** + * Schema class for structured output containing multiple responses. + * + *

This class defines the expected JSON structure for OpenAI's structured output, + * ensuring reliable parsing of batched inference results. + */ + public static class StructuredInputOutput { + @JsonProperty(required = true) + @JsonPropertyDescription("Array of input-output pairs") + public List responses; + } + +} diff --git a/sdks/java/ml/remoteinference/src/main/java/org/apache/beam/sdk/ml/remoteinference/openai/OpenAIModelInput.java b/sdks/java/ml/remoteinference/src/main/java/org/apache/beam/sdk/ml/remoteinference/openai/OpenAIModelInput.java new file mode 100644 index 000000000000..1ef59c89da66 --- /dev/null +++ b/sdks/java/ml/remoteinference/src/main/java/org/apache/beam/sdk/ml/remoteinference/openai/OpenAIModelInput.java @@ -0,0 +1,64 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * License); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an AS IS BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.sdk.ml.remoteinference.openai; + +import org.apache.beam.sdk.ml.remoteinference.base.BaseInput; + +/** + * Input for OpenAI model inference requests. + * + *

This class encapsulates text input to be sent to OpenAI models. + * + *

Example Usage

+ *
{@code
+ * OpenAIModelInput input = OpenAIModelInput.create("Translate to French: Hello");
+ * String text = input.getModelInput(); // "Translate to French: Hello"
+ * }
+ * + * @see OpenAIModelHandler + * @see OpenAIModelResponse + */ +public class OpenAIModelInput extends BaseInput { + + private final String input; + + private OpenAIModelInput(String input) { + + this.input = input; + } + + /** + * Returns the text input for the model. + * + * @return the input text string + */ + public String getModelInput() { + return input; + } + + /** + * Creates a new input instance with the specified text. + * + * @param input the text to send to the model + * @return a new {@link OpenAIModelInput} instance + */ + public static OpenAIModelInput create(String input) { + return new OpenAIModelInput(input); + } + +} diff --git a/sdks/java/ml/remoteinference/src/main/java/org/apache/beam/sdk/ml/remoteinference/openai/OpenAIModelParameters.java b/sdks/java/ml/remoteinference/src/main/java/org/apache/beam/sdk/ml/remoteinference/openai/OpenAIModelParameters.java new file mode 100644 index 000000000000..3aac4112cba3 --- /dev/null +++ b/sdks/java/ml/remoteinference/src/main/java/org/apache/beam/sdk/ml/remoteinference/openai/OpenAIModelParameters.java @@ -0,0 +1,114 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * License); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an AS IS BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.sdk.ml.remoteinference.openai; + +import org.apache.beam.sdk.ml.remoteinference.base.BaseModelParameters; + +/** + * Configuration parameters required for OpenAI model inference. + * + *

This class encapsulates all configuration needed to initialize and communicate with + * OpenAI's API, including authentication credentials, model selection, and inference instructions. + * + *

Example Usage

+ *
{@code
+ * OpenAIModelParameters params = OpenAIModelParameters.builder()
+ *     .apiKey("sk-...")
+ *     .modelName("gpt-4")
+ *     .instructionPrompt("Translate the following text to French:")
+ *     .build();
+ * }
+ * + * @see OpenAIModelHandler + */ +public class OpenAIModelParameters implements BaseModelParameters { + + private final String apiKey; + private final String modelName; + private final String instructionPrompt; + + private OpenAIModelParameters(Builder builder) { + this.apiKey = builder.apiKey; + this.modelName = builder.modelName; + this.instructionPrompt = builder.instructionPrompt; + } + + public String getApiKey() { + return apiKey; + } + + public String getModelName() { + return modelName; + } + + public String getInstructionPrompt() { + return instructionPrompt; + } + + public static Builder builder() { + return new Builder(); + } + + + public static class Builder { + private String apiKey; + private String modelName; + private String instructionPrompt; + + private Builder() { + } + + /** + * Sets the OpenAI API key for authentication. + * + * @param apiKey the API key (required) + */ + public Builder apiKey(String apiKey) { + this.apiKey = apiKey; + return this; + } + + /** + * Sets the name of the OpenAI model to use. + * + * @param modelName the model name, e.g., "gpt-4" (required) + */ + public Builder modelName(String modelName) { + this.modelName = modelName; + return this; + } + /** + * Sets the instruction prompt for the model. + * This prompt provides context or instructions to the model about how to process + * the input text. + * + * @param prompt the instruction text (required) + */ + public Builder instructionPrompt(String prompt) { + this.instructionPrompt = prompt; + return this; + } + + /** + * Builds the {@link OpenAIModelParameters} instance. + */ + public OpenAIModelParameters build() { + return new OpenAIModelParameters(this); + } + } +} diff --git a/sdks/java/ml/remoteinference/src/main/java/org/apache/beam/sdk/ml/remoteinference/openai/OpenAIModelResponse.java b/sdks/java/ml/remoteinference/src/main/java/org/apache/beam/sdk/ml/remoteinference/openai/OpenAIModelResponse.java new file mode 100644 index 000000000000..abfd9d7cb5c3 --- /dev/null +++ b/sdks/java/ml/remoteinference/src/main/java/org/apache/beam/sdk/ml/remoteinference/openai/OpenAIModelResponse.java @@ -0,0 +1,61 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * License); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an AS IS BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.sdk.ml.remoteinference.openai; + +import org.apache.beam.sdk.ml.remoteinference.base.BaseResponse; + +/** + * Response from OpenAI model inference results. + *

This class encapsulates the text output returned from OpenAI models.. + * + *

Example Usage

+ *
{@code
+ * OpenAIModelResponse response = OpenAIModelResponse.create("Bonjour");
+ * String output = response.getModelResponse(); // "Bonjour"
+ * }
+ * + * @see OpenAIModelHandler + * @see OpenAIModelInput + */ +public class OpenAIModelResponse extends BaseResponse { + + private final String output; + + private OpenAIModelResponse(String output) { + this.output = output; + } + + /** + * Returns the text output from the model. + * + * @return the output text string + */ + public String getModelResponse() { + return output; + } + + /** + * Creates a new response instance with the specified output text. + * + * @param output the text returned by the model + * @return a new {@link OpenAIModelResponse} instance + */ + public static OpenAIModelResponse create(String output) { + return new OpenAIModelResponse(output); + } +} diff --git a/sdks/java/ml/remoteinference/src/test/java/org/apache/beam/sdk/ml/remoteinference/RemoteInferenceTest.java b/sdks/java/ml/remoteinference/src/test/java/org/apache/beam/sdk/ml/remoteinference/RemoteInferenceTest.java new file mode 100644 index 000000000000..3f351b9f88a3 --- /dev/null +++ b/sdks/java/ml/remoteinference/src/test/java/org/apache/beam/sdk/ml/remoteinference/RemoteInferenceTest.java @@ -0,0 +1,586 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * License); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an AS IS BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.sdk.ml.remoteinference; + +import java.util.Arrays; +import java.util.Collections; +import java.util.List; +import java.util.stream.Collectors; +import java.util.stream.StreamSupport; + +import org.apache.beam.sdk.coders.SerializableCoder; +import org.apache.beam.sdk.ml.remoteinference.base.*; +import org.apache.beam.sdk.testing.PAssert; +import org.apache.beam.sdk.testing.TestPipeline; +import org.apache.beam.sdk.transforms.Create; +import org.apache.beam.sdk.values.PCollection; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertThrows; +import static org.junit.Assert.assertTrue; +import static org.junit.Assert.fail; +import org.junit.Rule; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + + + +@RunWith(JUnit4.class) +public class RemoteInferenceTest { + + @Rule + public final transient TestPipeline pipeline = TestPipeline.create(); + + // Test input class + public static class TestInput extends BaseInput { + private final String value; + + private TestInput(String value) { + this.value = value; + } + + public static TestInput create(String value) { + return new TestInput(value); + } + + public String getModelInput() { + return value; + } + + @Override + public boolean equals(Object o) { + if (this == o) + return true; + if (!(o instanceof TestInput)) + return false; + TestInput testInput = (TestInput) o; + return value.equals(testInput.value); + } + + @Override + public int hashCode() { + return value.hashCode(); + } + + @Override + public String toString() { + return "TestInput{value='" + value + "'}"; + } + } + + // Test output class + public static class TestOutput extends BaseResponse { + private final String result; + + private TestOutput(String result) { + this.result = result; + } + + public static TestOutput create(String result) { + return new TestOutput(result); + } + + public String getModelResponse() { + return result; + } + + @Override + public boolean equals(Object o) { + if (this == o) + return true; + if (!(o instanceof TestOutput)) + return false; + TestOutput that = (TestOutput) o; + return result.equals(that.result); + } + + @Override + public int hashCode() { + return result.hashCode(); + } + + @Override + public String toString() { + return "TestOutput{result='" + result + "'}"; + } + } + + // Test parameters class + public static class TestParameters implements BaseModelParameters { + private final String config; + + private TestParameters(Builder builder) { + this.config = builder.config; + } + + public String getConfig() { + return config; + } + + @Override + public String toString() { + return "TestParameters{config='" + config + "'}"; + } + + @Override + public boolean equals(Object o) { + if (this == o) + return true; + if (!(o instanceof TestParameters)) + return false; + TestParameters that = (TestParameters) o; + return config.equals(that.config); + } + + @Override + public int hashCode() { + return config.hashCode(); + } + + // Builder + public static class Builder { + private String config; + + public Builder setConfig(String config) { + this.config = config; + return this; + } + + public TestParameters build() { + return new TestParameters(this); + } + } + + public static Builder builder() { + return new Builder(); + } + } + + // Mock handler for successful inference + public static class MockSuccessHandler + implements BaseModelHandler { + + private TestParameters parameters; + private boolean clientCreated = false; + + @Override + public void createClient(TestParameters parameters) { + this.parameters = parameters; + this.clientCreated = true; + } + + @Override + public Iterable> request(List input) { + if (!clientCreated) { + throw new IllegalStateException("Client not initialized"); + } + return input.stream() + .map(i -> PredictionResult.create( + i, + new TestOutput("processed-" + i.getModelInput()))) + .collect(Collectors.toList()); + } + } + + // Mock handler that returns empty results + public static class MockEmptyResultHandler + implements BaseModelHandler { + + @Override + public void createClient(TestParameters parameters) { + // Setup succeeds + } + + @Override + public Iterable> request(List input) { + return Collections.emptyList(); + } + } + + // Mock handler that throws exception during setup + public static class MockFailingSetupHandler + implements BaseModelHandler { + + @Override + public void createClient(TestParameters parameters) { + throw new RuntimeException("Setup failed intentionally"); + } + + @Override + public Iterable> request(List input) { + return Collections.emptyList(); + } + } + + // Mock handler that throws exception during request + public static class MockFailingRequestHandler + implements BaseModelHandler { + + @Override + public void createClient(TestParameters parameters) { + // Setup succeeds + } + + @Override + public Iterable> request(List input) { + throw new RuntimeException("Request failed intentionally"); + } + } + + // Mock handler without default constructor (to test error handling) + public static class MockNoDefaultConstructorHandler + implements BaseModelHandler { + + private final String required; + + public MockNoDefaultConstructorHandler(String required) { + this.required = required; + } + + @Override + public void createClient(TestParameters parameters) { + } + + @Override + public Iterable> request(List input) { + return Collections.emptyList(); + } + } + + @Test + public void testInvokeWithSingleElement() { + TestInput input = TestInput.create("test-value"); + TestParameters params = TestParameters.builder() + .setConfig("test-config") + .build(); + + PCollection inputCollection = pipeline.apply(Create.of(input)); + + PCollection>> results = inputCollection + .apply("RemoteInference", + RemoteInference.invoke() + .handler(MockSuccessHandler.class) + .withParameters(params)); + + // Verify the output contains expected predictions + PAssert.thatSingleton(results).satisfies(batch -> { + List> resultList = StreamSupport.stream(batch.spliterator(), false) + .collect(Collectors.toList()); + + assertEquals("Expected exactly 1 result", 1, resultList.size()); + + PredictionResult result = resultList.get(0); + assertEquals("test-value", result.getInput().getModelInput()); + assertEquals("processed-test-value", result.getOutput().getModelResponse()); + + return null; + }); + + pipeline.run().waitUntilFinish(); + } + + @Test + public void testInvokeWithMultipleElements() { + List inputs = Arrays.asList( + new TestInput("input1"), + new TestInput("input2"), + new TestInput("input3")); + + TestParameters params = TestParameters.builder() + .setConfig("test-config") + .build(); + + PCollection inputCollection = pipeline + .apply("CreateInputs", Create.of(inputs).withCoder(SerializableCoder.of(TestInput.class))); + + PCollection>> results = inputCollection + .apply("RemoteInference", + RemoteInference.invoke() + .handler(MockSuccessHandler.class) + .withParameters(params)); + + // Count total results across all batches + PAssert.that(results).satisfies(batches -> { + int totalCount = 0; + for (Iterable> batch : batches) { + for (PredictionResult result : batch) { + totalCount++; + assertTrue("Output should start with 'processed-'", + result.getOutput().getModelResponse().startsWith("processed-")); + assertNotNull("Input should not be null", result.getInput()); + assertNotNull("Output should not be null", result.getOutput()); + } + } + assertEquals("Expected 3 total results", 3, totalCount); + return null; + }); + + pipeline.run().waitUntilFinish(); + } + + @Test + public void testInvokeWithEmptyCollection() { + TestParameters params = TestParameters.builder() + .setConfig("test-config") + .build(); + + PCollection inputCollection = pipeline + .apply("CreateEmptyInput", Create.empty(SerializableCoder.of(TestInput.class))); + + PCollection>> results = inputCollection + .apply("RemoteInference", + RemoteInference.invoke() + .handler(MockSuccessHandler.class) + .withParameters(params)); + + // assertion for empty PCollection + PAssert.that(results).empty(); + + pipeline.run().waitUntilFinish(); + } + + @Test + public void testHandlerReturnsEmptyResults() { + TestInput input = new TestInput("test-value"); + TestParameters params = TestParameters.builder() + .setConfig("test-config") + .build(); + + PCollection inputCollection = pipeline + .apply("CreateInput", Create.of(input).withCoder(SerializableCoder.of(TestInput.class))); + + PCollection>> results = inputCollection + .apply("RemoteInference", + RemoteInference.invoke() + .handler(MockEmptyResultHandler.class) + .withParameters(params)); + + // Verify we still get a result, but it's empty + PAssert.thatSingleton(results).satisfies(batch -> { + List> resultList = StreamSupport.stream(batch.spliterator(), false) + .collect(Collectors.toList()); + assertEquals("Expected empty result list", 0, resultList.size()); + return null; + }); + + pipeline.run().waitUntilFinish(); + } + + @Test + public void testHandlerSetupFailure() { + TestInput input = new TestInput("test-value"); + TestParameters params = TestParameters.builder() + .setConfig("test-config") + .build(); + + PCollection inputCollection = pipeline + .apply("CreateInput", Create.of(input).withCoder(SerializableCoder.of(TestInput.class))); + + inputCollection.apply("RemoteInference", + RemoteInference.invoke() + .handler(MockFailingSetupHandler.class) + .withParameters(params)); + + // Verify pipeline fails with expected error + try { + pipeline.run().waitUntilFinish(); + fail("Expected pipeline to fail due to handler setup failure"); + } catch (Exception e) { + String message = e.getMessage(); + assertTrue("Exception should mention setup failure or handler instantiation failure", + message != null && (message.contains("Setup failed intentionally") || + message.contains("Failed to instantiate handler"))); + } + } + + @Test + public void testHandlerRequestFailure() { + TestInput input = new TestInput("test-value"); + TestParameters params = TestParameters.builder() + .setConfig("test-config") + .build(); + + PCollection inputCollection = pipeline + .apply("CreateInput", Create.of(input).withCoder(SerializableCoder.of(TestInput.class))); + + inputCollection.apply("RemoteInference", + RemoteInference.invoke() + .handler(MockFailingRequestHandler.class) + .withParameters(params)); + + // Verify pipeline fails with expected error + try { + pipeline.run().waitUntilFinish(); + fail("Expected pipeline to fail due to request failure"); + } catch (Exception e) { + String message = e.getMessage(); + assertTrue("Exception should mention request failure", + message != null && message.contains("Request failed intentionally")); + } + } + + @Test + public void testHandlerWithoutDefaultConstructor() { + TestInput input = new TestInput("test-value"); + TestParameters params = TestParameters.builder() + .setConfig("test-config") + .build(); + + PCollection inputCollection = pipeline + .apply("CreateInput", Create.of(input).withCoder(SerializableCoder.of(TestInput.class))); + + inputCollection.apply("RemoteInference", + RemoteInference.invoke() + .handler(MockNoDefaultConstructorHandler.class) + .withParameters(params)); + + // Verify pipeline fails when handler cannot be instantiated + try { + pipeline.run().waitUntilFinish(); + fail("Expected pipeline to fail due to missing default constructor"); + } catch (Exception e) { + String message = e.getMessage(); + assertTrue("Exception should mention handler instantiation failure", + message != null && message.contains("Failed to instantiate handler")); + } + } + + @Test + public void testBuilderPattern() { + TestParameters params = TestParameters.builder() + .setConfig("test-config") + .build(); + + RemoteInference.Invoke transform = RemoteInference.invoke() + .handler(MockSuccessHandler.class) + .withParameters(params); + + assertNotNull("Transform should not be null", transform); + } + + @Test + public void testPredictionResultMapping() { + TestInput input = new TestInput("mapping-test"); + TestParameters params = TestParameters.builder() + .setConfig("test-config") + .build(); + + PCollection inputCollection = pipeline + .apply("CreateInput", Create.of(input).withCoder(SerializableCoder.of(TestInput.class))); + + PCollection>> results = inputCollection + .apply("RemoteInference", + RemoteInference.invoke() + .handler(MockSuccessHandler.class) + .withParameters(params)); + + PAssert.thatSingleton(results).satisfies(batch -> { + for (PredictionResult result : batch) { + // Verify that input is preserved in the result + assertNotNull("Input should not be null", result.getInput()); + assertNotNull("Output should not be null", result.getOutput()); + assertEquals("mapping-test", result.getInput().getModelInput()); + assertTrue("Output should contain input value", + result.getOutput().getModelResponse().contains("mapping-test")); + } + return null; + }); + + pipeline.run().waitUntilFinish(); + } + + // Temporary behaviour until we introduce java BatchElements transform + // to batch elements in RemoteInference + @Test + public void testMultipleInputsProduceSeparateBatches() { + List inputs = Arrays.asList( + new TestInput("input1"), + new TestInput("input2")); + + TestParameters params = TestParameters.builder() + .setConfig("test-config") + .build(); + + PCollection inputCollection = pipeline + .apply("CreateInputs", Create.of(inputs).withCoder(SerializableCoder.of(TestInput.class))); + + PCollection>> results = inputCollection + .apply("RemoteInference", + RemoteInference.invoke() + .handler(MockSuccessHandler.class) + .withParameters(params)); + + PAssert.that(results).satisfies(batches -> { + int batchCount = 0; + for (Iterable> batch : batches) { + batchCount++; + int elementCount = 0; + elementCount += StreamSupport.stream(batch.spliterator(), false).count(); + // Each batch should contain exactly 1 element + assertEquals("Each batch should contain 1 element", 1, elementCount); + } + assertEquals("Expected 2 batches", 2, batchCount); + return null; + }); + + pipeline.run().waitUntilFinish(); + } + + @Test + public void testWithEmptyParameters() { + + pipeline.enableAbandonedNodeEnforcement(false); + + TestInput input = TestInput.create("test-value"); + PCollection inputCollection = pipeline.apply(Create.of(input)); + + IllegalArgumentException thrown = assertThrows( + IllegalArgumentException.class, + () -> inputCollection.apply("RemoteInference", + RemoteInference.invoke() + .handler(MockSuccessHandler.class))); + + assertTrue( + "Expected message to contain 'withParameters() is required', but got: " + thrown.getMessage(), + thrown.getMessage().contains("withParameters() is required")); + } + + @Test + public void testWithEmptyHandler() { + + pipeline.enableAbandonedNodeEnforcement(false); + + TestParameters params = TestParameters.builder() + .setConfig("test-config") + .build(); + + TestInput input = TestInput.create("test-value"); + PCollection inputCollection = pipeline.apply(Create.of(input)); + + IllegalArgumentException thrown = assertThrows( + IllegalArgumentException.class, + () -> inputCollection.apply("RemoteInference", + RemoteInference.invoke() + .withParameters(params))); + + assertTrue( + "Expected message to contain 'handler() is required', but got: " + thrown.getMessage(), + thrown.getMessage().contains("handler() is required")); + } +} diff --git a/sdks/java/ml/remoteinference/src/test/java/org/apache/beam/sdk/ml/remoteinference/openai/OpenAIModelHandlerIT.java b/sdks/java/ml/remoteinference/src/test/java/org/apache/beam/sdk/ml/remoteinference/openai/OpenAIModelHandlerIT.java new file mode 100644 index 000000000000..fb24b090cb68 --- /dev/null +++ b/sdks/java/ml/remoteinference/src/test/java/org/apache/beam/sdk/ml/remoteinference/openai/OpenAIModelHandlerIT.java @@ -0,0 +1,366 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * License); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an AS IS BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.sdk.ml.remoteinference.openai; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; + +import org.apache.beam.sdk.testing.PAssert; +import org.apache.beam.sdk.testing.TestPipeline; +import org.apache.beam.sdk.transforms.Create; +import org.apache.beam.sdk.transforms.MapElements; +import org.apache.beam.sdk.values.PCollection; +import org.apache.beam.sdk.values.TypeDescriptor; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertTrue; +import static org.junit.Assert.fail; +import static org.junit.Assume.assumeNotNull; +import static org.junit.Assume.assumeTrue; +import org.junit.Before; +import org.junit.Rule; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import org.apache.beam.sdk.ml.remoteinference.base.*; +import org.apache.beam.sdk.ml.remoteinference.RemoteInference; +import org.apache.beam.sdk.ml.remoteinference.openai.OpenAIModelHandler.StructuredInputOutput; +import org.apache.beam.sdk.ml.remoteinference.openai.OpenAIModelHandler.Response; + +public class OpenAIModelHandlerIT { + private static final Logger LOG = LoggerFactory.getLogger(OpenAIModelHandlerIT.class); + + @Rule + public final transient TestPipeline pipeline = TestPipeline.create(); + + private String apiKey; + private static final String API_KEY_ENV = "OPENAI_API_KEY"; + private static final String DEFAULT_MODEL = "gpt-4o-mini"; + + + @Before + public void setUp() { + // Get API key + apiKey = System.getenv(API_KEY_ENV); + + // Skip tests if API key is not provided + assumeNotNull( + "OpenAI API key not found. Set " + API_KEY_ENV + + " environment variable to run integration tests.", + apiKey); + assumeTrue("OpenAI API key is empty. Set " + API_KEY_ENV + + " environment variable to run integration tests.", + !apiKey.trim().isEmpty()); + } + + @Test + public void testSentimentAnalysisWithSingleInput() { + String input = "This product is absolutely amazing! I love it!"; + + PCollection inputs = pipeline + .apply("CreateSingleInput", Create.of(input)) + .apply("MapToInput", MapElements + .into(TypeDescriptor.of(OpenAIModelInput.class)) + .via(OpenAIModelInput::create)); + + PCollection>> results = inputs + .apply("SentimentInference", + RemoteInference.invoke() + .handler(OpenAIModelHandler.class) + .withParameters(OpenAIModelParameters.builder() + .apiKey(apiKey) + .modelName(DEFAULT_MODEL) + .instructionPrompt( + "Analyze the sentiment as 'positive' or 'negative'. Return only one word.") + .build())); + + // Verify results + PAssert.that(results).satisfies(batches -> { + int count = 0; + for (Iterable> batch : batches) { + for (PredictionResult result : batch) { + count++; + assertNotNull("Input should not be null", result.getInput()); + assertNotNull("Output should not be null", result.getOutput()); + assertNotNull("Output text should not be null", + result.getOutput().getModelResponse()); + + String sentiment = result.getOutput().getModelResponse().toLowerCase(); + assertTrue("Sentiment should be positive or negative, got: " + sentiment, + sentiment.contains("positive") + || sentiment.contains("negative")); + } + } + assertEquals("Should have exactly 1 result", 1, count); + return null; + }); + + pipeline.run().waitUntilFinish(); + } + + @Test + public void testSentimentAnalysisWithMultipleInputs() { + List inputs = Arrays.asList( + "An excellent B2B SaaS solution that streamlines business processes efficiently.", + "The customer support is terrible. I've been waiting for days without any response.", + "The application works as expected. Installation was straightforward.", + "Really impressed with the innovative features! The AI capabilities are groundbreaking!", + "Mediocre product with occasional glitches. Documentation could be better."); + + PCollection inputCollection = pipeline + .apply("CreateMultipleInputs", Create.of(inputs)) + .apply("MapToInputs", MapElements + .into(TypeDescriptor.of(OpenAIModelInput.class)) + .via(OpenAIModelInput::create)); + + PCollection>> results = inputCollection + .apply("SentimentInference", + RemoteInference.invoke() + .handler(OpenAIModelHandler.class) + .withParameters(OpenAIModelParameters.builder() + .apiKey(apiKey) + .modelName(DEFAULT_MODEL) + .instructionPrompt( + "Analyze sentiment as positive or negative") + .build())); + + // Verify we get results for all inputs + PAssert.that(results).satisfies(batches -> { + int totalCount = 0; + for (Iterable> batch : batches) { + for (PredictionResult result : batch) { + totalCount++; + assertNotNull("Input should not be null", result.getInput()); + assertNotNull("Output should not be null", result.getOutput()); + assertFalse("Output should not be empty", + result.getOutput().getModelResponse().trim().isEmpty()); + } + } + assertEquals("Should have results for all 5 inputs", 5, totalCount); + return null; + }); + + pipeline.run().waitUntilFinish(); + } + + @Test + public void testTextClassification() { + List inputs = Arrays.asList( + "How do I reset my password?", + "Your product is broken and I want a refund!", + "Thank you for the excellent service!"); + + PCollection inputCollection = pipeline + .apply("CreateInputs", Create.of(inputs)) + .apply("MapToInputs", MapElements + .into(TypeDescriptor.of(OpenAIModelInput.class)) + .via(OpenAIModelInput::create)); + + PCollection>> results = inputCollection + .apply("ClassificationInference", + RemoteInference.invoke() + .handler(OpenAIModelHandler.class) + .withParameters(OpenAIModelParameters.builder() + .apiKey(apiKey) + .modelName(DEFAULT_MODEL) + .instructionPrompt( + "Classify each text into one category: 'question', 'complaint', or 'praise'. Return only the category.") + .build())); + + PAssert.that(results).satisfies(batches -> { + List categories = new ArrayList<>(); + for (Iterable> batch : batches) { + for (PredictionResult result : batch) { + String category = result.getOutput().getModelResponse().toLowerCase(); + categories.add(category); + } + } + + assertEquals("Should have 3 categories", 3, categories.size()); + + // Verify expected categories + boolean hasQuestion = categories.stream().anyMatch(c -> c.contains("question")); + boolean hasComplaint = categories.stream().anyMatch(c -> c.contains("complaint")); + boolean hasPraise = categories.stream().anyMatch(c -> c.contains("praise")); + + assertTrue("Should have at least one recognized category", + hasQuestion || hasComplaint || hasPraise); + + return null; + }); + + pipeline.run().waitUntilFinish(); + } + + @Test + public void testInputOutputMapping() { + List inputs = Arrays.asList("apple", "banana", "cherry"); + + PCollection inputCollection = pipeline + .apply("CreateInputs", Create.of(inputs)) + .apply("MapToInputs", MapElements + .into(TypeDescriptor.of(OpenAIModelInput.class)) + .via(OpenAIModelInput::create)); + + PCollection>> results = inputCollection + .apply("MappingInference", + RemoteInference.invoke() + .handler(OpenAIModelHandler.class) + .withParameters(OpenAIModelParameters.builder() + .apiKey(apiKey) + .modelName(DEFAULT_MODEL) + .instructionPrompt( + "Return the input word in uppercase") + .build())); + + // Verify input-output pairing is preserved + PAssert.that(results).satisfies(batches -> { + for (Iterable> batch : batches) { + for (PredictionResult result : batch) { + String input = result.getInput().getModelInput(); + String output = result.getOutput().getModelResponse().toLowerCase(); + + // Verify the output relates to the input + assertTrue("Output should relate to input '" + input + "', got: " + output, + output.contains(input.toLowerCase())); + } + } + return null; + }); + + pipeline.run().waitUntilFinish(); + } + + @Test + public void testWithDifferentModel() { + // Test with a different model + String input = "Explain quantum computing in one sentence."; + + PCollection inputs = pipeline + .apply("CreateInput", Create.of(input)) + .apply("MapToInput", MapElements + .into(TypeDescriptor.of(OpenAIModelInput.class)) + .via(OpenAIModelInput::create)); + + PCollection>> results = inputs + .apply("DifferentModelInference", + RemoteInference.invoke() + .handler(OpenAIModelHandler.class) + .withParameters(OpenAIModelParameters.builder() + .apiKey(apiKey) + .modelName("gpt-5") + .instructionPrompt("Respond concisely") + .build())); + + PAssert.that(results).satisfies(batches -> { + for (Iterable> batch : batches) { + for (PredictionResult result : batch) { + assertNotNull("Output should not be null", + result.getOutput().getModelResponse()); + assertFalse("Output should not be empty", + result.getOutput().getModelResponse().trim().isEmpty()); + } + } + return null; + }); + + pipeline.run().waitUntilFinish(); + } + + @Test + public void testWithInvalidApiKey() { + String input = "Test input"; + + PCollection inputs = pipeline + .apply("CreateInput", Create.of(input)) + .apply("MapToInput", MapElements + .into(TypeDescriptor.of(OpenAIModelInput.class)) + .via(OpenAIModelInput::create)); + + inputs.apply("InvalidKeyInference", + RemoteInference.invoke() + .handler(OpenAIModelHandler.class) + .withParameters(OpenAIModelParameters.builder() + .apiKey("invalid-api-key-12345") + .modelName(DEFAULT_MODEL) + .instructionPrompt("Test") + .build())); + + // Expect pipeline to fail with authentication error + try { + pipeline.run().waitUntilFinish(); + fail("Expected pipeline to fail with invalid API key"); + } catch (Exception e) { + // Expected - verify it's an authentication error + String message = e.getMessage().toLowerCase(); + assertTrue("Exception should mention authentication or API key issue, got: " + message, + message.contains("auth") || + message.contains("api") || + message.contains("key") || + message.contains("401")); + } + } + + /** + * Test with custom instruction formats + */ + @Test + public void testWithJsonOutputFormat() { + String input = "Paris is the capital of France"; + + PCollection inputs = pipeline + .apply("CreateInput", Create.of(input)) + .apply("MapToInput", MapElements + .into(TypeDescriptor.of(OpenAIModelInput.class)) + .via(OpenAIModelInput::create)); + + PCollection>> results = inputs + .apply("JsonFormatInference", + RemoteInference.invoke() + .handler(OpenAIModelHandler.class) + .withParameters(OpenAIModelParameters.builder() + .apiKey(apiKey) + .modelName(DEFAULT_MODEL) + .instructionPrompt( + "Extract the city and country. Return as: City: [city], Country: [country]") + .build())); + + PAssert.that(results).satisfies(batches -> { + for (Iterable> batch : batches) { + for (PredictionResult result : batch) { + String output = result.getOutput().getModelResponse(); + LOG.info("Structured output: " + output); + + // Verify output contains expected information + assertTrue("Output should mention Paris: " + output, + output.toLowerCase().contains("paris")); + assertTrue("Output should mention France: " + output, + output.toLowerCase().contains("france")); + } + } + return null; + }); + + pipeline.run().waitUntilFinish(); + } +} diff --git a/sdks/java/ml/remoteinference/src/test/java/org/apache/beam/sdk/ml/remoteinference/openai/OpenAIModelHandlerTest.java b/sdks/java/ml/remoteinference/src/test/java/org/apache/beam/sdk/ml/remoteinference/openai/OpenAIModelHandlerTest.java new file mode 100644 index 000000000000..3bff5600aa45 --- /dev/null +++ b/sdks/java/ml/remoteinference/src/test/java/org/apache/beam/sdk/ml/remoteinference/openai/OpenAIModelHandlerTest.java @@ -0,0 +1,450 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * License); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an AS IS BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.sdk.ml.remoteinference.openai; + +import java.util.Arrays; +import java.util.Collections; +import java.util.List; +import java.util.stream.Collectors; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertTrue; +import static org.junit.Assert.fail; +import org.junit.Before; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +import org.apache.beam.sdk.ml.remoteinference.base.*; +import org.apache.beam.sdk.ml.remoteinference.openai.OpenAIModelHandler.StructuredInputOutput; +import org.apache.beam.sdk.ml.remoteinference.openai.OpenAIModelHandler.Response; + + + +@RunWith(JUnit4.class) +public class OpenAIModelHandlerTest { + private OpenAIModelParameters testParameters; + + @Before + public void setUp() { + testParameters = OpenAIModelParameters.builder() + .apiKey("test-api-key") + .modelName("gpt-4") + .instructionPrompt("Test instruction") + .build(); + } + + /** + * Fake OpenAiModelHandler for testing. + */ + static class FakeOpenAiModelHandler extends OpenAIModelHandler { + + private boolean clientCreated = false; + private OpenAIModelParameters storedParameters; + private List responsesToReturn; + private RuntimeException exceptionToThrow; + private boolean shouldReturnNull = false; + + public void setResponsesToReturn(List responses) { + this.responsesToReturn = responses; + } + + public void setExceptionToThrow(RuntimeException exception) { + this.exceptionToThrow = exception; + } + + public void setShouldReturnNull(boolean shouldReturnNull) { + this.shouldReturnNull = shouldReturnNull; + } + + public boolean isClientCreated() { + return clientCreated; + } + + public OpenAIModelParameters getStoredParameters() { + return storedParameters; + } + + @Override + public void createClient(OpenAIModelParameters parameters) { + this.storedParameters = parameters; + this.clientCreated = true; + + if (exceptionToThrow != null) { + throw exceptionToThrow; + } + } + + @Override + public Iterable> request( + List input) { + + if (!clientCreated) { + throw new IllegalStateException("Client not initialized"); + } + + if (exceptionToThrow != null) { + throw exceptionToThrow; + } + + if (shouldReturnNull || responsesToReturn == null) { + throw new RuntimeException("Model returned no structured responses"); + } + + StructuredInputOutput structuredOutput = responsesToReturn.get(0); + + if (structuredOutput == null || structuredOutput.responses == null) { + throw new RuntimeException("Model returned no structured responses"); + } + + return structuredOutput.responses.stream() + .map(response -> PredictionResult.create( + OpenAIModelInput.create(response.input), + OpenAIModelResponse.create(response.output))) + .collect(Collectors.toList()); + } + } + + @Test + public void testCreateClient() { + FakeOpenAiModelHandler handler = new FakeOpenAiModelHandler(); + + OpenAIModelParameters params = OpenAIModelParameters.builder() + .apiKey("test-key") + .modelName("gpt-4") + .instructionPrompt("test prompt") + .build(); + + handler.createClient(params); + + assertTrue("Client should be created", handler.isClientCreated()); + assertNotNull("Parameters should be stored", handler.getStoredParameters()); + assertEquals("test-key", handler.getStoredParameters().getApiKey()); + assertEquals("gpt-4", handler.getStoredParameters().getModelName()); + } + + @Test + public void testRequestWithSingleInput() { + FakeOpenAiModelHandler handler = new FakeOpenAiModelHandler(); + + List inputs = Collections.singletonList( + OpenAIModelInput.create("test input")); + + StructuredInputOutput structuredOutput = new StructuredInputOutput(); + Response response = new Response(); + response.input = "test input"; + response.output = "test output"; + structuredOutput.responses = Collections.singletonList(response); + + handler.setResponsesToReturn(Collections.singletonList(structuredOutput)); + handler.createClient(testParameters); + + Iterable> results = handler.request(inputs); + + assertNotNull("Results should not be null", results); + + List> resultList = iterableToList(results); + + assertEquals("Should have 1 result", 1, resultList.size()); + + PredictionResult result = resultList.get(0); + assertEquals("test input", result.getInput().getModelInput()); + assertEquals("test output", result.getOutput().getModelResponse()); + } + + @Test + public void testRequestWithMultipleInputs() { + FakeOpenAiModelHandler handler = new FakeOpenAiModelHandler(); + + List inputs = Arrays.asList( + OpenAIModelInput.create("input1"), + OpenAIModelInput.create("input2"), + OpenAIModelInput.create("input3")); + + StructuredInputOutput structuredOutput = new StructuredInputOutput(); + + Response response1 = new Response(); + response1.input = "input1"; + response1.output = "output1"; + + Response response2 = new Response(); + response2.input = "input2"; + response2.output = "output2"; + + Response response3 = new Response(); + response3.input = "input3"; + response3.output = "output3"; + + structuredOutput.responses = Arrays.asList(response1, response2, response3); + + handler.setResponsesToReturn(Collections.singletonList(structuredOutput)); + handler.createClient(testParameters); + + Iterable> results = handler.request(inputs); + + List> resultList = iterableToList(results); + + assertEquals("Should have 3 results", 3, resultList.size()); + + for (int i = 0; i < 3; i++) { + PredictionResult result = resultList.get(i); + assertEquals("input" + (i + 1), result.getInput().getModelInput()); + assertEquals("output" + (i + 1), result.getOutput().getModelResponse()); + } + } + + @Test + public void testRequestWithEmptyInput() { + FakeOpenAiModelHandler handler = new FakeOpenAiModelHandler(); + + List inputs = Collections.emptyList(); + + StructuredInputOutput structuredOutput = new StructuredInputOutput(); + structuredOutput.responses = Collections.emptyList(); + + handler.setResponsesToReturn(Collections.singletonList(structuredOutput)); + handler.createClient(testParameters); + + Iterable> results = handler.request(inputs); + + List> resultList = iterableToList(results); + assertEquals("Should have 0 results", 0, resultList.size()); + } + + @Test + public void testRequestWithNullStructuredOutput() { + FakeOpenAiModelHandler handler = new FakeOpenAiModelHandler(); + + List inputs = Collections.singletonList( + OpenAIModelInput.create("test input")); + + handler.setShouldReturnNull(true); + handler.createClient(testParameters); + + try { + handler.request(inputs); + fail("Expected RuntimeException when structured output is null"); + } catch (RuntimeException e) { + assertTrue("Exception message should mention no structured responses", + e.getMessage().contains("Model returned no structured responses")); + } + } + + @Test + public void testRequestWithNullResponsesList() { + FakeOpenAiModelHandler handler = new FakeOpenAiModelHandler(); + + List inputs = Collections.singletonList( + OpenAIModelInput.create("test input")); + + StructuredInputOutput structuredOutput = new StructuredInputOutput(); + structuredOutput.responses = null; + + handler.setResponsesToReturn(Collections.singletonList(structuredOutput)); + handler.createClient(testParameters); + + try { + handler.request(inputs); + fail("Expected RuntimeException when responses list is null"); + } catch (RuntimeException e) { + assertTrue("Exception message should mention no structured responses", + e.getMessage().contains("Model returned no structured responses")); + } + } + + @Test + public void testCreateClientFailure() { + FakeOpenAiModelHandler handler = new FakeOpenAiModelHandler(); + handler.setExceptionToThrow(new RuntimeException("Setup failed")); + + try { + handler.createClient(testParameters); + fail("Expected RuntimeException during client creation"); + } catch (RuntimeException e) { + assertEquals("Setup failed", e.getMessage()); + } + } + + @Test + public void testRequestApiFailure() { + FakeOpenAiModelHandler handler = new FakeOpenAiModelHandler(); + + List inputs = Collections.singletonList( + OpenAIModelInput.create("test input")); + + handler.createClient(testParameters); + handler.setExceptionToThrow(new RuntimeException("API Error")); + + try { + handler.request(inputs); + fail("Expected RuntimeException when API fails"); + } catch (RuntimeException e) { + assertEquals("API Error", e.getMessage()); + } + } + + @Test + public void testRequestWithoutClientInitialization() { + FakeOpenAiModelHandler handler = new FakeOpenAiModelHandler(); + + List inputs = Collections.singletonList( + OpenAIModelInput.create("test input")); + + StructuredInputOutput structuredOutput = new StructuredInputOutput(); + Response response = new Response(); + response.input = "test input"; + response.output = "test output"; + structuredOutput.responses = Collections.singletonList(response); + + handler.setResponsesToReturn(Collections.singletonList(structuredOutput)); + + // Don't call createClient + try { + handler.request(inputs); + fail("Expected IllegalStateException when client not initialized"); + } catch (IllegalStateException e) { + assertTrue("Exception should mention client not initialized", + e.getMessage().contains("Client not initialized")); + } + } + + @Test + public void testInputOutputMapping() { + FakeOpenAiModelHandler handler = new FakeOpenAiModelHandler(); + + List inputs = Arrays.asList( + OpenAIModelInput.create("alpha"), + OpenAIModelInput.create("beta")); + + StructuredInputOutput structuredOutput = new StructuredInputOutput(); + + Response response1 = new Response(); + response1.input = "alpha"; + response1.output = "ALPHA"; + + Response response2 = new Response(); + response2.input = "beta"; + response2.output = "BETA"; + + structuredOutput.responses = Arrays.asList(response1, response2); + + handler.setResponsesToReturn(Collections.singletonList(structuredOutput)); + handler.createClient(testParameters); + + Iterable> results = handler.request(inputs); + + List> resultList = iterableToList(results); + + assertEquals(2, resultList.size()); + assertEquals("alpha", resultList.get(0).getInput().getModelInput()); + assertEquals("ALPHA", resultList.get(0).getOutput().getModelResponse()); + + assertEquals("beta", resultList.get(1).getInput().getModelInput()); + assertEquals("BETA", resultList.get(1).getOutput().getModelResponse()); + } + + @Test + public void testParametersBuilder() { + OpenAIModelParameters params = OpenAIModelParameters.builder() + .apiKey("my-api-key") + .modelName("gpt-4-turbo") + .instructionPrompt("Custom instruction") + .build(); + + assertEquals("my-api-key", params.getApiKey()); + assertEquals("gpt-4-turbo", params.getModelName()); + assertEquals("Custom instruction", params.getInstructionPrompt()); + } + + @Test + public void testOpenAIModelInputCreate() { + OpenAIModelInput input = OpenAIModelInput.create("test value"); + + assertNotNull("Input should not be null", input); + assertEquals("test value", input.getModelInput()); + } + + @Test + public void testOpenAIModelResponseCreate() { + OpenAIModelResponse response = OpenAIModelResponse.create("test output"); + + assertNotNull("Response should not be null", response); + assertEquals("test output", response.getModelResponse()); + } + + @Test + public void testStructuredInputOutputStructure() { + Response response = new Response(); + response.input = "test-input"; + response.output = "test-output"; + + assertEquals("test-input", response.input); + assertEquals("test-output", response.output); + + StructuredInputOutput structured = new StructuredInputOutput(); + structured.responses = Collections.singletonList(response); + + assertNotNull("Responses should not be null", structured.responses); + assertEquals("Should have 1 response", 1, structured.responses.size()); + assertEquals("test-input", structured.responses.get(0).input); + } + + @Test + public void testMultipleRequestsWithSameHandler() { + FakeOpenAiModelHandler handler = new FakeOpenAiModelHandler(); + handler.createClient(testParameters); + + // First request + StructuredInputOutput output1 = new StructuredInputOutput(); + Response response1 = new Response(); + response1.input = "first"; + response1.output = "FIRST"; + output1.responses = Collections.singletonList(response1); + handler.setResponsesToReturn(Collections.singletonList(output1)); + + List inputs1 = Collections.singletonList( + OpenAIModelInput.create("first")); + Iterable> results1 = handler.request(inputs1); + + List> resultList1 = iterableToList(results1); + assertEquals("FIRST", resultList1.get(0).getOutput().getModelResponse()); + + // Second request with different data + StructuredInputOutput output2 = new StructuredInputOutput(); + Response response2 = new Response(); + response2.input = "second"; + response2.output = "SECOND"; + output2.responses = Collections.singletonList(response2); + handler.setResponsesToReturn(Collections.singletonList(output2)); + + List inputs2 = Collections.singletonList( + OpenAIModelInput.create("second")); + Iterable> results2 = handler.request(inputs2); + + List> resultList2 = iterableToList(results2); + assertEquals("SECOND", resultList2.get(0).getOutput().getModelResponse()); + } + + // Helper method to convert Iterable to List + private List iterableToList(Iterable iterable) { + List list = new java.util.ArrayList<>(); + iterable.forEach(list::add); + return list; + } +} diff --git a/settings.gradle.kts b/settings.gradle.kts index 72c5194ec93d..fdfc5da6854c 100644 --- a/settings.gradle.kts +++ b/settings.gradle.kts @@ -383,3 +383,5 @@ include("sdks:java:extensions:sql:iceberg") findProject(":sdks:java:extensions:sql:iceberg")?.name = "iceberg" include("examples:java:iceberg") findProject(":examples:java:iceberg")?.name = "iceberg" + +include("sdks:java:ml:remoteinference") \ No newline at end of file