From d511c3db068db3dfb451cb590b9d691104a4fc4f Mon Sep 17 00:00:00 2001 From: Ganeshsivakumar Date: Sat, 25 Oct 2025 13:04:33 +0530 Subject: [PATCH 01/12] ml module setup --- sdks/java/ml/build.gradle | 32 +++++++++++++++++++ sdks/java/ml/remoteinference/build.gradle.kts | 17 ++++++++++ .../ml/remoteinference/RemoteInference.java | 5 +++ .../base/BaseModelHandler.java | 11 +++++++ .../base/BaseModelParameters.java | 7 ++++ settings.gradle.kts | 2 ++ 6 files changed, 74 insertions(+) create mode 100644 sdks/java/ml/build.gradle create mode 100644 sdks/java/ml/remoteinference/build.gradle.kts create mode 100644 sdks/java/ml/remoteinference/src/main/java/org/apache/beam/sdk/ml/remoteinference/RemoteInference.java create mode 100644 sdks/java/ml/remoteinference/src/main/java/org/apache/beam/sdk/ml/remoteinference/base/BaseModelHandler.java create mode 100644 sdks/java/ml/remoteinference/src/main/java/org/apache/beam/sdk/ml/remoteinference/base/BaseModelParameters.java diff --git a/sdks/java/ml/build.gradle b/sdks/java/ml/build.gradle new file mode 100644 index 000000000000..7b6b071fa2a1 --- /dev/null +++ b/sdks/java/ml/build.gradle @@ -0,0 +1,32 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * License); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an AS IS BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +plugins { id 'org.apache.beam.module' } +applyJavaNature( + automaticModuleName: 'org.apache.beam.sdk.ml', +) +provideIntegrationTestingDependencies() +enableJavaPerformanceTesting() + +description = "Apache Beam :: SDKs :: Java :: ML" +ext.summary = "Java ML module" + +dependencies { + + +} diff --git a/sdks/java/ml/remoteinference/build.gradle.kts b/sdks/java/ml/remoteinference/build.gradle.kts new file mode 100644 index 000000000000..7b0cfd933eef --- /dev/null +++ b/sdks/java/ml/remoteinference/build.gradle.kts @@ -0,0 +1,17 @@ +plugins { + id("org.apache.beam.module") + id("java-library") +} + +description = "Apache Beam :: SDKs :: Java :: ML :: RemoteInference" + +dependencies { + // Core Beam SDK + implementation(project(":sdks:java:core")) + + + // testing + testImplementation("junit:junit:4.13.2") + testImplementation(project(":sdks:java:testing:test-utils")) +} + diff --git a/sdks/java/ml/remoteinference/src/main/java/org/apache/beam/sdk/ml/remoteinference/RemoteInference.java b/sdks/java/ml/remoteinference/src/main/java/org/apache/beam/sdk/ml/remoteinference/RemoteInference.java new file mode 100644 index 000000000000..be0edf6d11eb --- /dev/null +++ b/sdks/java/ml/remoteinference/src/main/java/org/apache/beam/sdk/ml/remoteinference/RemoteInference.java @@ -0,0 +1,5 @@ +package org.apache.beam.sdk.ml.remoteinference; + +public class RemoteInference { + +} diff --git a/sdks/java/ml/remoteinference/src/main/java/org/apache/beam/sdk/ml/remoteinference/base/BaseModelHandler.java b/sdks/java/ml/remoteinference/src/main/java/org/apache/beam/sdk/ml/remoteinference/base/BaseModelHandler.java new file mode 100644 index 000000000000..374e45b67078 --- /dev/null +++ b/sdks/java/ml/remoteinference/src/main/java/org/apache/beam/sdk/ml/remoteinference/base/BaseModelHandler.java @@ -0,0 +1,11 @@ +package org.apache.beam.sdk.ml.remoteinference.base; + +public interface BaseModelHandler { + + // initialize the model client with provided parameters + public void createClient(ParamT parameters); + + // Logic to invoke model provider + public String request(String input); + +} diff --git a/sdks/java/ml/remoteinference/src/main/java/org/apache/beam/sdk/ml/remoteinference/base/BaseModelParameters.java b/sdks/java/ml/remoteinference/src/main/java/org/apache/beam/sdk/ml/remoteinference/base/BaseModelParameters.java new file mode 100644 index 000000000000..35158e2e4da6 --- /dev/null +++ b/sdks/java/ml/remoteinference/src/main/java/org/apache/beam/sdk/ml/remoteinference/base/BaseModelParameters.java @@ -0,0 +1,7 @@ +package org.apache.beam.sdk.ml.remoteinference.base; + +import java.io.Serializable; + +public interface BaseModelParameters extends Serializable { + +} diff --git a/settings.gradle.kts b/settings.gradle.kts index 72c5194ec93d..fdfc5da6854c 100644 --- a/settings.gradle.kts +++ b/settings.gradle.kts @@ -383,3 +383,5 @@ include("sdks:java:extensions:sql:iceberg") findProject(":sdks:java:extensions:sql:iceberg")?.name = "iceberg" include("examples:java:iceberg") findProject(":examples:java:iceberg")?.name = "iceberg" + +include("sdks:java:ml:remoteinference") \ No newline at end of file From 20086a8e8cf9c7b04ed8cb25a0f1ceb3495303ac Mon Sep 17 00:00:00 2001 From: Ganeshsivakumar Date: Sat, 25 Oct 2025 20:27:10 +0530 Subject: [PATCH 02/12] openai handler, remoteinference impl --- sdks/java/ml/remoteinference/build.gradle.kts | 3 + .../ml/remoteinference/RemoteInference.java | 86 +++++++++++++++++++ .../ml/remoteinference/base/BaseInput.java | 7 ++ .../base/BaseModelHandler.java | 6 +- .../ml/remoteinference/base/BaseResponse.java | 7 ++ .../openai/OpenAiModelHandler.java | 44 ++++++++++ .../openai/OpenAiModelInput.java | 22 +++++ .../openai/OpenAiModelParameters.java | 48 +++++++++++ .../openai/OpenAiModelResponse.java | 26 ++++++ 9 files changed, 246 insertions(+), 3 deletions(-) create mode 100644 sdks/java/ml/remoteinference/src/main/java/org/apache/beam/sdk/ml/remoteinference/base/BaseInput.java create mode 100644 sdks/java/ml/remoteinference/src/main/java/org/apache/beam/sdk/ml/remoteinference/base/BaseResponse.java create mode 100644 sdks/java/ml/remoteinference/src/main/java/org/apache/beam/sdk/ml/remoteinference/openai/OpenAiModelHandler.java create mode 100644 sdks/java/ml/remoteinference/src/main/java/org/apache/beam/sdk/ml/remoteinference/openai/OpenAiModelInput.java create mode 100644 sdks/java/ml/remoteinference/src/main/java/org/apache/beam/sdk/ml/remoteinference/openai/OpenAiModelParameters.java create mode 100644 sdks/java/ml/remoteinference/src/main/java/org/apache/beam/sdk/ml/remoteinference/openai/OpenAiModelResponse.java diff --git a/sdks/java/ml/remoteinference/build.gradle.kts b/sdks/java/ml/remoteinference/build.gradle.kts index 7b0cfd933eef..f3329d864438 100644 --- a/sdks/java/ml/remoteinference/build.gradle.kts +++ b/sdks/java/ml/remoteinference/build.gradle.kts @@ -9,6 +9,9 @@ dependencies { // Core Beam SDK implementation(project(":sdks:java:core")) + implementation("com.openai:openai-java:4.3.0") + implementation("com.google.auto.value:auto-value:1.11.0") + implementation("com.google.auto.value:auto-value-annotations:1.11.0") // testing testImplementation("junit:junit:4.13.2") diff --git a/sdks/java/ml/remoteinference/src/main/java/org/apache/beam/sdk/ml/remoteinference/RemoteInference.java b/sdks/java/ml/remoteinference/src/main/java/org/apache/beam/sdk/ml/remoteinference/RemoteInference.java index be0edf6d11eb..0d0f6611104c 100644 --- a/sdks/java/ml/remoteinference/src/main/java/org/apache/beam/sdk/ml/remoteinference/RemoteInference.java +++ b/sdks/java/ml/remoteinference/src/main/java/org/apache/beam/sdk/ml/remoteinference/RemoteInference.java @@ -1,5 +1,91 @@ package org.apache.beam.sdk.ml.remoteinference; +import org.checkerframework.checker.nullness.qual.Nullable; + +import org.apache.beam.sdk.ml.remoteinference.base.BaseInput; +import org.apache.beam.sdk.ml.remoteinference.base.BaseModelHandler; +import org.apache.beam.sdk.ml.remoteinference.base.BaseModelParameters; +import org.apache.beam.sdk.ml.remoteinference.base.BaseResponse; +import org.apache.beam.sdk.transforms.DoFn; +import org.apache.beam.sdk.transforms.PTransform; +import org.apache.beam.sdk.transforms.ParDo; +import org.apache.beam.sdk.values.PCollection; + + +import com.google.auto.value.AutoValue; + +@SuppressWarnings({ "rawtypes", "unchecked" }) public class RemoteInference { + public static Invoke invoke() { + return new AutoValue_RemoteInference_Invoke.Builder().setParameters(null) + .build(); + } + + private RemoteInference() { + } + + @AutoValue + public abstract static class Invoke + extends PTransform, PCollection> { + + abstract @Nullable Class handler(); + + abstract @Nullable BaseModelParameters parameters(); + + abstract Builder builder(); + + @AutoValue.Builder + abstract static class Builder { + + abstract Builder setHandler(Class modelHandler); + + abstract Builder setParameters(BaseModelParameters modelParameters); + + abstract Invoke build(); + } + + public Invoke handler(Class modelHandler) { + return builder().setHandler(modelHandler).build(); + } + + public Invoke withParameters(BaseModelParameters modelParameters) { + return builder().setParameters(modelParameters).build(); + } + + @Override + public PCollection expand(PCollection input) { + return input.apply(ParDo.of(new RemoteInferenceFn<>(this))); + } + + static class RemoteInferenceFn + extends DoFn { + + private final Class handlerClass; + private final BaseModelParameters parameters; + private transient BaseModelHandler handler; + + RemoteInferenceFn(Invoke spec) { + this.handlerClass = spec.handler(); + this.parameters = spec.parameters(); + } + + @Setup + public void setupHandler() { + try { + this.handler = handlerClass.getDeclaredConstructor().newInstance(); + this.handler.createClient(parameters); + } catch (Exception e) { + throw new RuntimeException("Failed to instantiate handler: " + + handlerClass.getName(), e); + } + } + + @ProcessElement + public void processElement(ProcessContext c) { + OutputT response = (OutputT) this.handler.request(c.element()); + c.output(response); + } + } + } } diff --git a/sdks/java/ml/remoteinference/src/main/java/org/apache/beam/sdk/ml/remoteinference/base/BaseInput.java b/sdks/java/ml/remoteinference/src/main/java/org/apache/beam/sdk/ml/remoteinference/base/BaseInput.java new file mode 100644 index 000000000000..67d25208d3f7 --- /dev/null +++ b/sdks/java/ml/remoteinference/src/main/java/org/apache/beam/sdk/ml/remoteinference/base/BaseInput.java @@ -0,0 +1,7 @@ +package org.apache.beam.sdk.ml.remoteinference.base; + +import java.io.Serializable; + +public abstract class BaseInput implements Serializable { + +} diff --git a/sdks/java/ml/remoteinference/src/main/java/org/apache/beam/sdk/ml/remoteinference/base/BaseModelHandler.java b/sdks/java/ml/remoteinference/src/main/java/org/apache/beam/sdk/ml/remoteinference/base/BaseModelHandler.java index 374e45b67078..d331491493fa 100644 --- a/sdks/java/ml/remoteinference/src/main/java/org/apache/beam/sdk/ml/remoteinference/base/BaseModelHandler.java +++ b/sdks/java/ml/remoteinference/src/main/java/org/apache/beam/sdk/ml/remoteinference/base/BaseModelHandler.java @@ -1,11 +1,11 @@ package org.apache.beam.sdk.ml.remoteinference.base; -public interface BaseModelHandler { +public interface BaseModelHandler { - // initialize the model client with provided parameters + // initialize the model with provided parameters public void createClient(ParamT parameters); // Logic to invoke model provider - public String request(String input); + public OutputT request(InputT input); } diff --git a/sdks/java/ml/remoteinference/src/main/java/org/apache/beam/sdk/ml/remoteinference/base/BaseResponse.java b/sdks/java/ml/remoteinference/src/main/java/org/apache/beam/sdk/ml/remoteinference/base/BaseResponse.java new file mode 100644 index 000000000000..7c8869a8cd00 --- /dev/null +++ b/sdks/java/ml/remoteinference/src/main/java/org/apache/beam/sdk/ml/remoteinference/base/BaseResponse.java @@ -0,0 +1,7 @@ +package org.apache.beam.sdk.ml.remoteinference.base; + +import java.io.Serializable; + +public abstract class BaseResponse implements Serializable { + +} diff --git a/sdks/java/ml/remoteinference/src/main/java/org/apache/beam/sdk/ml/remoteinference/openai/OpenAiModelHandler.java b/sdks/java/ml/remoteinference/src/main/java/org/apache/beam/sdk/ml/remoteinference/openai/OpenAiModelHandler.java new file mode 100644 index 000000000000..d9fd247c69c3 --- /dev/null +++ b/sdks/java/ml/remoteinference/src/main/java/org/apache/beam/sdk/ml/remoteinference/openai/OpenAiModelHandler.java @@ -0,0 +1,44 @@ +package org.apache.beam.sdk.ml.remoteinference.openai; + +import com.openai.client.OpenAIClient; +import com.openai.client.okhttp.OpenAIOkHttpClient; +import com.openai.models.responses.ResponseCreateParams; +import org.apache.beam.sdk.ml.remoteinference.base.BaseModelHandler; + +import java.util.stream.Collectors; + +public class OpenAiModelHandler + implements BaseModelHandler { + + private transient OpenAIClient client; + private transient ResponseCreateParams clientParams; + private OpenAiModelParameters modelParameters; + + @Override + public void createClient(OpenAiModelParameters parameters) { + this.modelParameters = parameters; + this.client = OpenAIOkHttpClient.builder() + .apiKey(this.modelParameters.getApiKey()) + .build(); + } + + @Override + public OpenAiModelResponse request(OpenAiModelInput input) { + + this.clientParams = ResponseCreateParams.builder() + .model(this.modelParameters.getModelName()) + .input(input.getInput()) + .build(); + + String output = client.responses().create(clientParams).output().stream() + .flatMap(item -> item.message().stream()) + .flatMap(message -> message.content().stream()) + .flatMap(content -> content.outputText().stream()) + .map(outputText -> outputText.text()) + .collect(Collectors.joining()); + + OpenAiModelResponse res = OpenAiModelResponse.create(input.getInput(), output); + return res; + } + +} diff --git a/sdks/java/ml/remoteinference/src/main/java/org/apache/beam/sdk/ml/remoteinference/openai/OpenAiModelInput.java b/sdks/java/ml/remoteinference/src/main/java/org/apache/beam/sdk/ml/remoteinference/openai/OpenAiModelInput.java new file mode 100644 index 000000000000..0a3a3a3d0a50 --- /dev/null +++ b/sdks/java/ml/remoteinference/src/main/java/org/apache/beam/sdk/ml/remoteinference/openai/OpenAiModelInput.java @@ -0,0 +1,22 @@ +package org.apache.beam.sdk.ml.remoteinference.openai; + +import org.apache.beam.sdk.ml.remoteinference.base.BaseInput; + +public class OpenAiModelInput extends BaseInput { + + private final String input; + + private OpenAiModelInput(String input) { + + this.input = input; + } + + public String getInput() { + return input; + } + + public static OpenAiModelInput create(String input) { + return new OpenAiModelInput(input); + } + +} diff --git a/sdks/java/ml/remoteinference/src/main/java/org/apache/beam/sdk/ml/remoteinference/openai/OpenAiModelParameters.java b/sdks/java/ml/remoteinference/src/main/java/org/apache/beam/sdk/ml/remoteinference/openai/OpenAiModelParameters.java new file mode 100644 index 000000000000..1c88799fae2e --- /dev/null +++ b/sdks/java/ml/remoteinference/src/main/java/org/apache/beam/sdk/ml/remoteinference/openai/OpenAiModelParameters.java @@ -0,0 +1,48 @@ +package org.apache.beam.sdk.ml.remoteinference.openai; + +import org.apache.beam.sdk.ml.remoteinference.base.BaseModelParameters; + +public class OpenAiModelParameters implements BaseModelParameters { + + private final String apiKey; + private final String modelName; + + private OpenAiModelParameters(Builder builder) { + this.apiKey = builder.apiKey; + this.modelName = builder.modelName; + } + + public String getApiKey() { + return apiKey; + } + + public String getModelName() { + return modelName; + } + + public static Builder builder() { + return new Builder(); + } + + public static class Builder { + private String apiKey; + private String modelName; + + private Builder() { + } + + public Builder apiKey(String apiKey) { + this.apiKey = apiKey; + return this; + } + + public Builder modelName(String modelName) { + this.modelName = modelName; + return this; + } + + public OpenAiModelParameters build() { + return new OpenAiModelParameters(this); + } + } +} diff --git a/sdks/java/ml/remoteinference/src/main/java/org/apache/beam/sdk/ml/remoteinference/openai/OpenAiModelResponse.java b/sdks/java/ml/remoteinference/src/main/java/org/apache/beam/sdk/ml/remoteinference/openai/OpenAiModelResponse.java new file mode 100644 index 000000000000..fa727d59d291 --- /dev/null +++ b/sdks/java/ml/remoteinference/src/main/java/org/apache/beam/sdk/ml/remoteinference/openai/OpenAiModelResponse.java @@ -0,0 +1,26 @@ +package org.apache.beam.sdk.ml.remoteinference.openai; + +import org.apache.beam.sdk.ml.remoteinference.base.BaseResponse; + +public class OpenAiModelResponse extends BaseResponse { + + private final String input; + private final String output; + + private OpenAiModelResponse(String input, String output) { + this.input = input; + this.output = output; + } + + public String getInput() { + return input; + } + + public String getOutput() { + return output; + } + + public static OpenAiModelResponse create(String input, String output) { + return new OpenAiModelResponse(input, output); + } +} From cf5d0a4063e88d538d4e699254c5344944571910 Mon Sep 17 00:00:00 2001 From: Ganeshsivakumar Date: Sat, 25 Oct 2025 23:34:42 +0530 Subject: [PATCH 03/12] prompt, example pipeline --- sdks/java/ml/remoteinference/build.gradle.kts | 1 + .../openai/OpenAiModelParameters.java | 12 ++++++ .../src/test/java/Example.java | 43 +++++++++++++++++++ 3 files changed, 56 insertions(+) create mode 100644 sdks/java/ml/remoteinference/src/test/java/Example.java diff --git a/sdks/java/ml/remoteinference/build.gradle.kts b/sdks/java/ml/remoteinference/build.gradle.kts index f3329d864438..214d7c5c8b5f 100644 --- a/sdks/java/ml/remoteinference/build.gradle.kts +++ b/sdks/java/ml/remoteinference/build.gradle.kts @@ -14,6 +14,7 @@ dependencies { implementation("com.google.auto.value:auto-value-annotations:1.11.0") // testing + testImplementation(project(":runners:direct-java")) testImplementation("junit:junit:4.13.2") testImplementation(project(":sdks:java:testing:test-utils")) } diff --git a/sdks/java/ml/remoteinference/src/main/java/org/apache/beam/sdk/ml/remoteinference/openai/OpenAiModelParameters.java b/sdks/java/ml/remoteinference/src/main/java/org/apache/beam/sdk/ml/remoteinference/openai/OpenAiModelParameters.java index 1c88799fae2e..b67e75afba1e 100644 --- a/sdks/java/ml/remoteinference/src/main/java/org/apache/beam/sdk/ml/remoteinference/openai/OpenAiModelParameters.java +++ b/sdks/java/ml/remoteinference/src/main/java/org/apache/beam/sdk/ml/remoteinference/openai/OpenAiModelParameters.java @@ -6,10 +6,12 @@ public class OpenAiModelParameters implements BaseModelParameters { private final String apiKey; private final String modelName; + private final String instructionPrompt; private OpenAiModelParameters(Builder builder) { this.apiKey = builder.apiKey; this.modelName = builder.modelName; + this.instructionPrompt = builder.instructionPrompt; } public String getApiKey() { @@ -20,6 +22,10 @@ public String getModelName() { return modelName; } + public String getInstructionPrompt() { + return instructionPrompt; + } + public static Builder builder() { return new Builder(); } @@ -27,6 +33,7 @@ public static Builder builder() { public static class Builder { private String apiKey; private String modelName; + private String instructionPrompt; private Builder() { } @@ -41,6 +48,11 @@ public Builder modelName(String modelName) { return this; } + public Builder instructionPrompt(String prompt) { + this.instructionPrompt = prompt; + return this; + } + public OpenAiModelParameters build() { return new OpenAiModelParameters(this); } diff --git a/sdks/java/ml/remoteinference/src/test/java/Example.java b/sdks/java/ml/remoteinference/src/test/java/Example.java new file mode 100644 index 000000000000..f0ba8f8c3864 --- /dev/null +++ b/sdks/java/ml/remoteinference/src/test/java/Example.java @@ -0,0 +1,43 @@ +import org.apache.beam.runners.direct.DirectRunner; +import org.apache.beam.sdk.Pipeline; +import org.apache.beam.sdk.ml.remoteinference.RemoteInference; +import org.apache.beam.sdk.ml.remoteinference.openai.OpenAiModelHandler; +import org.apache.beam.sdk.ml.remoteinference.openai.OpenAiModelInput; +import org.apache.beam.sdk.ml.remoteinference.openai.OpenAiModelParameters; +import org.apache.beam.sdk.ml.remoteinference.openai.OpenAiModelResponse; +import org.apache.beam.sdk.options.PipelineOptions; +import org.apache.beam.sdk.options.PipelineOptionsFactory; +import org.apache.beam.sdk.transforms.Create; +import org.apache.beam.sdk.transforms.DoFn; +import org.apache.beam.sdk.transforms.MapElements; +import org.apache.beam.sdk.transforms.ParDo; +import org.apache.beam.sdk.values.TypeDescriptor; + +public class Example { + public static void main(String[] args) { + + PipelineOptions options = PipelineOptionsFactory.create(); + options.setRunner(DirectRunner.class); + Pipeline p = Pipeline.create(options); + + p.apply("text", Create.of( + "An excellent B2B SaaS solution that streamlines business processes efficiently. The platform is user-friendly and highly reliable. Overall, it delivers great value for enterprise teams.")) + .apply(MapElements.into(TypeDescriptor.of(OpenAiModelInput.class)) + .via(OpenAiModelInput::create)) + .apply("inference", RemoteInference.invoke() + .handler(OpenAiModelHandler.class) + .withParameters(OpenAiModelParameters.builder() + .apiKey("key") + .modelName("gpt-5-mini") + .instructionPrompt("Analyse sentiment as positive or negative") + .build())) + .apply("print output", ParDo.of(new DoFn() { + @ProcessElement + public void print(ProcessContext c) { + System.out.println("OUTPUT: " + c.element().getOutput()); + } + })); + + p.run(); + } +} From d3f4f18e6e016696bad9845fab0e3af12bd6bc0c Mon Sep 17 00:00:00 2001 From: Ganeshsivakumar Date: Wed, 29 Oct 2025 16:21:53 +0530 Subject: [PATCH 04/12] output format and comments --- sdks/java/ml/remoteinference/build.gradle.kts | 18 ++++++ .../ml/remoteinference/RemoteInference.java | 30 ++++++--- .../ml/remoteinference/base/BaseInput.java | 17 +++++ .../base/BaseModelHandler.java | 19 +++++- .../base/BaseModelParameters.java | 17 +++++ .../ml/remoteinference/base/BaseResponse.java | 17 +++++ .../base/PredictionResult.java | 44 +++++++++++++ .../openai/OpenAIModelHandler.java | 62 +++++++++++++++++++ .../openai/OpenAIModelInput.java | 39 ++++++++++++ ...meters.java => OpenAIModelParameters.java} | 25 ++++++-- .../openai/OpenAIModelResponse.java | 37 +++++++++++ .../openai/OpenAiModelHandler.java | 44 ------------- .../openai/OpenAiModelInput.java | 22 ------- .../openai/OpenAiModelResponse.java | 26 -------- .../src/test/java/Example.java | 24 +++---- 15 files changed, 324 insertions(+), 117 deletions(-) create mode 100644 sdks/java/ml/remoteinference/src/main/java/org/apache/beam/sdk/ml/remoteinference/base/PredictionResult.java create mode 100644 sdks/java/ml/remoteinference/src/main/java/org/apache/beam/sdk/ml/remoteinference/openai/OpenAIModelHandler.java create mode 100644 sdks/java/ml/remoteinference/src/main/java/org/apache/beam/sdk/ml/remoteinference/openai/OpenAIModelInput.java rename sdks/java/ml/remoteinference/src/main/java/org/apache/beam/sdk/ml/remoteinference/openai/{OpenAiModelParameters.java => OpenAIModelParameters.java} (52%) create mode 100644 sdks/java/ml/remoteinference/src/main/java/org/apache/beam/sdk/ml/remoteinference/openai/OpenAIModelResponse.java delete mode 100644 sdks/java/ml/remoteinference/src/main/java/org/apache/beam/sdk/ml/remoteinference/openai/OpenAiModelHandler.java delete mode 100644 sdks/java/ml/remoteinference/src/main/java/org/apache/beam/sdk/ml/remoteinference/openai/OpenAiModelInput.java delete mode 100644 sdks/java/ml/remoteinference/src/main/java/org/apache/beam/sdk/ml/remoteinference/openai/OpenAiModelResponse.java diff --git a/sdks/java/ml/remoteinference/build.gradle.kts b/sdks/java/ml/remoteinference/build.gradle.kts index 214d7c5c8b5f..8c801ca4cae4 100644 --- a/sdks/java/ml/remoteinference/build.gradle.kts +++ b/sdks/java/ml/remoteinference/build.gradle.kts @@ -1,3 +1,21 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * License); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an AS IS BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + plugins { id("org.apache.beam.module") id("java-library") diff --git a/sdks/java/ml/remoteinference/src/main/java/org/apache/beam/sdk/ml/remoteinference/RemoteInference.java b/sdks/java/ml/remoteinference/src/main/java/org/apache/beam/sdk/ml/remoteinference/RemoteInference.java index 0d0f6611104c..6168a0864427 100644 --- a/sdks/java/ml/remoteinference/src/main/java/org/apache/beam/sdk/ml/remoteinference/RemoteInference.java +++ b/sdks/java/ml/remoteinference/src/main/java/org/apache/beam/sdk/ml/remoteinference/RemoteInference.java @@ -1,11 +1,25 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * License); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an AS IS BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ package org.apache.beam.sdk.ml.remoteinference; +import org.apache.beam.sdk.ml.remoteinference.base.*; import org.checkerframework.checker.nullness.qual.Nullable; -import org.apache.beam.sdk.ml.remoteinference.base.BaseInput; -import org.apache.beam.sdk.ml.remoteinference.base.BaseModelHandler; -import org.apache.beam.sdk.ml.remoteinference.base.BaseModelParameters; -import org.apache.beam.sdk.ml.remoteinference.base.BaseResponse; import org.apache.beam.sdk.transforms.DoFn; import org.apache.beam.sdk.transforms.PTransform; import org.apache.beam.sdk.transforms.ParDo; @@ -27,7 +41,7 @@ private RemoteInference() { @AutoValue public abstract static class Invoke - extends PTransform, PCollection> { + extends PTransform, PCollection>>> { abstract @Nullable Class handler(); @@ -54,12 +68,12 @@ public Invoke withParameters(BaseModelParameters modelParameter } @Override - public PCollection expand(PCollection input) { + public PCollection>> expand(PCollection input) { return input.apply(ParDo.of(new RemoteInferenceFn<>(this))); } static class RemoteInferenceFn - extends DoFn { + extends DoFn>> { private final Class handlerClass; private final BaseModelParameters parameters; @@ -83,7 +97,7 @@ public void setupHandler() { @ProcessElement public void processElement(ProcessContext c) { - OutputT response = (OutputT) this.handler.request(c.element()); + Iterable> response = this.handler.request(c.element()); c.output(response); } } diff --git a/sdks/java/ml/remoteinference/src/main/java/org/apache/beam/sdk/ml/remoteinference/base/BaseInput.java b/sdks/java/ml/remoteinference/src/main/java/org/apache/beam/sdk/ml/remoteinference/base/BaseInput.java index 67d25208d3f7..939406722b8f 100644 --- a/sdks/java/ml/remoteinference/src/main/java/org/apache/beam/sdk/ml/remoteinference/base/BaseInput.java +++ b/sdks/java/ml/remoteinference/src/main/java/org/apache/beam/sdk/ml/remoteinference/base/BaseInput.java @@ -1,3 +1,20 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * License); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an AS IS BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ package org.apache.beam.sdk.ml.remoteinference.base; import java.io.Serializable; diff --git a/sdks/java/ml/remoteinference/src/main/java/org/apache/beam/sdk/ml/remoteinference/base/BaseModelHandler.java b/sdks/java/ml/remoteinference/src/main/java/org/apache/beam/sdk/ml/remoteinference/base/BaseModelHandler.java index d331491493fa..b5f1aa06d7cb 100644 --- a/sdks/java/ml/remoteinference/src/main/java/org/apache/beam/sdk/ml/remoteinference/base/BaseModelHandler.java +++ b/sdks/java/ml/remoteinference/src/main/java/org/apache/beam/sdk/ml/remoteinference/base/BaseModelHandler.java @@ -1,3 +1,20 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * License); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an AS IS BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ package org.apache.beam.sdk.ml.remoteinference.base; public interface BaseModelHandler { @@ -6,6 +23,6 @@ public interface BaseModelHandler> request(InputT input); } diff --git a/sdks/java/ml/remoteinference/src/main/java/org/apache/beam/sdk/ml/remoteinference/base/BaseModelParameters.java b/sdks/java/ml/remoteinference/src/main/java/org/apache/beam/sdk/ml/remoteinference/base/BaseModelParameters.java index 35158e2e4da6..180ad2c91220 100644 --- a/sdks/java/ml/remoteinference/src/main/java/org/apache/beam/sdk/ml/remoteinference/base/BaseModelParameters.java +++ b/sdks/java/ml/remoteinference/src/main/java/org/apache/beam/sdk/ml/remoteinference/base/BaseModelParameters.java @@ -1,3 +1,20 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * License); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an AS IS BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ package org.apache.beam.sdk.ml.remoteinference.base; import java.io.Serializable; diff --git a/sdks/java/ml/remoteinference/src/main/java/org/apache/beam/sdk/ml/remoteinference/base/BaseResponse.java b/sdks/java/ml/remoteinference/src/main/java/org/apache/beam/sdk/ml/remoteinference/base/BaseResponse.java index 7c8869a8cd00..2e7af0e7c2c1 100644 --- a/sdks/java/ml/remoteinference/src/main/java/org/apache/beam/sdk/ml/remoteinference/base/BaseResponse.java +++ b/sdks/java/ml/remoteinference/src/main/java/org/apache/beam/sdk/ml/remoteinference/base/BaseResponse.java @@ -1,3 +1,20 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * License); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an AS IS BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ package org.apache.beam.sdk.ml.remoteinference.base; import java.io.Serializable; diff --git a/sdks/java/ml/remoteinference/src/main/java/org/apache/beam/sdk/ml/remoteinference/base/PredictionResult.java b/sdks/java/ml/remoteinference/src/main/java/org/apache/beam/sdk/ml/remoteinference/base/PredictionResult.java new file mode 100644 index 000000000000..b19f64917479 --- /dev/null +++ b/sdks/java/ml/remoteinference/src/main/java/org/apache/beam/sdk/ml/remoteinference/base/PredictionResult.java @@ -0,0 +1,44 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * License); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an AS IS BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.sdk.ml.remoteinference.base; + +import java.io.Serializable; + +public class PredictionResult implements Serializable { + + private final InputT input; + private final OutputT output; + + private PredictionResult(InputT input, OutputT output) { + this.input = input; + this.output = output; + + } + + public InputT getInput() { + return input; + } + + public OutputT getOutput() { + return output; + } + + public static PredictionResult create(InputT input, OutputT output) { + return new PredictionResult<>(input, output); + } +} diff --git a/sdks/java/ml/remoteinference/src/main/java/org/apache/beam/sdk/ml/remoteinference/openai/OpenAIModelHandler.java b/sdks/java/ml/remoteinference/src/main/java/org/apache/beam/sdk/ml/remoteinference/openai/OpenAIModelHandler.java new file mode 100644 index 000000000000..20498fe619df --- /dev/null +++ b/sdks/java/ml/remoteinference/src/main/java/org/apache/beam/sdk/ml/remoteinference/openai/OpenAIModelHandler.java @@ -0,0 +1,62 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * License); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an AS IS BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.sdk.ml.remoteinference.openai; + +import com.openai.client.OpenAIClient; +import com.openai.client.okhttp.OpenAIOkHttpClient; +import com.openai.models.responses.ResponseCreateParams; +import org.apache.beam.sdk.ml.remoteinference.base.BaseModelHandler; +import org.apache.beam.sdk.ml.remoteinference.base.PredictionResult; + +import java.util.List; +import java.util.stream.Collectors; + +public class OpenAIModelHandler + implements BaseModelHandler { + + private transient OpenAIClient client; + private transient ResponseCreateParams clientParams; + private OpenAIModelParameters modelParameters; + + @Override + public void createClient(OpenAIModelParameters parameters) { + this.modelParameters = parameters; + this.client = OpenAIOkHttpClient.builder() + .apiKey(this.modelParameters.getApiKey()) + .build(); + } + + @Override + public Iterable> request(OpenAIModelInput input) { + + this.clientParams = ResponseCreateParams.builder() + .model(this.modelParameters.getModelName()) + .input(input.getInput()) + .build(); + + String output = client.responses().create(clientParams).output().stream() + .flatMap(item -> item.message().stream()) + .flatMap(message -> message.content().stream()) + .flatMap(content -> content.outputText().stream()) + .map(outputText -> outputText.text()) + .collect(Collectors.joining()); + + return List.of(PredictionResult.create(input, OpenAIModelResponse.create(output))); + } + +} diff --git a/sdks/java/ml/remoteinference/src/main/java/org/apache/beam/sdk/ml/remoteinference/openai/OpenAIModelInput.java b/sdks/java/ml/remoteinference/src/main/java/org/apache/beam/sdk/ml/remoteinference/openai/OpenAIModelInput.java new file mode 100644 index 000000000000..0500832def3f --- /dev/null +++ b/sdks/java/ml/remoteinference/src/main/java/org/apache/beam/sdk/ml/remoteinference/openai/OpenAIModelInput.java @@ -0,0 +1,39 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * License); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an AS IS BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.sdk.ml.remoteinference.openai; + +import org.apache.beam.sdk.ml.remoteinference.base.BaseInput; + +public class OpenAIModelInput extends BaseInput { + + private final String input; + + private OpenAIModelInput(String input) { + + this.input = input; + } + + public String getInput() { + return input; + } + + public static OpenAIModelInput create(String input) { + return new OpenAIModelInput(input); + } + +} diff --git a/sdks/java/ml/remoteinference/src/main/java/org/apache/beam/sdk/ml/remoteinference/openai/OpenAiModelParameters.java b/sdks/java/ml/remoteinference/src/main/java/org/apache/beam/sdk/ml/remoteinference/openai/OpenAIModelParameters.java similarity index 52% rename from sdks/java/ml/remoteinference/src/main/java/org/apache/beam/sdk/ml/remoteinference/openai/OpenAiModelParameters.java rename to sdks/java/ml/remoteinference/src/main/java/org/apache/beam/sdk/ml/remoteinference/openai/OpenAIModelParameters.java index b67e75afba1e..985c6ceddc7b 100644 --- a/sdks/java/ml/remoteinference/src/main/java/org/apache/beam/sdk/ml/remoteinference/openai/OpenAiModelParameters.java +++ b/sdks/java/ml/remoteinference/src/main/java/org/apache/beam/sdk/ml/remoteinference/openai/OpenAIModelParameters.java @@ -1,14 +1,31 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * License); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an AS IS BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ package org.apache.beam.sdk.ml.remoteinference.openai; import org.apache.beam.sdk.ml.remoteinference.base.BaseModelParameters; -public class OpenAiModelParameters implements BaseModelParameters { +public class OpenAIModelParameters implements BaseModelParameters { private final String apiKey; private final String modelName; private final String instructionPrompt; - private OpenAiModelParameters(Builder builder) { + private OpenAIModelParameters(Builder builder) { this.apiKey = builder.apiKey; this.modelName = builder.modelName; this.instructionPrompt = builder.instructionPrompt; @@ -53,8 +70,8 @@ public Builder instructionPrompt(String prompt) { return this; } - public OpenAiModelParameters build() { - return new OpenAiModelParameters(this); + public OpenAIModelParameters build() { + return new OpenAIModelParameters(this); } } } diff --git a/sdks/java/ml/remoteinference/src/main/java/org/apache/beam/sdk/ml/remoteinference/openai/OpenAIModelResponse.java b/sdks/java/ml/remoteinference/src/main/java/org/apache/beam/sdk/ml/remoteinference/openai/OpenAIModelResponse.java new file mode 100644 index 000000000000..e65513851bc0 --- /dev/null +++ b/sdks/java/ml/remoteinference/src/main/java/org/apache/beam/sdk/ml/remoteinference/openai/OpenAIModelResponse.java @@ -0,0 +1,37 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * License); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an AS IS BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.sdk.ml.remoteinference.openai; + +import org.apache.beam.sdk.ml.remoteinference.base.BaseResponse; + +public class OpenAIModelResponse extends BaseResponse { + + private final String output; + + private OpenAIModelResponse(String output) { + this.output = output; + } + + public String getOutput() { + return output; + } + + public static OpenAIModelResponse create(String output) { + return new OpenAIModelResponse(output); + } +} diff --git a/sdks/java/ml/remoteinference/src/main/java/org/apache/beam/sdk/ml/remoteinference/openai/OpenAiModelHandler.java b/sdks/java/ml/remoteinference/src/main/java/org/apache/beam/sdk/ml/remoteinference/openai/OpenAiModelHandler.java deleted file mode 100644 index d9fd247c69c3..000000000000 --- a/sdks/java/ml/remoteinference/src/main/java/org/apache/beam/sdk/ml/remoteinference/openai/OpenAiModelHandler.java +++ /dev/null @@ -1,44 +0,0 @@ -package org.apache.beam.sdk.ml.remoteinference.openai; - -import com.openai.client.OpenAIClient; -import com.openai.client.okhttp.OpenAIOkHttpClient; -import com.openai.models.responses.ResponseCreateParams; -import org.apache.beam.sdk.ml.remoteinference.base.BaseModelHandler; - -import java.util.stream.Collectors; - -public class OpenAiModelHandler - implements BaseModelHandler { - - private transient OpenAIClient client; - private transient ResponseCreateParams clientParams; - private OpenAiModelParameters modelParameters; - - @Override - public void createClient(OpenAiModelParameters parameters) { - this.modelParameters = parameters; - this.client = OpenAIOkHttpClient.builder() - .apiKey(this.modelParameters.getApiKey()) - .build(); - } - - @Override - public OpenAiModelResponse request(OpenAiModelInput input) { - - this.clientParams = ResponseCreateParams.builder() - .model(this.modelParameters.getModelName()) - .input(input.getInput()) - .build(); - - String output = client.responses().create(clientParams).output().stream() - .flatMap(item -> item.message().stream()) - .flatMap(message -> message.content().stream()) - .flatMap(content -> content.outputText().stream()) - .map(outputText -> outputText.text()) - .collect(Collectors.joining()); - - OpenAiModelResponse res = OpenAiModelResponse.create(input.getInput(), output); - return res; - } - -} diff --git a/sdks/java/ml/remoteinference/src/main/java/org/apache/beam/sdk/ml/remoteinference/openai/OpenAiModelInput.java b/sdks/java/ml/remoteinference/src/main/java/org/apache/beam/sdk/ml/remoteinference/openai/OpenAiModelInput.java deleted file mode 100644 index 0a3a3a3d0a50..000000000000 --- a/sdks/java/ml/remoteinference/src/main/java/org/apache/beam/sdk/ml/remoteinference/openai/OpenAiModelInput.java +++ /dev/null @@ -1,22 +0,0 @@ -package org.apache.beam.sdk.ml.remoteinference.openai; - -import org.apache.beam.sdk.ml.remoteinference.base.BaseInput; - -public class OpenAiModelInput extends BaseInput { - - private final String input; - - private OpenAiModelInput(String input) { - - this.input = input; - } - - public String getInput() { - return input; - } - - public static OpenAiModelInput create(String input) { - return new OpenAiModelInput(input); - } - -} diff --git a/sdks/java/ml/remoteinference/src/main/java/org/apache/beam/sdk/ml/remoteinference/openai/OpenAiModelResponse.java b/sdks/java/ml/remoteinference/src/main/java/org/apache/beam/sdk/ml/remoteinference/openai/OpenAiModelResponse.java deleted file mode 100644 index fa727d59d291..000000000000 --- a/sdks/java/ml/remoteinference/src/main/java/org/apache/beam/sdk/ml/remoteinference/openai/OpenAiModelResponse.java +++ /dev/null @@ -1,26 +0,0 @@ -package org.apache.beam.sdk.ml.remoteinference.openai; - -import org.apache.beam.sdk.ml.remoteinference.base.BaseResponse; - -public class OpenAiModelResponse extends BaseResponse { - - private final String input; - private final String output; - - private OpenAiModelResponse(String input, String output) { - this.input = input; - this.output = output; - } - - public String getInput() { - return input; - } - - public String getOutput() { - return output; - } - - public static OpenAiModelResponse create(String input, String output) { - return new OpenAiModelResponse(input, output); - } -} diff --git a/sdks/java/ml/remoteinference/src/test/java/Example.java b/sdks/java/ml/remoteinference/src/test/java/Example.java index f0ba8f8c3864..97111231b379 100644 --- a/sdks/java/ml/remoteinference/src/test/java/Example.java +++ b/sdks/java/ml/remoteinference/src/test/java/Example.java @@ -1,10 +1,10 @@ import org.apache.beam.runners.direct.DirectRunner; import org.apache.beam.sdk.Pipeline; import org.apache.beam.sdk.ml.remoteinference.RemoteInference; -import org.apache.beam.sdk.ml.remoteinference.openai.OpenAiModelHandler; -import org.apache.beam.sdk.ml.remoteinference.openai.OpenAiModelInput; -import org.apache.beam.sdk.ml.remoteinference.openai.OpenAiModelParameters; -import org.apache.beam.sdk.ml.remoteinference.openai.OpenAiModelResponse; +import org.apache.beam.sdk.ml.remoteinference.openai.OpenAIModelHandler; +import org.apache.beam.sdk.ml.remoteinference.openai.OpenAIModelInput; +import org.apache.beam.sdk.ml.remoteinference.openai.OpenAIModelParameters; +import org.apache.beam.sdk.ml.remoteinference.openai.OpenAIModelResponse; import org.apache.beam.sdk.options.PipelineOptions; import org.apache.beam.sdk.options.PipelineOptionsFactory; import org.apache.beam.sdk.transforms.Create; @@ -16,28 +16,28 @@ public class Example { public static void main(String[] args) { - PipelineOptions options = PipelineOptionsFactory.create(); + /*PipelineOptions options = PipelineOptionsFactory.create(); options.setRunner(DirectRunner.class); Pipeline p = Pipeline.create(options); p.apply("text", Create.of( "An excellent B2B SaaS solution that streamlines business processes efficiently. The platform is user-friendly and highly reliable. Overall, it delivers great value for enterprise teams.")) - .apply(MapElements.into(TypeDescriptor.of(OpenAiModelInput.class)) - .via(OpenAiModelInput::create)) - .apply("inference", RemoteInference.invoke() - .handler(OpenAiModelHandler.class) - .withParameters(OpenAiModelParameters.builder() + .apply(MapElements.into(TypeDescriptor.of(OpenAIModelInput.class)) + .via(OpenAIModelInput::create)) + .apply("inference", RemoteInference.invoke() + .handler(OpenAIModelHandler.class) + .withParameters(OpenAIModelParameters.builder() .apiKey("key") .modelName("gpt-5-mini") .instructionPrompt("Analyse sentiment as positive or negative") .build())) - .apply("print output", ParDo.of(new DoFn() { + .apply("print output", ParDo.of(new DoFn() { @ProcessElement public void print(ProcessContext c) { System.out.println("OUTPUT: " + c.element().getOutput()); } })); - p.run(); + p.run();*/ } } From 3935a6dc92dc712b893fd1f928b3cc5da9a1f94b Mon Sep 17 00:00:00 2001 From: Ganeshsivakumar Date: Wed, 29 Oct 2025 16:52:19 +0530 Subject: [PATCH 05/12] add license header --- .../remoteinference/src/test/java/Example.java | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/sdks/java/ml/remoteinference/src/test/java/Example.java b/sdks/java/ml/remoteinference/src/test/java/Example.java index 97111231b379..baf944342fe4 100644 --- a/sdks/java/ml/remoteinference/src/test/java/Example.java +++ b/sdks/java/ml/remoteinference/src/test/java/Example.java @@ -1,3 +1,20 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * License); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an AS IS BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ import org.apache.beam.runners.direct.DirectRunner; import org.apache.beam.sdk.Pipeline; import org.apache.beam.sdk.ml.remoteinference.RemoteInference; From 768472c5748424630cbbbd67a50b6e390a9d04b7 Mon Sep 17 00:00:00 2001 From: Ganeshsivakumar Date: Fri, 31 Oct 2025 22:30:25 +0530 Subject: [PATCH 06/12] batching and structured outputs --- sdks/java/ml/remoteinference/build.gradle.kts | 1 + .../ml/remoteinference/RemoteInference.java | 51 ++++++++++++- .../base/BaseModelHandler.java | 4 +- .../base/BaseModelParameters.java | 1 + .../ml/remoteinference/base/BatchConfig.java | 45 +++++++++++ .../openai/OpenAIModelHandler.java | 75 +++++++++++++++---- .../openai/OpenAIModelParameters.java | 8 ++ 7 files changed, 169 insertions(+), 16 deletions(-) create mode 100644 sdks/java/ml/remoteinference/src/main/java/org/apache/beam/sdk/ml/remoteinference/base/BatchConfig.java diff --git a/sdks/java/ml/remoteinference/build.gradle.kts b/sdks/java/ml/remoteinference/build.gradle.kts index 8c801ca4cae4..f5eefd91514b 100644 --- a/sdks/java/ml/remoteinference/build.gradle.kts +++ b/sdks/java/ml/remoteinference/build.gradle.kts @@ -30,6 +30,7 @@ dependencies { implementation("com.openai:openai-java:4.3.0") implementation("com.google.auto.value:auto-value:1.11.0") implementation("com.google.auto.value:auto-value-annotations:1.11.0") + implementation("com.fasterxml.jackson.core:jackson-core:2.20.0") // testing testImplementation(project(":runners:direct-java")) diff --git a/sdks/java/ml/remoteinference/src/main/java/org/apache/beam/sdk/ml/remoteinference/RemoteInference.java b/sdks/java/ml/remoteinference/src/main/java/org/apache/beam/sdk/ml/remoteinference/RemoteInference.java index 6168a0864427..d0a922f01d8f 100644 --- a/sdks/java/ml/remoteinference/src/main/java/org/apache/beam/sdk/ml/remoteinference/RemoteInference.java +++ b/sdks/java/ml/remoteinference/src/main/java/org/apache/beam/sdk/ml/remoteinference/RemoteInference.java @@ -28,6 +28,9 @@ import com.google.auto.value.AutoValue; +import java.util.ArrayList; +import java.util.List; + @SuppressWarnings({ "rawtypes", "unchecked" }) public class RemoteInference { @@ -47,6 +50,8 @@ public abstract static class Invoke builder(); @AutoValue.Builder @@ -56,6 +61,8 @@ abstract static class Builder setParameters(BaseModelParameters modelParameters); + abstract Builder setBatchConfig(BatchConfig config); + abstract Invoke build(); } @@ -67,13 +74,21 @@ public Invoke withParameters(BaseModelParameters modelParameter return builder().setParameters(modelParameters).build(); } + public Invoke withBatchConfig(BatchConfig config) { + return builder().setBatchConfig(config).build(); + } + @Override public PCollection>> expand(PCollection input) { - return input.apply(ParDo.of(new RemoteInferenceFn<>(this))); + return input.apply(ParDo.of(new BatchElementsFn<>(this.batchConfig() != null ? this.batchConfig() + : this + .parameters() + .defaultBatchConfig()))) + .apply(ParDo.of(new RemoteInferenceFn<>(this))); } static class RemoteInferenceFn - extends DoFn>> { + extends DoFn, Iterable>> { private final Class handlerClass; private final BaseModelParameters parameters; @@ -101,5 +116,37 @@ public void processElement(ProcessContext c) { c.output(response); } } + + public static class BatchElementsFn extends DoFn> { + private final BatchConfig config; + private List batch; + + public BatchElementsFn(BatchConfig config) { + this.config = config; + } + + @StartBundle + public void startBundle() { + batch = new ArrayList<>(); + } + + @ProcessElement + public void processElement(ProcessContext c) { + batch.add(c.element()); + if (batch.size() >= config.getMaxBatchSize()) { + c.output(new ArrayList<>(batch)); + batch.clear(); + } + } + + @FinishBundle + public void finishBundle(FinishBundleContext c) { + if (!batch.isEmpty()) { + c.output(new ArrayList<>(batch), null, null); + batch.clear(); + } + } + + } } } diff --git a/sdks/java/ml/remoteinference/src/main/java/org/apache/beam/sdk/ml/remoteinference/base/BaseModelHandler.java b/sdks/java/ml/remoteinference/src/main/java/org/apache/beam/sdk/ml/remoteinference/base/BaseModelHandler.java index b5f1aa06d7cb..3ad6cdce84e5 100644 --- a/sdks/java/ml/remoteinference/src/main/java/org/apache/beam/sdk/ml/remoteinference/base/BaseModelHandler.java +++ b/sdks/java/ml/remoteinference/src/main/java/org/apache/beam/sdk/ml/remoteinference/base/BaseModelHandler.java @@ -17,12 +17,14 @@ */ package org.apache.beam.sdk.ml.remoteinference.base; +import java.util.List; + public interface BaseModelHandler { // initialize the model with provided parameters public void createClient(ParamT parameters); // Logic to invoke model provider - public Iterable> request(InputT input); + public Iterable> request(List input); } diff --git a/sdks/java/ml/remoteinference/src/main/java/org/apache/beam/sdk/ml/remoteinference/base/BaseModelParameters.java b/sdks/java/ml/remoteinference/src/main/java/org/apache/beam/sdk/ml/remoteinference/base/BaseModelParameters.java index 180ad2c91220..bb56cb74a555 100644 --- a/sdks/java/ml/remoteinference/src/main/java/org/apache/beam/sdk/ml/remoteinference/base/BaseModelParameters.java +++ b/sdks/java/ml/remoteinference/src/main/java/org/apache/beam/sdk/ml/remoteinference/base/BaseModelParameters.java @@ -21,4 +21,5 @@ public interface BaseModelParameters extends Serializable { + public BatchConfig defaultBatchConfig(); } diff --git a/sdks/java/ml/remoteinference/src/main/java/org/apache/beam/sdk/ml/remoteinference/base/BatchConfig.java b/sdks/java/ml/remoteinference/src/main/java/org/apache/beam/sdk/ml/remoteinference/base/BatchConfig.java new file mode 100644 index 000000000000..a4497c1181f1 --- /dev/null +++ b/sdks/java/ml/remoteinference/src/main/java/org/apache/beam/sdk/ml/remoteinference/base/BatchConfig.java @@ -0,0 +1,45 @@ +package org.apache.beam.sdk.ml.remoteinference.base; + +import java.io.Serializable; + +public class BatchConfig implements Serializable { + + private final int minBatchSize; + private final int maxBatchSize; + + private BatchConfig(Builder builder) { + this.minBatchSize = builder.minBatchSize; + this.maxBatchSize = builder.maxBatchSize; + } + + public int getMinBatchSize() { + return minBatchSize; + } + + public int getMaxBatchSize() { + return maxBatchSize; + } + + public static Builder builder() { + return new Builder(); + } + + public static class Builder { + private int minBatchSize; + private int maxBatchSize; + + public Builder minBatchSize(int minBatchSize) { + this.minBatchSize = minBatchSize; + return this; + } + + public Builder maxBatchSize(int maxBatchSize) { + this.maxBatchSize = maxBatchSize; + return this; + } + + public BatchConfig build() { + return new BatchConfig(this); + } + } +} diff --git a/sdks/java/ml/remoteinference/src/main/java/org/apache/beam/sdk/ml/remoteinference/openai/OpenAIModelHandler.java b/sdks/java/ml/remoteinference/src/main/java/org/apache/beam/sdk/ml/remoteinference/openai/OpenAIModelHandler.java index 20498fe619df..f30e820549ae 100644 --- a/sdks/java/ml/remoteinference/src/main/java/org/apache/beam/sdk/ml/remoteinference/openai/OpenAIModelHandler.java +++ b/sdks/java/ml/remoteinference/src/main/java/org/apache/beam/sdk/ml/remoteinference/openai/OpenAIModelHandler.java @@ -17,9 +17,15 @@ */ package org.apache.beam.sdk.ml.remoteinference.openai; +import com.fasterxml.jackson.annotation.JsonProperty; +import com.fasterxml.jackson.annotation.JsonPropertyDescription; +import com.fasterxml.jackson.core.JsonProcessingException; +import com.fasterxml.jackson.databind.ObjectMapper; import com.openai.client.OpenAIClient; import com.openai.client.okhttp.OpenAIOkHttpClient; +import com.openai.core.JsonSchemaLocalValidation; import com.openai.models.responses.ResponseCreateParams; +import com.openai.models.responses.StructuredResponseCreateParams; import org.apache.beam.sdk.ml.remoteinference.base.BaseModelHandler; import org.apache.beam.sdk.ml.remoteinference.base.PredictionResult; @@ -30,7 +36,7 @@ public class OpenAIModelHandler implements BaseModelHandler { private transient OpenAIClient client; - private transient ResponseCreateParams clientParams; + private transient StructuredResponseCreateParams clientParams; private OpenAIModelParameters modelParameters; @Override @@ -42,21 +48,64 @@ public void createClient(OpenAIModelParameters parameters) { } @Override - public Iterable> request(OpenAIModelInput input) { + public Iterable> request(List input) { - this.clientParams = ResponseCreateParams.builder() - .model(this.modelParameters.getModelName()) - .input(input.getInput()) - .build(); + try { + // Convert input list to JSON string + String inputBatch = new ObjectMapper() + .writeValueAsString(input.stream().map(OpenAIModelInput::getInput).toList()); + + // Build structured response parameters + this.clientParams = ResponseCreateParams.builder() + .model(modelParameters.getModelName()) + .input(inputBatch) + .text(StructuredInputOutput.class, JsonSchemaLocalValidation.NO) + .instructions(modelParameters.getInstructionPrompt()) + .build(); + + // Get structured output from the model + StructuredInputOutput structuredOutput = client.responses() + .create(clientParams) + .output() + .stream() + .flatMap(item -> item.message().stream()) + .flatMap(message -> message.content().stream()) + .flatMap(content -> content.outputText().stream()) + .findFirst() + .orElse(null); + + if (structuredOutput == null || structuredOutput.responses == null) { + throw new RuntimeException("Model returned no structured responses"); + } + + // Map responses to PredictionResults + List> results = structuredOutput.responses.stream() + .map(response -> PredictionResult.create( + OpenAIModelInput.create(response.input), + OpenAIModelResponse.create(response.output))) + .collect(Collectors.toList()); - String output = client.responses().create(clientParams).output().stream() - .flatMap(item -> item.message().stream()) - .flatMap(message -> message.content().stream()) - .flatMap(content -> content.outputText().stream()) - .map(outputText -> outputText.text()) - .collect(Collectors.joining()); + return results; + + } catch (JsonProcessingException e) { + throw new RuntimeException("Failed to serialize input batch", e); + } + } + + public static class Response { + @JsonProperty(required = true) + @JsonPropertyDescription("The input string") + public String input; + + @JsonProperty(required = true) + @JsonPropertyDescription("The output string") + public String output; + } - return List.of(PredictionResult.create(input, OpenAIModelResponse.create(output))); + public static class StructuredInputOutput { + @JsonProperty(required = true) + @JsonPropertyDescription("Array of input-output pairs") + public List responses; } } diff --git a/sdks/java/ml/remoteinference/src/main/java/org/apache/beam/sdk/ml/remoteinference/openai/OpenAIModelParameters.java b/sdks/java/ml/remoteinference/src/main/java/org/apache/beam/sdk/ml/remoteinference/openai/OpenAIModelParameters.java index 985c6ceddc7b..e26308694828 100644 --- a/sdks/java/ml/remoteinference/src/main/java/org/apache/beam/sdk/ml/remoteinference/openai/OpenAIModelParameters.java +++ b/sdks/java/ml/remoteinference/src/main/java/org/apache/beam/sdk/ml/remoteinference/openai/OpenAIModelParameters.java @@ -18,17 +18,20 @@ package org.apache.beam.sdk.ml.remoteinference.openai; import org.apache.beam.sdk.ml.remoteinference.base.BaseModelParameters; +import org.apache.beam.sdk.ml.remoteinference.base.BatchConfig; public class OpenAIModelParameters implements BaseModelParameters { private final String apiKey; private final String modelName; private final String instructionPrompt; + private final BatchConfig batchConfig; private OpenAIModelParameters(Builder builder) { this.apiKey = builder.apiKey; this.modelName = builder.modelName; this.instructionPrompt = builder.instructionPrompt; + this.batchConfig = BatchConfig.builder().maxBatchSize(1).minBatchSize(1).build(); } public String getApiKey() { @@ -47,6 +50,11 @@ public static Builder builder() { return new Builder(); } + @Override + public BatchConfig defaultBatchConfig() { + return batchConfig; + } + public static class Builder { private String apiKey; private String modelName; From 6bf1e23a34376e2e221e664fee3b2fab35ff0816 Mon Sep 17 00:00:00 2001 From: Ganeshsivakumar Date: Thu, 6 Nov 2025 12:42:54 +0530 Subject: [PATCH 07/12] unit test and handler IT --- sdks/java/ml/remoteinference/build.gradle.kts | 11 +- .../ml/remoteinference/RemoteInference.java | 59 +- .../base/BaseModelParameters.java | 1 - .../ml/remoteinference/base/BatchConfig.java | 45 -- .../openai/OpenAIModelHandler.java | 2 +- .../openai/OpenAIModelInput.java | 2 +- .../openai/OpenAIModelParameters.java | 7 - .../openai/OpenAIModelResponse.java | 2 +- .../src/test/java/Example.java | 60 -- .../remoteinference/RemoteInferenceTest.java | 586 ++++++++++++++++++ .../openai/OpenAIModelHandlerIT.java | 366 +++++++++++ .../openai/OpenAIModelHandlerTest.java | 450 ++++++++++++++ 12 files changed, 1429 insertions(+), 162 deletions(-) delete mode 100644 sdks/java/ml/remoteinference/src/main/java/org/apache/beam/sdk/ml/remoteinference/base/BatchConfig.java delete mode 100644 sdks/java/ml/remoteinference/src/test/java/Example.java create mode 100644 sdks/java/ml/remoteinference/src/test/java/org/apache/beam/sdk/ml/remoteinference/RemoteInferenceTest.java create mode 100644 sdks/java/ml/remoteinference/src/test/java/org/apache/beam/sdk/ml/remoteinference/openai/OpenAIModelHandlerIT.java create mode 100644 sdks/java/ml/remoteinference/src/test/java/org/apache/beam/sdk/ml/remoteinference/openai/OpenAIModelHandlerTest.java diff --git a/sdks/java/ml/remoteinference/build.gradle.kts b/sdks/java/ml/remoteinference/build.gradle.kts index f5eefd91514b..98e3ce620bb9 100644 --- a/sdks/java/ml/remoteinference/build.gradle.kts +++ b/sdks/java/ml/remoteinference/build.gradle.kts @@ -28,12 +28,19 @@ dependencies { implementation(project(":sdks:java:core")) implementation("com.openai:openai-java:4.3.0") - implementation("com.google.auto.value:auto-value:1.11.0") - implementation("com.google.auto.value:auto-value-annotations:1.11.0") + compileOnly("com.google.auto.value:auto-value-annotations:1.11.0") + compileOnly("org.checkerframework:checker-qual:3.42.0") + annotationProcessor("com.google.auto.value:auto-value:1.11.0") implementation("com.fasterxml.jackson.core:jackson-core:2.20.0") + implementation("org.apache.beam:beam-vendor-guava-32_1_2-jre:0.1") + implementation("org.slf4j:slf4j-api:2.0.9") + implementation("org.slf4j:slf4j-simple:2.0.9") // testing testImplementation(project(":runners:direct-java")) + testImplementation("org.hamcrest:hamcrest:2.2") + testImplementation("org.mockito:mockito-core:5.8.0") + testImplementation("org.mockito:mockito-junit-jupiter:5.8.0") testImplementation("junit:junit:4.13.2") testImplementation(project(":sdks:java:testing:test-utils")) } diff --git a/sdks/java/ml/remoteinference/src/main/java/org/apache/beam/sdk/ml/remoteinference/RemoteInference.java b/sdks/java/ml/remoteinference/src/main/java/org/apache/beam/sdk/ml/remoteinference/RemoteInference.java index d0a922f01d8f..4f05faef2b85 100644 --- a/sdks/java/ml/remoteinference/src/main/java/org/apache/beam/sdk/ml/remoteinference/RemoteInference.java +++ b/sdks/java/ml/remoteinference/src/main/java/org/apache/beam/sdk/ml/remoteinference/RemoteInference.java @@ -18,17 +18,18 @@ package org.apache.beam.sdk.ml.remoteinference; import org.apache.beam.sdk.ml.remoteinference.base.*; +import org.apache.beam.sdk.transforms.*; import org.checkerframework.checker.nullness.qual.Nullable; +import static org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions.checkArgument; + -import org.apache.beam.sdk.transforms.DoFn; -import org.apache.beam.sdk.transforms.PTransform; -import org.apache.beam.sdk.transforms.ParDo; import org.apache.beam.sdk.values.PCollection; import com.google.auto.value.AutoValue; import java.util.ArrayList; +import java.util.Collections; import java.util.List; @SuppressWarnings({ "rawtypes", "unchecked" }) @@ -50,7 +51,6 @@ public abstract static class Invoke builder(); @@ -61,7 +61,6 @@ abstract static class Builder setParameters(BaseModelParameters modelParameters); - abstract Builder setBatchConfig(BatchConfig config); abstract Invoke build(); } @@ -74,17 +73,20 @@ public Invoke withParameters(BaseModelParameters modelParameter return builder().setParameters(modelParameters).build(); } - public Invoke withBatchConfig(BatchConfig config) { - return builder().setBatchConfig(config).build(); - } @Override public PCollection>> expand(PCollection input) { - return input.apply(ParDo.of(new BatchElementsFn<>(this.batchConfig() != null ? this.batchConfig() - : this - .parameters() - .defaultBatchConfig()))) - .apply(ParDo.of(new RemoteInferenceFn<>(this))); + checkArgument(handler() != null, "handler() is required"); + checkArgument(parameters() != null, "withParameters() is required"); + return input + .apply("WrapInputInList", MapElements.via(new SimpleFunction>() { + @Override + public List apply(InputT element) { + return Collections.singletonList(element); + } + })) + // Pass the list to the inference function + .apply("RemoteInference", ParDo.of(new RemoteInferenceFn(this))); } static class RemoteInferenceFn @@ -117,36 +119,5 @@ public void processElement(ProcessContext c) { } } - public static class BatchElementsFn extends DoFn> { - private final BatchConfig config; - private List batch; - - public BatchElementsFn(BatchConfig config) { - this.config = config; - } - - @StartBundle - public void startBundle() { - batch = new ArrayList<>(); - } - - @ProcessElement - public void processElement(ProcessContext c) { - batch.add(c.element()); - if (batch.size() >= config.getMaxBatchSize()) { - c.output(new ArrayList<>(batch)); - batch.clear(); - } - } - - @FinishBundle - public void finishBundle(FinishBundleContext c) { - if (!batch.isEmpty()) { - c.output(new ArrayList<>(batch), null, null); - batch.clear(); - } - } - - } } } diff --git a/sdks/java/ml/remoteinference/src/main/java/org/apache/beam/sdk/ml/remoteinference/base/BaseModelParameters.java b/sdks/java/ml/remoteinference/src/main/java/org/apache/beam/sdk/ml/remoteinference/base/BaseModelParameters.java index bb56cb74a555..180ad2c91220 100644 --- a/sdks/java/ml/remoteinference/src/main/java/org/apache/beam/sdk/ml/remoteinference/base/BaseModelParameters.java +++ b/sdks/java/ml/remoteinference/src/main/java/org/apache/beam/sdk/ml/remoteinference/base/BaseModelParameters.java @@ -21,5 +21,4 @@ public interface BaseModelParameters extends Serializable { - public BatchConfig defaultBatchConfig(); } diff --git a/sdks/java/ml/remoteinference/src/main/java/org/apache/beam/sdk/ml/remoteinference/base/BatchConfig.java b/sdks/java/ml/remoteinference/src/main/java/org/apache/beam/sdk/ml/remoteinference/base/BatchConfig.java deleted file mode 100644 index a4497c1181f1..000000000000 --- a/sdks/java/ml/remoteinference/src/main/java/org/apache/beam/sdk/ml/remoteinference/base/BatchConfig.java +++ /dev/null @@ -1,45 +0,0 @@ -package org.apache.beam.sdk.ml.remoteinference.base; - -import java.io.Serializable; - -public class BatchConfig implements Serializable { - - private final int minBatchSize; - private final int maxBatchSize; - - private BatchConfig(Builder builder) { - this.minBatchSize = builder.minBatchSize; - this.maxBatchSize = builder.maxBatchSize; - } - - public int getMinBatchSize() { - return minBatchSize; - } - - public int getMaxBatchSize() { - return maxBatchSize; - } - - public static Builder builder() { - return new Builder(); - } - - public static class Builder { - private int minBatchSize; - private int maxBatchSize; - - public Builder minBatchSize(int minBatchSize) { - this.minBatchSize = minBatchSize; - return this; - } - - public Builder maxBatchSize(int maxBatchSize) { - this.maxBatchSize = maxBatchSize; - return this; - } - - public BatchConfig build() { - return new BatchConfig(this); - } - } -} diff --git a/sdks/java/ml/remoteinference/src/main/java/org/apache/beam/sdk/ml/remoteinference/openai/OpenAIModelHandler.java b/sdks/java/ml/remoteinference/src/main/java/org/apache/beam/sdk/ml/remoteinference/openai/OpenAIModelHandler.java index f30e820549ae..77138a22d4a6 100644 --- a/sdks/java/ml/remoteinference/src/main/java/org/apache/beam/sdk/ml/remoteinference/openai/OpenAIModelHandler.java +++ b/sdks/java/ml/remoteinference/src/main/java/org/apache/beam/sdk/ml/remoteinference/openai/OpenAIModelHandler.java @@ -53,7 +53,7 @@ public Iterable> request try { // Convert input list to JSON string String inputBatch = new ObjectMapper() - .writeValueAsString(input.stream().map(OpenAIModelInput::getInput).toList()); + .writeValueAsString(input.stream().map(OpenAIModelInput::getModelInput).toList()); // Build structured response parameters this.clientParams = ResponseCreateParams.builder() diff --git a/sdks/java/ml/remoteinference/src/main/java/org/apache/beam/sdk/ml/remoteinference/openai/OpenAIModelInput.java b/sdks/java/ml/remoteinference/src/main/java/org/apache/beam/sdk/ml/remoteinference/openai/OpenAIModelInput.java index 0500832def3f..fabf0d5dd23b 100644 --- a/sdks/java/ml/remoteinference/src/main/java/org/apache/beam/sdk/ml/remoteinference/openai/OpenAIModelInput.java +++ b/sdks/java/ml/remoteinference/src/main/java/org/apache/beam/sdk/ml/remoteinference/openai/OpenAIModelInput.java @@ -28,7 +28,7 @@ private OpenAIModelInput(String input) { this.input = input; } - public String getInput() { + public String getModelInput() { return input; } diff --git a/sdks/java/ml/remoteinference/src/main/java/org/apache/beam/sdk/ml/remoteinference/openai/OpenAIModelParameters.java b/sdks/java/ml/remoteinference/src/main/java/org/apache/beam/sdk/ml/remoteinference/openai/OpenAIModelParameters.java index e26308694828..1c703437633d 100644 --- a/sdks/java/ml/remoteinference/src/main/java/org/apache/beam/sdk/ml/remoteinference/openai/OpenAIModelParameters.java +++ b/sdks/java/ml/remoteinference/src/main/java/org/apache/beam/sdk/ml/remoteinference/openai/OpenAIModelParameters.java @@ -18,20 +18,17 @@ package org.apache.beam.sdk.ml.remoteinference.openai; import org.apache.beam.sdk.ml.remoteinference.base.BaseModelParameters; -import org.apache.beam.sdk.ml.remoteinference.base.BatchConfig; public class OpenAIModelParameters implements BaseModelParameters { private final String apiKey; private final String modelName; private final String instructionPrompt; - private final BatchConfig batchConfig; private OpenAIModelParameters(Builder builder) { this.apiKey = builder.apiKey; this.modelName = builder.modelName; this.instructionPrompt = builder.instructionPrompt; - this.batchConfig = BatchConfig.builder().maxBatchSize(1).minBatchSize(1).build(); } public String getApiKey() { @@ -50,10 +47,6 @@ public static Builder builder() { return new Builder(); } - @Override - public BatchConfig defaultBatchConfig() { - return batchConfig; - } public static class Builder { private String apiKey; diff --git a/sdks/java/ml/remoteinference/src/main/java/org/apache/beam/sdk/ml/remoteinference/openai/OpenAIModelResponse.java b/sdks/java/ml/remoteinference/src/main/java/org/apache/beam/sdk/ml/remoteinference/openai/OpenAIModelResponse.java index e65513851bc0..e37543ac1a10 100644 --- a/sdks/java/ml/remoteinference/src/main/java/org/apache/beam/sdk/ml/remoteinference/openai/OpenAIModelResponse.java +++ b/sdks/java/ml/remoteinference/src/main/java/org/apache/beam/sdk/ml/remoteinference/openai/OpenAIModelResponse.java @@ -27,7 +27,7 @@ private OpenAIModelResponse(String output) { this.output = output; } - public String getOutput() { + public String getModelResponse() { return output; } diff --git a/sdks/java/ml/remoteinference/src/test/java/Example.java b/sdks/java/ml/remoteinference/src/test/java/Example.java deleted file mode 100644 index baf944342fe4..000000000000 --- a/sdks/java/ml/remoteinference/src/test/java/Example.java +++ /dev/null @@ -1,60 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * License); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an AS IS BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -import org.apache.beam.runners.direct.DirectRunner; -import org.apache.beam.sdk.Pipeline; -import org.apache.beam.sdk.ml.remoteinference.RemoteInference; -import org.apache.beam.sdk.ml.remoteinference.openai.OpenAIModelHandler; -import org.apache.beam.sdk.ml.remoteinference.openai.OpenAIModelInput; -import org.apache.beam.sdk.ml.remoteinference.openai.OpenAIModelParameters; -import org.apache.beam.sdk.ml.remoteinference.openai.OpenAIModelResponse; -import org.apache.beam.sdk.options.PipelineOptions; -import org.apache.beam.sdk.options.PipelineOptionsFactory; -import org.apache.beam.sdk.transforms.Create; -import org.apache.beam.sdk.transforms.DoFn; -import org.apache.beam.sdk.transforms.MapElements; -import org.apache.beam.sdk.transforms.ParDo; -import org.apache.beam.sdk.values.TypeDescriptor; - -public class Example { - public static void main(String[] args) { - - /*PipelineOptions options = PipelineOptionsFactory.create(); - options.setRunner(DirectRunner.class); - Pipeline p = Pipeline.create(options); - - p.apply("text", Create.of( - "An excellent B2B SaaS solution that streamlines business processes efficiently. The platform is user-friendly and highly reliable. Overall, it delivers great value for enterprise teams.")) - .apply(MapElements.into(TypeDescriptor.of(OpenAIModelInput.class)) - .via(OpenAIModelInput::create)) - .apply("inference", RemoteInference.invoke() - .handler(OpenAIModelHandler.class) - .withParameters(OpenAIModelParameters.builder() - .apiKey("key") - .modelName("gpt-5-mini") - .instructionPrompt("Analyse sentiment as positive or negative") - .build())) - .apply("print output", ParDo.of(new DoFn() { - @ProcessElement - public void print(ProcessContext c) { - System.out.println("OUTPUT: " + c.element().getOutput()); - } - })); - - p.run();*/ - } -} diff --git a/sdks/java/ml/remoteinference/src/test/java/org/apache/beam/sdk/ml/remoteinference/RemoteInferenceTest.java b/sdks/java/ml/remoteinference/src/test/java/org/apache/beam/sdk/ml/remoteinference/RemoteInferenceTest.java new file mode 100644 index 000000000000..3f351b9f88a3 --- /dev/null +++ b/sdks/java/ml/remoteinference/src/test/java/org/apache/beam/sdk/ml/remoteinference/RemoteInferenceTest.java @@ -0,0 +1,586 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * License); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an AS IS BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.sdk.ml.remoteinference; + +import java.util.Arrays; +import java.util.Collections; +import java.util.List; +import java.util.stream.Collectors; +import java.util.stream.StreamSupport; + +import org.apache.beam.sdk.coders.SerializableCoder; +import org.apache.beam.sdk.ml.remoteinference.base.*; +import org.apache.beam.sdk.testing.PAssert; +import org.apache.beam.sdk.testing.TestPipeline; +import org.apache.beam.sdk.transforms.Create; +import org.apache.beam.sdk.values.PCollection; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertThrows; +import static org.junit.Assert.assertTrue; +import static org.junit.Assert.fail; +import org.junit.Rule; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + + + +@RunWith(JUnit4.class) +public class RemoteInferenceTest { + + @Rule + public final transient TestPipeline pipeline = TestPipeline.create(); + + // Test input class + public static class TestInput extends BaseInput { + private final String value; + + private TestInput(String value) { + this.value = value; + } + + public static TestInput create(String value) { + return new TestInput(value); + } + + public String getModelInput() { + return value; + } + + @Override + public boolean equals(Object o) { + if (this == o) + return true; + if (!(o instanceof TestInput)) + return false; + TestInput testInput = (TestInput) o; + return value.equals(testInput.value); + } + + @Override + public int hashCode() { + return value.hashCode(); + } + + @Override + public String toString() { + return "TestInput{value='" + value + "'}"; + } + } + + // Test output class + public static class TestOutput extends BaseResponse { + private final String result; + + private TestOutput(String result) { + this.result = result; + } + + public static TestOutput create(String result) { + return new TestOutput(result); + } + + public String getModelResponse() { + return result; + } + + @Override + public boolean equals(Object o) { + if (this == o) + return true; + if (!(o instanceof TestOutput)) + return false; + TestOutput that = (TestOutput) o; + return result.equals(that.result); + } + + @Override + public int hashCode() { + return result.hashCode(); + } + + @Override + public String toString() { + return "TestOutput{result='" + result + "'}"; + } + } + + // Test parameters class + public static class TestParameters implements BaseModelParameters { + private final String config; + + private TestParameters(Builder builder) { + this.config = builder.config; + } + + public String getConfig() { + return config; + } + + @Override + public String toString() { + return "TestParameters{config='" + config + "'}"; + } + + @Override + public boolean equals(Object o) { + if (this == o) + return true; + if (!(o instanceof TestParameters)) + return false; + TestParameters that = (TestParameters) o; + return config.equals(that.config); + } + + @Override + public int hashCode() { + return config.hashCode(); + } + + // Builder + public static class Builder { + private String config; + + public Builder setConfig(String config) { + this.config = config; + return this; + } + + public TestParameters build() { + return new TestParameters(this); + } + } + + public static Builder builder() { + return new Builder(); + } + } + + // Mock handler for successful inference + public static class MockSuccessHandler + implements BaseModelHandler { + + private TestParameters parameters; + private boolean clientCreated = false; + + @Override + public void createClient(TestParameters parameters) { + this.parameters = parameters; + this.clientCreated = true; + } + + @Override + public Iterable> request(List input) { + if (!clientCreated) { + throw new IllegalStateException("Client not initialized"); + } + return input.stream() + .map(i -> PredictionResult.create( + i, + new TestOutput("processed-" + i.getModelInput()))) + .collect(Collectors.toList()); + } + } + + // Mock handler that returns empty results + public static class MockEmptyResultHandler + implements BaseModelHandler { + + @Override + public void createClient(TestParameters parameters) { + // Setup succeeds + } + + @Override + public Iterable> request(List input) { + return Collections.emptyList(); + } + } + + // Mock handler that throws exception during setup + public static class MockFailingSetupHandler + implements BaseModelHandler { + + @Override + public void createClient(TestParameters parameters) { + throw new RuntimeException("Setup failed intentionally"); + } + + @Override + public Iterable> request(List input) { + return Collections.emptyList(); + } + } + + // Mock handler that throws exception during request + public static class MockFailingRequestHandler + implements BaseModelHandler { + + @Override + public void createClient(TestParameters parameters) { + // Setup succeeds + } + + @Override + public Iterable> request(List input) { + throw new RuntimeException("Request failed intentionally"); + } + } + + // Mock handler without default constructor (to test error handling) + public static class MockNoDefaultConstructorHandler + implements BaseModelHandler { + + private final String required; + + public MockNoDefaultConstructorHandler(String required) { + this.required = required; + } + + @Override + public void createClient(TestParameters parameters) { + } + + @Override + public Iterable> request(List input) { + return Collections.emptyList(); + } + } + + @Test + public void testInvokeWithSingleElement() { + TestInput input = TestInput.create("test-value"); + TestParameters params = TestParameters.builder() + .setConfig("test-config") + .build(); + + PCollection inputCollection = pipeline.apply(Create.of(input)); + + PCollection>> results = inputCollection + .apply("RemoteInference", + RemoteInference.invoke() + .handler(MockSuccessHandler.class) + .withParameters(params)); + + // Verify the output contains expected predictions + PAssert.thatSingleton(results).satisfies(batch -> { + List> resultList = StreamSupport.stream(batch.spliterator(), false) + .collect(Collectors.toList()); + + assertEquals("Expected exactly 1 result", 1, resultList.size()); + + PredictionResult result = resultList.get(0); + assertEquals("test-value", result.getInput().getModelInput()); + assertEquals("processed-test-value", result.getOutput().getModelResponse()); + + return null; + }); + + pipeline.run().waitUntilFinish(); + } + + @Test + public void testInvokeWithMultipleElements() { + List inputs = Arrays.asList( + new TestInput("input1"), + new TestInput("input2"), + new TestInput("input3")); + + TestParameters params = TestParameters.builder() + .setConfig("test-config") + .build(); + + PCollection inputCollection = pipeline + .apply("CreateInputs", Create.of(inputs).withCoder(SerializableCoder.of(TestInput.class))); + + PCollection>> results = inputCollection + .apply("RemoteInference", + RemoteInference.invoke() + .handler(MockSuccessHandler.class) + .withParameters(params)); + + // Count total results across all batches + PAssert.that(results).satisfies(batches -> { + int totalCount = 0; + for (Iterable> batch : batches) { + for (PredictionResult result : batch) { + totalCount++; + assertTrue("Output should start with 'processed-'", + result.getOutput().getModelResponse().startsWith("processed-")); + assertNotNull("Input should not be null", result.getInput()); + assertNotNull("Output should not be null", result.getOutput()); + } + } + assertEquals("Expected 3 total results", 3, totalCount); + return null; + }); + + pipeline.run().waitUntilFinish(); + } + + @Test + public void testInvokeWithEmptyCollection() { + TestParameters params = TestParameters.builder() + .setConfig("test-config") + .build(); + + PCollection inputCollection = pipeline + .apply("CreateEmptyInput", Create.empty(SerializableCoder.of(TestInput.class))); + + PCollection>> results = inputCollection + .apply("RemoteInference", + RemoteInference.invoke() + .handler(MockSuccessHandler.class) + .withParameters(params)); + + // assertion for empty PCollection + PAssert.that(results).empty(); + + pipeline.run().waitUntilFinish(); + } + + @Test + public void testHandlerReturnsEmptyResults() { + TestInput input = new TestInput("test-value"); + TestParameters params = TestParameters.builder() + .setConfig("test-config") + .build(); + + PCollection inputCollection = pipeline + .apply("CreateInput", Create.of(input).withCoder(SerializableCoder.of(TestInput.class))); + + PCollection>> results = inputCollection + .apply("RemoteInference", + RemoteInference.invoke() + .handler(MockEmptyResultHandler.class) + .withParameters(params)); + + // Verify we still get a result, but it's empty + PAssert.thatSingleton(results).satisfies(batch -> { + List> resultList = StreamSupport.stream(batch.spliterator(), false) + .collect(Collectors.toList()); + assertEquals("Expected empty result list", 0, resultList.size()); + return null; + }); + + pipeline.run().waitUntilFinish(); + } + + @Test + public void testHandlerSetupFailure() { + TestInput input = new TestInput("test-value"); + TestParameters params = TestParameters.builder() + .setConfig("test-config") + .build(); + + PCollection inputCollection = pipeline + .apply("CreateInput", Create.of(input).withCoder(SerializableCoder.of(TestInput.class))); + + inputCollection.apply("RemoteInference", + RemoteInference.invoke() + .handler(MockFailingSetupHandler.class) + .withParameters(params)); + + // Verify pipeline fails with expected error + try { + pipeline.run().waitUntilFinish(); + fail("Expected pipeline to fail due to handler setup failure"); + } catch (Exception e) { + String message = e.getMessage(); + assertTrue("Exception should mention setup failure or handler instantiation failure", + message != null && (message.contains("Setup failed intentionally") || + message.contains("Failed to instantiate handler"))); + } + } + + @Test + public void testHandlerRequestFailure() { + TestInput input = new TestInput("test-value"); + TestParameters params = TestParameters.builder() + .setConfig("test-config") + .build(); + + PCollection inputCollection = pipeline + .apply("CreateInput", Create.of(input).withCoder(SerializableCoder.of(TestInput.class))); + + inputCollection.apply("RemoteInference", + RemoteInference.invoke() + .handler(MockFailingRequestHandler.class) + .withParameters(params)); + + // Verify pipeline fails with expected error + try { + pipeline.run().waitUntilFinish(); + fail("Expected pipeline to fail due to request failure"); + } catch (Exception e) { + String message = e.getMessage(); + assertTrue("Exception should mention request failure", + message != null && message.contains("Request failed intentionally")); + } + } + + @Test + public void testHandlerWithoutDefaultConstructor() { + TestInput input = new TestInput("test-value"); + TestParameters params = TestParameters.builder() + .setConfig("test-config") + .build(); + + PCollection inputCollection = pipeline + .apply("CreateInput", Create.of(input).withCoder(SerializableCoder.of(TestInput.class))); + + inputCollection.apply("RemoteInference", + RemoteInference.invoke() + .handler(MockNoDefaultConstructorHandler.class) + .withParameters(params)); + + // Verify pipeline fails when handler cannot be instantiated + try { + pipeline.run().waitUntilFinish(); + fail("Expected pipeline to fail due to missing default constructor"); + } catch (Exception e) { + String message = e.getMessage(); + assertTrue("Exception should mention handler instantiation failure", + message != null && message.contains("Failed to instantiate handler")); + } + } + + @Test + public void testBuilderPattern() { + TestParameters params = TestParameters.builder() + .setConfig("test-config") + .build(); + + RemoteInference.Invoke transform = RemoteInference.invoke() + .handler(MockSuccessHandler.class) + .withParameters(params); + + assertNotNull("Transform should not be null", transform); + } + + @Test + public void testPredictionResultMapping() { + TestInput input = new TestInput("mapping-test"); + TestParameters params = TestParameters.builder() + .setConfig("test-config") + .build(); + + PCollection inputCollection = pipeline + .apply("CreateInput", Create.of(input).withCoder(SerializableCoder.of(TestInput.class))); + + PCollection>> results = inputCollection + .apply("RemoteInference", + RemoteInference.invoke() + .handler(MockSuccessHandler.class) + .withParameters(params)); + + PAssert.thatSingleton(results).satisfies(batch -> { + for (PredictionResult result : batch) { + // Verify that input is preserved in the result + assertNotNull("Input should not be null", result.getInput()); + assertNotNull("Output should not be null", result.getOutput()); + assertEquals("mapping-test", result.getInput().getModelInput()); + assertTrue("Output should contain input value", + result.getOutput().getModelResponse().contains("mapping-test")); + } + return null; + }); + + pipeline.run().waitUntilFinish(); + } + + // Temporary behaviour until we introduce java BatchElements transform + // to batch elements in RemoteInference + @Test + public void testMultipleInputsProduceSeparateBatches() { + List inputs = Arrays.asList( + new TestInput("input1"), + new TestInput("input2")); + + TestParameters params = TestParameters.builder() + .setConfig("test-config") + .build(); + + PCollection inputCollection = pipeline + .apply("CreateInputs", Create.of(inputs).withCoder(SerializableCoder.of(TestInput.class))); + + PCollection>> results = inputCollection + .apply("RemoteInference", + RemoteInference.invoke() + .handler(MockSuccessHandler.class) + .withParameters(params)); + + PAssert.that(results).satisfies(batches -> { + int batchCount = 0; + for (Iterable> batch : batches) { + batchCount++; + int elementCount = 0; + elementCount += StreamSupport.stream(batch.spliterator(), false).count(); + // Each batch should contain exactly 1 element + assertEquals("Each batch should contain 1 element", 1, elementCount); + } + assertEquals("Expected 2 batches", 2, batchCount); + return null; + }); + + pipeline.run().waitUntilFinish(); + } + + @Test + public void testWithEmptyParameters() { + + pipeline.enableAbandonedNodeEnforcement(false); + + TestInput input = TestInput.create("test-value"); + PCollection inputCollection = pipeline.apply(Create.of(input)); + + IllegalArgumentException thrown = assertThrows( + IllegalArgumentException.class, + () -> inputCollection.apply("RemoteInference", + RemoteInference.invoke() + .handler(MockSuccessHandler.class))); + + assertTrue( + "Expected message to contain 'withParameters() is required', but got: " + thrown.getMessage(), + thrown.getMessage().contains("withParameters() is required")); + } + + @Test + public void testWithEmptyHandler() { + + pipeline.enableAbandonedNodeEnforcement(false); + + TestParameters params = TestParameters.builder() + .setConfig("test-config") + .build(); + + TestInput input = TestInput.create("test-value"); + PCollection inputCollection = pipeline.apply(Create.of(input)); + + IllegalArgumentException thrown = assertThrows( + IllegalArgumentException.class, + () -> inputCollection.apply("RemoteInference", + RemoteInference.invoke() + .withParameters(params))); + + assertTrue( + "Expected message to contain 'handler() is required', but got: " + thrown.getMessage(), + thrown.getMessage().contains("handler() is required")); + } +} diff --git a/sdks/java/ml/remoteinference/src/test/java/org/apache/beam/sdk/ml/remoteinference/openai/OpenAIModelHandlerIT.java b/sdks/java/ml/remoteinference/src/test/java/org/apache/beam/sdk/ml/remoteinference/openai/OpenAIModelHandlerIT.java new file mode 100644 index 000000000000..fb24b090cb68 --- /dev/null +++ b/sdks/java/ml/remoteinference/src/test/java/org/apache/beam/sdk/ml/remoteinference/openai/OpenAIModelHandlerIT.java @@ -0,0 +1,366 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * License); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an AS IS BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.sdk.ml.remoteinference.openai; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; + +import org.apache.beam.sdk.testing.PAssert; +import org.apache.beam.sdk.testing.TestPipeline; +import org.apache.beam.sdk.transforms.Create; +import org.apache.beam.sdk.transforms.MapElements; +import org.apache.beam.sdk.values.PCollection; +import org.apache.beam.sdk.values.TypeDescriptor; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertTrue; +import static org.junit.Assert.fail; +import static org.junit.Assume.assumeNotNull; +import static org.junit.Assume.assumeTrue; +import org.junit.Before; +import org.junit.Rule; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import org.apache.beam.sdk.ml.remoteinference.base.*; +import org.apache.beam.sdk.ml.remoteinference.RemoteInference; +import org.apache.beam.sdk.ml.remoteinference.openai.OpenAIModelHandler.StructuredInputOutput; +import org.apache.beam.sdk.ml.remoteinference.openai.OpenAIModelHandler.Response; + +public class OpenAIModelHandlerIT { + private static final Logger LOG = LoggerFactory.getLogger(OpenAIModelHandlerIT.class); + + @Rule + public final transient TestPipeline pipeline = TestPipeline.create(); + + private String apiKey; + private static final String API_KEY_ENV = "OPENAI_API_KEY"; + private static final String DEFAULT_MODEL = "gpt-4o-mini"; + + + @Before + public void setUp() { + // Get API key + apiKey = System.getenv(API_KEY_ENV); + + // Skip tests if API key is not provided + assumeNotNull( + "OpenAI API key not found. Set " + API_KEY_ENV + + " environment variable to run integration tests.", + apiKey); + assumeTrue("OpenAI API key is empty. Set " + API_KEY_ENV + + " environment variable to run integration tests.", + !apiKey.trim().isEmpty()); + } + + @Test + public void testSentimentAnalysisWithSingleInput() { + String input = "This product is absolutely amazing! I love it!"; + + PCollection inputs = pipeline + .apply("CreateSingleInput", Create.of(input)) + .apply("MapToInput", MapElements + .into(TypeDescriptor.of(OpenAIModelInput.class)) + .via(OpenAIModelInput::create)); + + PCollection>> results = inputs + .apply("SentimentInference", + RemoteInference.invoke() + .handler(OpenAIModelHandler.class) + .withParameters(OpenAIModelParameters.builder() + .apiKey(apiKey) + .modelName(DEFAULT_MODEL) + .instructionPrompt( + "Analyze the sentiment as 'positive' or 'negative'. Return only one word.") + .build())); + + // Verify results + PAssert.that(results).satisfies(batches -> { + int count = 0; + for (Iterable> batch : batches) { + for (PredictionResult result : batch) { + count++; + assertNotNull("Input should not be null", result.getInput()); + assertNotNull("Output should not be null", result.getOutput()); + assertNotNull("Output text should not be null", + result.getOutput().getModelResponse()); + + String sentiment = result.getOutput().getModelResponse().toLowerCase(); + assertTrue("Sentiment should be positive or negative, got: " + sentiment, + sentiment.contains("positive") + || sentiment.contains("negative")); + } + } + assertEquals("Should have exactly 1 result", 1, count); + return null; + }); + + pipeline.run().waitUntilFinish(); + } + + @Test + public void testSentimentAnalysisWithMultipleInputs() { + List inputs = Arrays.asList( + "An excellent B2B SaaS solution that streamlines business processes efficiently.", + "The customer support is terrible. I've been waiting for days without any response.", + "The application works as expected. Installation was straightforward.", + "Really impressed with the innovative features! The AI capabilities are groundbreaking!", + "Mediocre product with occasional glitches. Documentation could be better."); + + PCollection inputCollection = pipeline + .apply("CreateMultipleInputs", Create.of(inputs)) + .apply("MapToInputs", MapElements + .into(TypeDescriptor.of(OpenAIModelInput.class)) + .via(OpenAIModelInput::create)); + + PCollection>> results = inputCollection + .apply("SentimentInference", + RemoteInference.invoke() + .handler(OpenAIModelHandler.class) + .withParameters(OpenAIModelParameters.builder() + .apiKey(apiKey) + .modelName(DEFAULT_MODEL) + .instructionPrompt( + "Analyze sentiment as positive or negative") + .build())); + + // Verify we get results for all inputs + PAssert.that(results).satisfies(batches -> { + int totalCount = 0; + for (Iterable> batch : batches) { + for (PredictionResult result : batch) { + totalCount++; + assertNotNull("Input should not be null", result.getInput()); + assertNotNull("Output should not be null", result.getOutput()); + assertFalse("Output should not be empty", + result.getOutput().getModelResponse().trim().isEmpty()); + } + } + assertEquals("Should have results for all 5 inputs", 5, totalCount); + return null; + }); + + pipeline.run().waitUntilFinish(); + } + + @Test + public void testTextClassification() { + List inputs = Arrays.asList( + "How do I reset my password?", + "Your product is broken and I want a refund!", + "Thank you for the excellent service!"); + + PCollection inputCollection = pipeline + .apply("CreateInputs", Create.of(inputs)) + .apply("MapToInputs", MapElements + .into(TypeDescriptor.of(OpenAIModelInput.class)) + .via(OpenAIModelInput::create)); + + PCollection>> results = inputCollection + .apply("ClassificationInference", + RemoteInference.invoke() + .handler(OpenAIModelHandler.class) + .withParameters(OpenAIModelParameters.builder() + .apiKey(apiKey) + .modelName(DEFAULT_MODEL) + .instructionPrompt( + "Classify each text into one category: 'question', 'complaint', or 'praise'. Return only the category.") + .build())); + + PAssert.that(results).satisfies(batches -> { + List categories = new ArrayList<>(); + for (Iterable> batch : batches) { + for (PredictionResult result : batch) { + String category = result.getOutput().getModelResponse().toLowerCase(); + categories.add(category); + } + } + + assertEquals("Should have 3 categories", 3, categories.size()); + + // Verify expected categories + boolean hasQuestion = categories.stream().anyMatch(c -> c.contains("question")); + boolean hasComplaint = categories.stream().anyMatch(c -> c.contains("complaint")); + boolean hasPraise = categories.stream().anyMatch(c -> c.contains("praise")); + + assertTrue("Should have at least one recognized category", + hasQuestion || hasComplaint || hasPraise); + + return null; + }); + + pipeline.run().waitUntilFinish(); + } + + @Test + public void testInputOutputMapping() { + List inputs = Arrays.asList("apple", "banana", "cherry"); + + PCollection inputCollection = pipeline + .apply("CreateInputs", Create.of(inputs)) + .apply("MapToInputs", MapElements + .into(TypeDescriptor.of(OpenAIModelInput.class)) + .via(OpenAIModelInput::create)); + + PCollection>> results = inputCollection + .apply("MappingInference", + RemoteInference.invoke() + .handler(OpenAIModelHandler.class) + .withParameters(OpenAIModelParameters.builder() + .apiKey(apiKey) + .modelName(DEFAULT_MODEL) + .instructionPrompt( + "Return the input word in uppercase") + .build())); + + // Verify input-output pairing is preserved + PAssert.that(results).satisfies(batches -> { + for (Iterable> batch : batches) { + for (PredictionResult result : batch) { + String input = result.getInput().getModelInput(); + String output = result.getOutput().getModelResponse().toLowerCase(); + + // Verify the output relates to the input + assertTrue("Output should relate to input '" + input + "', got: " + output, + output.contains(input.toLowerCase())); + } + } + return null; + }); + + pipeline.run().waitUntilFinish(); + } + + @Test + public void testWithDifferentModel() { + // Test with a different model + String input = "Explain quantum computing in one sentence."; + + PCollection inputs = pipeline + .apply("CreateInput", Create.of(input)) + .apply("MapToInput", MapElements + .into(TypeDescriptor.of(OpenAIModelInput.class)) + .via(OpenAIModelInput::create)); + + PCollection>> results = inputs + .apply("DifferentModelInference", + RemoteInference.invoke() + .handler(OpenAIModelHandler.class) + .withParameters(OpenAIModelParameters.builder() + .apiKey(apiKey) + .modelName("gpt-5") + .instructionPrompt("Respond concisely") + .build())); + + PAssert.that(results).satisfies(batches -> { + for (Iterable> batch : batches) { + for (PredictionResult result : batch) { + assertNotNull("Output should not be null", + result.getOutput().getModelResponse()); + assertFalse("Output should not be empty", + result.getOutput().getModelResponse().trim().isEmpty()); + } + } + return null; + }); + + pipeline.run().waitUntilFinish(); + } + + @Test + public void testWithInvalidApiKey() { + String input = "Test input"; + + PCollection inputs = pipeline + .apply("CreateInput", Create.of(input)) + .apply("MapToInput", MapElements + .into(TypeDescriptor.of(OpenAIModelInput.class)) + .via(OpenAIModelInput::create)); + + inputs.apply("InvalidKeyInference", + RemoteInference.invoke() + .handler(OpenAIModelHandler.class) + .withParameters(OpenAIModelParameters.builder() + .apiKey("invalid-api-key-12345") + .modelName(DEFAULT_MODEL) + .instructionPrompt("Test") + .build())); + + // Expect pipeline to fail with authentication error + try { + pipeline.run().waitUntilFinish(); + fail("Expected pipeline to fail with invalid API key"); + } catch (Exception e) { + // Expected - verify it's an authentication error + String message = e.getMessage().toLowerCase(); + assertTrue("Exception should mention authentication or API key issue, got: " + message, + message.contains("auth") || + message.contains("api") || + message.contains("key") || + message.contains("401")); + } + } + + /** + * Test with custom instruction formats + */ + @Test + public void testWithJsonOutputFormat() { + String input = "Paris is the capital of France"; + + PCollection inputs = pipeline + .apply("CreateInput", Create.of(input)) + .apply("MapToInput", MapElements + .into(TypeDescriptor.of(OpenAIModelInput.class)) + .via(OpenAIModelInput::create)); + + PCollection>> results = inputs + .apply("JsonFormatInference", + RemoteInference.invoke() + .handler(OpenAIModelHandler.class) + .withParameters(OpenAIModelParameters.builder() + .apiKey(apiKey) + .modelName(DEFAULT_MODEL) + .instructionPrompt( + "Extract the city and country. Return as: City: [city], Country: [country]") + .build())); + + PAssert.that(results).satisfies(batches -> { + for (Iterable> batch : batches) { + for (PredictionResult result : batch) { + String output = result.getOutput().getModelResponse(); + LOG.info("Structured output: " + output); + + // Verify output contains expected information + assertTrue("Output should mention Paris: " + output, + output.toLowerCase().contains("paris")); + assertTrue("Output should mention France: " + output, + output.toLowerCase().contains("france")); + } + } + return null; + }); + + pipeline.run().waitUntilFinish(); + } +} diff --git a/sdks/java/ml/remoteinference/src/test/java/org/apache/beam/sdk/ml/remoteinference/openai/OpenAIModelHandlerTest.java b/sdks/java/ml/remoteinference/src/test/java/org/apache/beam/sdk/ml/remoteinference/openai/OpenAIModelHandlerTest.java new file mode 100644 index 000000000000..3bff5600aa45 --- /dev/null +++ b/sdks/java/ml/remoteinference/src/test/java/org/apache/beam/sdk/ml/remoteinference/openai/OpenAIModelHandlerTest.java @@ -0,0 +1,450 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * License); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an AS IS BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.sdk.ml.remoteinference.openai; + +import java.util.Arrays; +import java.util.Collections; +import java.util.List; +import java.util.stream.Collectors; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertTrue; +import static org.junit.Assert.fail; +import org.junit.Before; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +import org.apache.beam.sdk.ml.remoteinference.base.*; +import org.apache.beam.sdk.ml.remoteinference.openai.OpenAIModelHandler.StructuredInputOutput; +import org.apache.beam.sdk.ml.remoteinference.openai.OpenAIModelHandler.Response; + + + +@RunWith(JUnit4.class) +public class OpenAIModelHandlerTest { + private OpenAIModelParameters testParameters; + + @Before + public void setUp() { + testParameters = OpenAIModelParameters.builder() + .apiKey("test-api-key") + .modelName("gpt-4") + .instructionPrompt("Test instruction") + .build(); + } + + /** + * Fake OpenAiModelHandler for testing. + */ + static class FakeOpenAiModelHandler extends OpenAIModelHandler { + + private boolean clientCreated = false; + private OpenAIModelParameters storedParameters; + private List responsesToReturn; + private RuntimeException exceptionToThrow; + private boolean shouldReturnNull = false; + + public void setResponsesToReturn(List responses) { + this.responsesToReturn = responses; + } + + public void setExceptionToThrow(RuntimeException exception) { + this.exceptionToThrow = exception; + } + + public void setShouldReturnNull(boolean shouldReturnNull) { + this.shouldReturnNull = shouldReturnNull; + } + + public boolean isClientCreated() { + return clientCreated; + } + + public OpenAIModelParameters getStoredParameters() { + return storedParameters; + } + + @Override + public void createClient(OpenAIModelParameters parameters) { + this.storedParameters = parameters; + this.clientCreated = true; + + if (exceptionToThrow != null) { + throw exceptionToThrow; + } + } + + @Override + public Iterable> request( + List input) { + + if (!clientCreated) { + throw new IllegalStateException("Client not initialized"); + } + + if (exceptionToThrow != null) { + throw exceptionToThrow; + } + + if (shouldReturnNull || responsesToReturn == null) { + throw new RuntimeException("Model returned no structured responses"); + } + + StructuredInputOutput structuredOutput = responsesToReturn.get(0); + + if (structuredOutput == null || structuredOutput.responses == null) { + throw new RuntimeException("Model returned no structured responses"); + } + + return structuredOutput.responses.stream() + .map(response -> PredictionResult.create( + OpenAIModelInput.create(response.input), + OpenAIModelResponse.create(response.output))) + .collect(Collectors.toList()); + } + } + + @Test + public void testCreateClient() { + FakeOpenAiModelHandler handler = new FakeOpenAiModelHandler(); + + OpenAIModelParameters params = OpenAIModelParameters.builder() + .apiKey("test-key") + .modelName("gpt-4") + .instructionPrompt("test prompt") + .build(); + + handler.createClient(params); + + assertTrue("Client should be created", handler.isClientCreated()); + assertNotNull("Parameters should be stored", handler.getStoredParameters()); + assertEquals("test-key", handler.getStoredParameters().getApiKey()); + assertEquals("gpt-4", handler.getStoredParameters().getModelName()); + } + + @Test + public void testRequestWithSingleInput() { + FakeOpenAiModelHandler handler = new FakeOpenAiModelHandler(); + + List inputs = Collections.singletonList( + OpenAIModelInput.create("test input")); + + StructuredInputOutput structuredOutput = new StructuredInputOutput(); + Response response = new Response(); + response.input = "test input"; + response.output = "test output"; + structuredOutput.responses = Collections.singletonList(response); + + handler.setResponsesToReturn(Collections.singletonList(structuredOutput)); + handler.createClient(testParameters); + + Iterable> results = handler.request(inputs); + + assertNotNull("Results should not be null", results); + + List> resultList = iterableToList(results); + + assertEquals("Should have 1 result", 1, resultList.size()); + + PredictionResult result = resultList.get(0); + assertEquals("test input", result.getInput().getModelInput()); + assertEquals("test output", result.getOutput().getModelResponse()); + } + + @Test + public void testRequestWithMultipleInputs() { + FakeOpenAiModelHandler handler = new FakeOpenAiModelHandler(); + + List inputs = Arrays.asList( + OpenAIModelInput.create("input1"), + OpenAIModelInput.create("input2"), + OpenAIModelInput.create("input3")); + + StructuredInputOutput structuredOutput = new StructuredInputOutput(); + + Response response1 = new Response(); + response1.input = "input1"; + response1.output = "output1"; + + Response response2 = new Response(); + response2.input = "input2"; + response2.output = "output2"; + + Response response3 = new Response(); + response3.input = "input3"; + response3.output = "output3"; + + structuredOutput.responses = Arrays.asList(response1, response2, response3); + + handler.setResponsesToReturn(Collections.singletonList(structuredOutput)); + handler.createClient(testParameters); + + Iterable> results = handler.request(inputs); + + List> resultList = iterableToList(results); + + assertEquals("Should have 3 results", 3, resultList.size()); + + for (int i = 0; i < 3; i++) { + PredictionResult result = resultList.get(i); + assertEquals("input" + (i + 1), result.getInput().getModelInput()); + assertEquals("output" + (i + 1), result.getOutput().getModelResponse()); + } + } + + @Test + public void testRequestWithEmptyInput() { + FakeOpenAiModelHandler handler = new FakeOpenAiModelHandler(); + + List inputs = Collections.emptyList(); + + StructuredInputOutput structuredOutput = new StructuredInputOutput(); + structuredOutput.responses = Collections.emptyList(); + + handler.setResponsesToReturn(Collections.singletonList(structuredOutput)); + handler.createClient(testParameters); + + Iterable> results = handler.request(inputs); + + List> resultList = iterableToList(results); + assertEquals("Should have 0 results", 0, resultList.size()); + } + + @Test + public void testRequestWithNullStructuredOutput() { + FakeOpenAiModelHandler handler = new FakeOpenAiModelHandler(); + + List inputs = Collections.singletonList( + OpenAIModelInput.create("test input")); + + handler.setShouldReturnNull(true); + handler.createClient(testParameters); + + try { + handler.request(inputs); + fail("Expected RuntimeException when structured output is null"); + } catch (RuntimeException e) { + assertTrue("Exception message should mention no structured responses", + e.getMessage().contains("Model returned no structured responses")); + } + } + + @Test + public void testRequestWithNullResponsesList() { + FakeOpenAiModelHandler handler = new FakeOpenAiModelHandler(); + + List inputs = Collections.singletonList( + OpenAIModelInput.create("test input")); + + StructuredInputOutput structuredOutput = new StructuredInputOutput(); + structuredOutput.responses = null; + + handler.setResponsesToReturn(Collections.singletonList(structuredOutput)); + handler.createClient(testParameters); + + try { + handler.request(inputs); + fail("Expected RuntimeException when responses list is null"); + } catch (RuntimeException e) { + assertTrue("Exception message should mention no structured responses", + e.getMessage().contains("Model returned no structured responses")); + } + } + + @Test + public void testCreateClientFailure() { + FakeOpenAiModelHandler handler = new FakeOpenAiModelHandler(); + handler.setExceptionToThrow(new RuntimeException("Setup failed")); + + try { + handler.createClient(testParameters); + fail("Expected RuntimeException during client creation"); + } catch (RuntimeException e) { + assertEquals("Setup failed", e.getMessage()); + } + } + + @Test + public void testRequestApiFailure() { + FakeOpenAiModelHandler handler = new FakeOpenAiModelHandler(); + + List inputs = Collections.singletonList( + OpenAIModelInput.create("test input")); + + handler.createClient(testParameters); + handler.setExceptionToThrow(new RuntimeException("API Error")); + + try { + handler.request(inputs); + fail("Expected RuntimeException when API fails"); + } catch (RuntimeException e) { + assertEquals("API Error", e.getMessage()); + } + } + + @Test + public void testRequestWithoutClientInitialization() { + FakeOpenAiModelHandler handler = new FakeOpenAiModelHandler(); + + List inputs = Collections.singletonList( + OpenAIModelInput.create("test input")); + + StructuredInputOutput structuredOutput = new StructuredInputOutput(); + Response response = new Response(); + response.input = "test input"; + response.output = "test output"; + structuredOutput.responses = Collections.singletonList(response); + + handler.setResponsesToReturn(Collections.singletonList(structuredOutput)); + + // Don't call createClient + try { + handler.request(inputs); + fail("Expected IllegalStateException when client not initialized"); + } catch (IllegalStateException e) { + assertTrue("Exception should mention client not initialized", + e.getMessage().contains("Client not initialized")); + } + } + + @Test + public void testInputOutputMapping() { + FakeOpenAiModelHandler handler = new FakeOpenAiModelHandler(); + + List inputs = Arrays.asList( + OpenAIModelInput.create("alpha"), + OpenAIModelInput.create("beta")); + + StructuredInputOutput structuredOutput = new StructuredInputOutput(); + + Response response1 = new Response(); + response1.input = "alpha"; + response1.output = "ALPHA"; + + Response response2 = new Response(); + response2.input = "beta"; + response2.output = "BETA"; + + structuredOutput.responses = Arrays.asList(response1, response2); + + handler.setResponsesToReturn(Collections.singletonList(structuredOutput)); + handler.createClient(testParameters); + + Iterable> results = handler.request(inputs); + + List> resultList = iterableToList(results); + + assertEquals(2, resultList.size()); + assertEquals("alpha", resultList.get(0).getInput().getModelInput()); + assertEquals("ALPHA", resultList.get(0).getOutput().getModelResponse()); + + assertEquals("beta", resultList.get(1).getInput().getModelInput()); + assertEquals("BETA", resultList.get(1).getOutput().getModelResponse()); + } + + @Test + public void testParametersBuilder() { + OpenAIModelParameters params = OpenAIModelParameters.builder() + .apiKey("my-api-key") + .modelName("gpt-4-turbo") + .instructionPrompt("Custom instruction") + .build(); + + assertEquals("my-api-key", params.getApiKey()); + assertEquals("gpt-4-turbo", params.getModelName()); + assertEquals("Custom instruction", params.getInstructionPrompt()); + } + + @Test + public void testOpenAIModelInputCreate() { + OpenAIModelInput input = OpenAIModelInput.create("test value"); + + assertNotNull("Input should not be null", input); + assertEquals("test value", input.getModelInput()); + } + + @Test + public void testOpenAIModelResponseCreate() { + OpenAIModelResponse response = OpenAIModelResponse.create("test output"); + + assertNotNull("Response should not be null", response); + assertEquals("test output", response.getModelResponse()); + } + + @Test + public void testStructuredInputOutputStructure() { + Response response = new Response(); + response.input = "test-input"; + response.output = "test-output"; + + assertEquals("test-input", response.input); + assertEquals("test-output", response.output); + + StructuredInputOutput structured = new StructuredInputOutput(); + structured.responses = Collections.singletonList(response); + + assertNotNull("Responses should not be null", structured.responses); + assertEquals("Should have 1 response", 1, structured.responses.size()); + assertEquals("test-input", structured.responses.get(0).input); + } + + @Test + public void testMultipleRequestsWithSameHandler() { + FakeOpenAiModelHandler handler = new FakeOpenAiModelHandler(); + handler.createClient(testParameters); + + // First request + StructuredInputOutput output1 = new StructuredInputOutput(); + Response response1 = new Response(); + response1.input = "first"; + response1.output = "FIRST"; + output1.responses = Collections.singletonList(response1); + handler.setResponsesToReturn(Collections.singletonList(output1)); + + List inputs1 = Collections.singletonList( + OpenAIModelInput.create("first")); + Iterable> results1 = handler.request(inputs1); + + List> resultList1 = iterableToList(results1); + assertEquals("FIRST", resultList1.get(0).getOutput().getModelResponse()); + + // Second request with different data + StructuredInputOutput output2 = new StructuredInputOutput(); + Response response2 = new Response(); + response2.input = "second"; + response2.output = "SECOND"; + output2.responses = Collections.singletonList(response2); + handler.setResponsesToReturn(Collections.singletonList(output2)); + + List inputs2 = Collections.singletonList( + OpenAIModelInput.create("second")); + Iterable> results2 = handler.request(inputs2); + + List> resultList2 = iterableToList(results2); + assertEquals("SECOND", resultList2.get(0).getOutput().getModelResponse()); + } + + // Helper method to convert Iterable to List + private List iterableToList(Iterable iterable) { + List list = new java.util.ArrayList<>(); + iterable.forEach(list::add); + return list; + } +} From af49dbcb1918974a61c9289c73a36697661bb10f Mon Sep 17 00:00:00 2001 From: Ganeshsivakumar Date: Fri, 7 Nov 2025 19:10:22 +0530 Subject: [PATCH 08/12] api docs --- .../ml/remoteinference/RemoteInference.java | 56 +++++++++++++++++-- .../ml/remoteinference/base/BaseInput.java | 4 ++ .../base/BaseModelHandler.java | 45 ++++++++++++++- .../base/BaseModelParameters.java | 16 ++++++ .../ml/remoteinference/base/BaseResponse.java | 10 ++++ .../base/PredictionResult.java | 9 +++ .../openai/OpenAIModelHandler.java | 54 ++++++++++++++++++ .../openai/OpenAIModelInput.java | 25 +++++++++ .../openai/OpenAIModelParameters.java | 38 ++++++++++++- .../openai/OpenAIModelResponse.java | 24 ++++++++ 10 files changed, 272 insertions(+), 9 deletions(-) diff --git a/sdks/java/ml/remoteinference/src/main/java/org/apache/beam/sdk/ml/remoteinference/RemoteInference.java b/sdks/java/ml/remoteinference/src/main/java/org/apache/beam/sdk/ml/remoteinference/RemoteInference.java index 4f05faef2b85..6a8b9b656bdf 100644 --- a/sdks/java/ml/remoteinference/src/main/java/org/apache/beam/sdk/ml/remoteinference/RemoteInference.java +++ b/sdks/java/ml/remoteinference/src/main/java/org/apache/beam/sdk/ml/remoteinference/RemoteInference.java @@ -21,20 +21,49 @@ import org.apache.beam.sdk.transforms.*; import org.checkerframework.checker.nullness.qual.Nullable; import static org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions.checkArgument; - - import org.apache.beam.sdk.values.PCollection; - - import com.google.auto.value.AutoValue; import java.util.ArrayList; import java.util.Collections; import java.util.List; +/** + * A {@link PTransform} for making remote inference calls to external machine learning services. + * + *

{@code RemoteInference} provides a framework for integrating remote ML model + * inference into Apache Beam pipelines and handles the communication between pipelines + * and external inference APIs. + * + *

Example: OpenAI Model Inference

+ * + *
{@code
+ * // Create model parameters
+ * OpenAIModelParameters params = OpenAIModelParameters.builder()
+ *     .apiKey("your-api-key")
+ *     .modelName("gpt-4")
+ *     .instructionPrompt("Analyse sentiment as positive or negative")
+ *     .build();
+ *
+ * // Apply remote inference transform
+ * PCollection inputs = pipeline.apply(Create.of(
+ *     OpenAIModelInput.create("An excellent B2B SaaS solution that streamlines business processes efficiently."),
+ *     OpenAIModelInput.create("Really impressed with the innovative features!")
+ * ));
+ *
+ * PCollection>> results =
+ *     inputs.apply(
+ *         RemoteInference.invoke()
+ *             .handler(OpenAIModelHandler.class)
+ *             .withParameters(params)
+ *     );
+ * }
+ * + */ @SuppressWarnings({ "rawtypes", "unchecked" }) public class RemoteInference { + /** Invoke the model handler with model parameters */ public static Invoke invoke() { return new AutoValue_RemoteInference_Invoke.Builder().setParameters(null) .build(); @@ -65,10 +94,16 @@ abstract static class Builder build(); } + /** + * Model handler class for inference. + */ public Invoke handler(Class modelHandler) { return builder().setHandler(modelHandler).build(); } + /** + * Configures the parameters for model initialization. + */ public Invoke withParameters(BaseModelParameters modelParameters) { return builder().setParameters(modelParameters).build(); } @@ -89,6 +124,16 @@ public List apply(InputT element) { .apply("RemoteInference", ParDo.of(new RemoteInferenceFn(this))); } + /** + * A {@link DoFn} that performs remote inference operation. + * + *

This function manages the lifecycle of the model handler: + *

    + *
  • Instantiates the handler during {@link Setup}
  • + *
  • Initializes the remote client via {@link BaseModelHandler#createClient}
  • + *
  • Processes elements by calling {@link BaseModelHandler#request}
  • + *
+ */ static class RemoteInferenceFn extends DoFn, Iterable>> { @@ -101,6 +146,7 @@ static class RemoteInferenceFn> response = this.handler.request(c.element()); diff --git a/sdks/java/ml/remoteinference/src/main/java/org/apache/beam/sdk/ml/remoteinference/base/BaseInput.java b/sdks/java/ml/remoteinference/src/main/java/org/apache/beam/sdk/ml/remoteinference/base/BaseInput.java index 939406722b8f..0a84a67fdb4c 100644 --- a/sdks/java/ml/remoteinference/src/main/java/org/apache/beam/sdk/ml/remoteinference/base/BaseInput.java +++ b/sdks/java/ml/remoteinference/src/main/java/org/apache/beam/sdk/ml/remoteinference/base/BaseInput.java @@ -19,6 +19,10 @@ import java.io.Serializable; +/** + * Base class for defining input types used with remote inference transforms. + *Implementations holds the data needed for inference (text, images, etc.) + */ public abstract class BaseInput implements Serializable { } diff --git a/sdks/java/ml/remoteinference/src/main/java/org/apache/beam/sdk/ml/remoteinference/base/BaseModelHandler.java b/sdks/java/ml/remoteinference/src/main/java/org/apache/beam/sdk/ml/remoteinference/base/BaseModelHandler.java index 3ad6cdce84e5..1128a287d927 100644 --- a/sdks/java/ml/remoteinference/src/main/java/org/apache/beam/sdk/ml/remoteinference/base/BaseModelHandler.java +++ b/sdks/java/ml/remoteinference/src/main/java/org/apache/beam/sdk/ml/remoteinference/base/BaseModelHandler.java @@ -19,12 +19,51 @@ import java.util.List; +/** + * Interface for model-specific handlers that perform remote inference operations. + * + *

Implementations of this interface encapsulate all logic for communicating with a + * specific remote inference service. Each handler is responsible for: + *

    + *
  • Initializing and managing client connections
  • + *
  • Converting Beam inputs to service-specific request formats
  • + *
  • Making inference API calls
  • + *
  • Converting service responses to Beam output types
  • + *
  • Handling errors and retries if applicable
  • + *
+ * + *

Lifecycle

+ * + *

Handler instances follow this lifecycle: + *

    + *
  1. Instantiation via no-argument constructor
  2. + *
  3. {@link #createClient} called with parameters during setup
  4. + *
  5. {@link #request} called for each batch of inputs
  6. + *
+ * + * + *

Handlers typically contain non-serializable client objects. + * Mark client fields as {@code transient} and initialize them in {@link #createClient} + * + *

Batching Considerations

+ * + *

The {@link #request} method receives a list of inputs. Implementations should: + *

    + *
  • Batch inputs efficiently if the service supports batch inference
  • + *
  • Return results in the same order as inputs
  • + *
  • Maintain input-output correspondence in {@link PredictionResult}
  • + *
+ * + */ public interface BaseModelHandler { - - // initialize the model with provided parameters + /** + * Initializes the remote model client with the provided parameters. + */ public void createClient(ParamT parameters); - // Logic to invoke model provider + /** + * Performs inference on a batch of inputs and returns the results. + */ public Iterable> request(List input); } diff --git a/sdks/java/ml/remoteinference/src/main/java/org/apache/beam/sdk/ml/remoteinference/base/BaseModelParameters.java b/sdks/java/ml/remoteinference/src/main/java/org/apache/beam/sdk/ml/remoteinference/base/BaseModelParameters.java index 180ad2c91220..46ee72c73001 100644 --- a/sdks/java/ml/remoteinference/src/main/java/org/apache/beam/sdk/ml/remoteinference/base/BaseModelParameters.java +++ b/sdks/java/ml/remoteinference/src/main/java/org/apache/beam/sdk/ml/remoteinference/base/BaseModelParameters.java @@ -19,6 +19,22 @@ import java.io.Serializable; +/** + * Base interface for defining model-specific parameters used to configure remote inference clients. + * + *

Implementations of this interface encapsulate all configuration needed to initialize + * and communicate with a remote model inference service. This typically includes: + *

    + *
  • Authentication credentials (API keys, tokens)
  • + *
  • Model identifiers or names
  • + *
  • Endpoint URLs or connection settings
  • + *
  • Inference configuration (temperature, max tokens, timeout values, etc.)
  • + *
+ * + *

Parameters must be serializable. Consider using + * the builder pattern for complex parameter objects. + * + */ public interface BaseModelParameters extends Serializable { } diff --git a/sdks/java/ml/remoteinference/src/main/java/org/apache/beam/sdk/ml/remoteinference/base/BaseResponse.java b/sdks/java/ml/remoteinference/src/main/java/org/apache/beam/sdk/ml/remoteinference/base/BaseResponse.java index 2e7af0e7c2c1..b3c050c45c11 100644 --- a/sdks/java/ml/remoteinference/src/main/java/org/apache/beam/sdk/ml/remoteinference/base/BaseResponse.java +++ b/sdks/java/ml/remoteinference/src/main/java/org/apache/beam/sdk/ml/remoteinference/base/BaseResponse.java @@ -19,6 +19,16 @@ import java.io.Serializable; +/** + * Base class for defining response types returned from remote inference operations. + + *

Implementations: + *

    + *
  • Contain the inference results (predictions, classifications, generated text, etc.)
  • + *
  • Includes any relevant metadata
  • + *
+ * + */ public abstract class BaseResponse implements Serializable { } diff --git a/sdks/java/ml/remoteinference/src/main/java/org/apache/beam/sdk/ml/remoteinference/base/PredictionResult.java b/sdks/java/ml/remoteinference/src/main/java/org/apache/beam/sdk/ml/remoteinference/base/PredictionResult.java index b19f64917479..edc30fd11246 100644 --- a/sdks/java/ml/remoteinference/src/main/java/org/apache/beam/sdk/ml/remoteinference/base/PredictionResult.java +++ b/sdks/java/ml/remoteinference/src/main/java/org/apache/beam/sdk/ml/remoteinference/base/PredictionResult.java @@ -19,6 +19,12 @@ import java.io.Serializable; +/** + * Pairs an input with its corresponding inference output. + * + *

This class maintains the association between input data and its model's results + * for Downstream processing + */ public class PredictionResult implements Serializable { private final InputT input; @@ -30,14 +36,17 @@ private PredictionResult(InputT input, OutputT output) { } + /* Returns input to handler */ public InputT getInput() { return input; } + /* Returns model handler's response*/ public OutputT getOutput() { return output; } + /* Creates a PredictionResult instance of provided input, output and types */ public static PredictionResult create(InputT input, OutputT output) { return new PredictionResult<>(input, output); } diff --git a/sdks/java/ml/remoteinference/src/main/java/org/apache/beam/sdk/ml/remoteinference/openai/OpenAIModelHandler.java b/sdks/java/ml/remoteinference/src/main/java/org/apache/beam/sdk/ml/remoteinference/openai/OpenAIModelHandler.java index 77138a22d4a6..87616ee693d1 100644 --- a/sdks/java/ml/remoteinference/src/main/java/org/apache/beam/sdk/ml/remoteinference/openai/OpenAIModelHandler.java +++ b/sdks/java/ml/remoteinference/src/main/java/org/apache/beam/sdk/ml/remoteinference/openai/OpenAIModelHandler.java @@ -32,6 +32,31 @@ import java.util.List; import java.util.stream.Collectors; +/** + * Model handler for OpenAI API inference requests. + * + *

This handler manages communication with OpenAI's API, including client initialization, + * request formatting, and response parsing. It uses OpenAI's structured output feature to + * ensure reliable input-output pairing. + * + *

Usage

+ *
{@code
+ * OpenAIModelParameters params = OpenAIModelParameters.builder()
+ *     .apiKey("sk-...")
+ *     .modelName("gpt-4")
+ *     .instructionPrompt("Classify the following text into one of the categories: {CATEGORIES}")
+ *     .build();
+ *
+ * PCollection inputs = ...;
+ * PCollection>> results =
+ *     inputs.apply(
+ *         RemoteInference.invoke()
+ *             .handler(OpenAIModelHandler.class)
+ *             .withParameters(params)
+ *     );
+ * }
+ * + */ public class OpenAIModelHandler implements BaseModelHandler { @@ -39,6 +64,14 @@ public class OpenAIModelHandler private transient StructuredResponseCreateParams clientParams; private OpenAIModelParameters modelParameters; + /** + * Initializes the OpenAI client with the provided parameters. + * + *

This method is called once during setup. It creates an authenticated + * OpenAI client using the API key from the parameters. + * + * @param parameters the configuration parameters including API key and model name + */ @Override public void createClient(OpenAIModelParameters parameters) { this.modelParameters = parameters; @@ -47,6 +80,16 @@ public void createClient(OpenAIModelParameters parameters) { .build(); } + /** + * Performs inference on a batch of inputs using the OpenAI Client. + * + *

This method serializes the input batch to JSON string, sends it to OpenAI with structured + * output requirements, and parses the response into {@link PredictionResult} objects + * that pair each input with its corresponding output. + * + * @param input the list of inputs to process + * @return an iterable of model results and input pairs + */ @Override public Iterable> request(List input) { @@ -92,6 +135,11 @@ public Iterable> request } } + /** + * Schema class for structured output response. + * + *

Represents a single input-output pair returned by the OpenAI API. + */ public static class Response { @JsonProperty(required = true) @JsonPropertyDescription("The input string") @@ -102,6 +150,12 @@ public static class Response { public String output; } + /** + * Schema class for structured output containing multiple responses. + * + *

This class defines the expected JSON structure for OpenAI's structured output, + * ensuring reliable parsing of batched inference results. + */ public static class StructuredInputOutput { @JsonProperty(required = true) @JsonPropertyDescription("Array of input-output pairs") diff --git a/sdks/java/ml/remoteinference/src/main/java/org/apache/beam/sdk/ml/remoteinference/openai/OpenAIModelInput.java b/sdks/java/ml/remoteinference/src/main/java/org/apache/beam/sdk/ml/remoteinference/openai/OpenAIModelInput.java index fabf0d5dd23b..1ef59c89da66 100644 --- a/sdks/java/ml/remoteinference/src/main/java/org/apache/beam/sdk/ml/remoteinference/openai/OpenAIModelInput.java +++ b/sdks/java/ml/remoteinference/src/main/java/org/apache/beam/sdk/ml/remoteinference/openai/OpenAIModelInput.java @@ -19,6 +19,20 @@ import org.apache.beam.sdk.ml.remoteinference.base.BaseInput; +/** + * Input for OpenAI model inference requests. + * + *

This class encapsulates text input to be sent to OpenAI models. + * + *

Example Usage

+ *
{@code
+ * OpenAIModelInput input = OpenAIModelInput.create("Translate to French: Hello");
+ * String text = input.getModelInput(); // "Translate to French: Hello"
+ * }
+ * + * @see OpenAIModelHandler + * @see OpenAIModelResponse + */ public class OpenAIModelInput extends BaseInput { private final String input; @@ -28,10 +42,21 @@ private OpenAIModelInput(String input) { this.input = input; } + /** + * Returns the text input for the model. + * + * @return the input text string + */ public String getModelInput() { return input; } + /** + * Creates a new input instance with the specified text. + * + * @param input the text to send to the model + * @return a new {@link OpenAIModelInput} instance + */ public static OpenAIModelInput create(String input) { return new OpenAIModelInput(input); } diff --git a/sdks/java/ml/remoteinference/src/main/java/org/apache/beam/sdk/ml/remoteinference/openai/OpenAIModelParameters.java b/sdks/java/ml/remoteinference/src/main/java/org/apache/beam/sdk/ml/remoteinference/openai/OpenAIModelParameters.java index 1c703437633d..3aac4112cba3 100644 --- a/sdks/java/ml/remoteinference/src/main/java/org/apache/beam/sdk/ml/remoteinference/openai/OpenAIModelParameters.java +++ b/sdks/java/ml/remoteinference/src/main/java/org/apache/beam/sdk/ml/remoteinference/openai/OpenAIModelParameters.java @@ -19,6 +19,23 @@ import org.apache.beam.sdk.ml.remoteinference.base.BaseModelParameters; +/** + * Configuration parameters required for OpenAI model inference. + * + *

This class encapsulates all configuration needed to initialize and communicate with + * OpenAI's API, including authentication credentials, model selection, and inference instructions. + * + *

Example Usage

+ *
{@code
+ * OpenAIModelParameters params = OpenAIModelParameters.builder()
+ *     .apiKey("sk-...")
+ *     .modelName("gpt-4")
+ *     .instructionPrompt("Translate the following text to French:")
+ *     .build();
+ * }
+ * + * @see OpenAIModelHandler + */ public class OpenAIModelParameters implements BaseModelParameters { private final String apiKey; @@ -56,21 +73,40 @@ public static class Builder { private Builder() { } + /** + * Sets the OpenAI API key for authentication. + * + * @param apiKey the API key (required) + */ public Builder apiKey(String apiKey) { this.apiKey = apiKey; return this; } + /** + * Sets the name of the OpenAI model to use. + * + * @param modelName the model name, e.g., "gpt-4" (required) + */ public Builder modelName(String modelName) { this.modelName = modelName; return this; } - + /** + * Sets the instruction prompt for the model. + * This prompt provides context or instructions to the model about how to process + * the input text. + * + * @param prompt the instruction text (required) + */ public Builder instructionPrompt(String prompt) { this.instructionPrompt = prompt; return this; } + /** + * Builds the {@link OpenAIModelParameters} instance. + */ public OpenAIModelParameters build() { return new OpenAIModelParameters(this); } diff --git a/sdks/java/ml/remoteinference/src/main/java/org/apache/beam/sdk/ml/remoteinference/openai/OpenAIModelResponse.java b/sdks/java/ml/remoteinference/src/main/java/org/apache/beam/sdk/ml/remoteinference/openai/OpenAIModelResponse.java index e37543ac1a10..abfd9d7cb5c3 100644 --- a/sdks/java/ml/remoteinference/src/main/java/org/apache/beam/sdk/ml/remoteinference/openai/OpenAIModelResponse.java +++ b/sdks/java/ml/remoteinference/src/main/java/org/apache/beam/sdk/ml/remoteinference/openai/OpenAIModelResponse.java @@ -19,6 +19,19 @@ import org.apache.beam.sdk.ml.remoteinference.base.BaseResponse; +/** + * Response from OpenAI model inference results. + *

This class encapsulates the text output returned from OpenAI models.. + * + *

Example Usage

+ *
{@code
+ * OpenAIModelResponse response = OpenAIModelResponse.create("Bonjour");
+ * String output = response.getModelResponse(); // "Bonjour"
+ * }
+ * + * @see OpenAIModelHandler + * @see OpenAIModelInput + */ public class OpenAIModelResponse extends BaseResponse { private final String output; @@ -27,10 +40,21 @@ private OpenAIModelResponse(String output) { this.output = output; } + /** + * Returns the text output from the model. + * + * @return the output text string + */ public String getModelResponse() { return output; } + /** + * Creates a new response instance with the specified output text. + * + * @param output the text returned by the model + * @return a new {@link OpenAIModelResponse} instance + */ public static OpenAIModelResponse create(String output) { return new OpenAIModelResponse(output); } From b3a2c6727f10f2b309e95549a9f8f54ae25dba3a Mon Sep 17 00:00:00 2001 From: Ganeshsivakumar Date: Tue, 11 Nov 2025 21:25:50 +0530 Subject: [PATCH 09/12] api docs --- sdks/java/ml/build.gradle | 2 -- 1 file changed, 2 deletions(-) diff --git a/sdks/java/ml/build.gradle b/sdks/java/ml/build.gradle index 7b6b071fa2a1..b18480559e14 100644 --- a/sdks/java/ml/build.gradle +++ b/sdks/java/ml/build.gradle @@ -27,6 +27,4 @@ description = "Apache Beam :: SDKs :: Java :: ML" ext.summary = "Java ML module" dependencies { - - } From 95e484135932f29df242aecdda781ef42f74c53e Mon Sep 17 00:00:00 2001 From: Ganeshsivakumar Date: Wed, 19 Nov 2025 15:57:17 +0530 Subject: [PATCH 10/12] retries --- sdks/java/ml/remoteinference/build.gradle.kts | 1 + .../ml/remoteinference/RemoteInference.java | 13 ++- .../sdk/ml/remoteinference/RetryHandler.java | 103 ++++++++++++++++++ .../openai/OpenAIModelHandlerIT.java | 58 ++++++++-- 4 files changed, 161 insertions(+), 14 deletions(-) create mode 100644 sdks/java/ml/remoteinference/src/main/java/org/apache/beam/sdk/ml/remoteinference/RetryHandler.java diff --git a/sdks/java/ml/remoteinference/build.gradle.kts b/sdks/java/ml/remoteinference/build.gradle.kts index 98e3ce620bb9..776be5becb63 100644 --- a/sdks/java/ml/remoteinference/build.gradle.kts +++ b/sdks/java/ml/remoteinference/build.gradle.kts @@ -35,6 +35,7 @@ dependencies { implementation("org.apache.beam:beam-vendor-guava-32_1_2-jre:0.1") implementation("org.slf4j:slf4j-api:2.0.9") implementation("org.slf4j:slf4j-simple:2.0.9") + implementation("joda-time:joda-time:2.11.1") // testing testImplementation(project(":runners:direct-java")) diff --git a/sdks/java/ml/remoteinference/src/main/java/org/apache/beam/sdk/ml/remoteinference/RemoteInference.java b/sdks/java/ml/remoteinference/src/main/java/org/apache/beam/sdk/ml/remoteinference/RemoteInference.java index 6a8b9b656bdf..934850575ac7 100644 --- a/sdks/java/ml/remoteinference/src/main/java/org/apache/beam/sdk/ml/remoteinference/RemoteInference.java +++ b/sdks/java/ml/remoteinference/src/main/java/org/apache/beam/sdk/ml/remoteinference/RemoteInference.java @@ -139,19 +139,21 @@ static class RemoteInferenceFn handlerClass; private final BaseModelParameters parameters; - private transient BaseModelHandler handler; + private transient BaseModelHandler modelHandler; + private final RetryHandler retryHandler; RemoteInferenceFn(Invoke spec) { this.handlerClass = spec.handler(); this.parameters = spec.parameters(); + retryHandler = RetryHandler.withDefaults(); } /** Instantiate the model handler and client*/ @Setup public void setupHandler() { try { - this.handler = handlerClass.getDeclaredConstructor().newInstance(); - this.handler.createClient(parameters); + this.modelHandler = handlerClass.getDeclaredConstructor().newInstance(); + this.modelHandler.createClient(parameters); } catch (Exception e) { throw new RuntimeException("Failed to instantiate handler: " + handlerClass.getName(), e); @@ -159,8 +161,9 @@ public void setupHandler() { } /** Perform Inference */ @ProcessElement - public void processElement(ProcessContext c) { - Iterable> response = this.handler.request(c.element()); + public void processElement(ProcessContext c) throws Exception { + Iterable> response = retryHandler + .execute(() -> modelHandler.request(c.element())); c.output(response); } } diff --git a/sdks/java/ml/remoteinference/src/main/java/org/apache/beam/sdk/ml/remoteinference/RetryHandler.java b/sdks/java/ml/remoteinference/src/main/java/org/apache/beam/sdk/ml/remoteinference/RetryHandler.java new file mode 100644 index 000000000000..4371d81dd9d7 --- /dev/null +++ b/sdks/java/ml/remoteinference/src/main/java/org/apache/beam/sdk/ml/remoteinference/RetryHandler.java @@ -0,0 +1,103 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * License); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an AS IS BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.sdk.ml.remoteinference; + +import org.apache.beam.sdk.util.BackOff; +import org.apache.beam.sdk.util.FluentBackoff; +import org.apache.beam.sdk.util.Sleeper; +import org.joda.time.Duration; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.io.Serializable; + +/** + * A utility for running request and handle failures and retries. + */ +public class RetryHandler implements Serializable { + + private static final Logger LOG = LoggerFactory.getLogger(RetryHandler.class); + + private final int maxRetries; + private final Duration initialBackoff; + private final Duration maxBackoff; + private final Duration maxCumulativeBackoff; + + private RetryHandler( + int maxRetries, + Duration initialBackoff, + Duration maxBackoff, + Duration maxCumulativeBackoff) { + this.maxRetries = maxRetries; + this.initialBackoff = initialBackoff; + this.maxBackoff = maxBackoff; + this.maxCumulativeBackoff = maxCumulativeBackoff; + } + + public static RetryHandler withDefaults() { + return new RetryHandler( + 3, // maxRetries + Duration.standardSeconds(1), // initialBackoff + Duration.standardSeconds(10), // maxBackoff per retry + Duration.standardMinutes(1) // maxCumulativeBackoff + ); + } + + public T execute(RetryableRequest request) throws Exception { + BackOff backoff = FluentBackoff.DEFAULT + .withMaxRetries(maxRetries) + .withInitialBackoff(initialBackoff) + .withMaxBackoff(maxBackoff) + .withMaxCumulativeBackoff(maxCumulativeBackoff) + .backoff(); + + Sleeper sleeper = Sleeper.DEFAULT; + Exception lastException; + int attempt = 1; + + while (true) { + try { + return request.call(); + + } catch (Exception e) { + lastException = e; + + long backoffMillis = backoff.nextBackOffMillis(); + + if (backoffMillis == BackOff.STOP) { + LOG.error("Request failed after {} retry attempts.", attempt); + throw new RuntimeException( + "Request failed after exhausting retries. " + + "Max retries: " + maxRetries + ", " , + lastException); + } + + LOG.warn("Retry request attempt {} failed with: {}. Retrying in {} ms", attempt, e.getMessage(), backoffMillis); + + attempt++; + sleeper.sleep(backoffMillis); + } + } + } + + @FunctionalInterface + public interface RetryableRequest { + + T call() throws Exception; + } +} diff --git a/sdks/java/ml/remoteinference/src/test/java/org/apache/beam/sdk/ml/remoteinference/openai/OpenAIModelHandlerIT.java b/sdks/java/ml/remoteinference/src/test/java/org/apache/beam/sdk/ml/remoteinference/openai/OpenAIModelHandlerIT.java index fb24b090cb68..805ff5ad4909 100644 --- a/sdks/java/ml/remoteinference/src/test/java/org/apache/beam/sdk/ml/remoteinference/openai/OpenAIModelHandlerIT.java +++ b/sdks/java/ml/remoteinference/src/test/java/org/apache/beam/sdk/ml/remoteinference/openai/OpenAIModelHandlerIT.java @@ -306,18 +306,20 @@ public void testWithInvalidApiKey() { .instructionPrompt("Test") .build())); - // Expect pipeline to fail with authentication error try { pipeline.run().waitUntilFinish(); - fail("Expected pipeline to fail with invalid API key"); + fail("Expected pipeline failure due to invalid API key"); } catch (Exception e) { - // Expected - verify it's an authentication error - String message = e.getMessage().toLowerCase(); - assertTrue("Exception should mention authentication or API key issue, got: " + message, - message.contains("auth") || - message.contains("api") || - message.contains("key") || - message.contains("401")); + String msg = e.toString().toLowerCase(); + + assertTrue( + "Expected retry exhaustion or API key issue. Got: " + msg, + msg.contains("exhaust") || + msg.contains("max retries") || + msg.contains("401") || + msg.contains("api key") || + msg.contains("incorrect api key") + ); } } @@ -363,4 +365,42 @@ public void testWithJsonOutputFormat() { pipeline.run().waitUntilFinish(); } + + @Test + public void testRetryWithInvalidModel() { + + PCollection inputs = + pipeline + .apply("CreateInput", Create.of("Test input")) + .apply("MapToInput", + MapElements.into(TypeDescriptor.of(OpenAIModelInput.class)) + .via(OpenAIModelInput::create)); + + inputs.apply( + "FailingOpenAIRequest", + RemoteInference.invoke() + .handler(OpenAIModelHandler.class) + .withParameters( + OpenAIModelParameters.builder() + .apiKey(apiKey) + .modelName("fake-model") + .instructionPrompt("test retry") + .build())); + + try { + pipeline.run().waitUntilFinish(); + fail("Pipeline should fail after retry exhaustion."); + } catch (Exception e) { + String message = e.getMessage().toLowerCase(); + + assertTrue( + "Expected retry-exhaustion error. Actual: " + message, + message.contains("exhaust") || + message.contains("retry") || + message.contains("max retries") || + message.contains("request failed") || + message.contains("fake-model")); + } + } + } From 5eeff773f05aabd4220159ea74491506eed4608c Mon Sep 17 00:00:00 2001 From: Ganeshsivakumar Date: Fri, 21 Nov 2025 19:09:12 +0530 Subject: [PATCH 11/12] refactor inference module --- sdks/java/ml/inference/build.gradle | 31 ++++++++++++++++ .../java/ml/inference/openai/build.gradle.kts | 36 +++++++++++++++++++ .../inference}/openai/OpenAIModelHandler.java | 15 ++++---- .../inference}/openai/OpenAIModelInput.java | 5 ++- .../openai/OpenAIModelParameters.java | 4 +-- .../openai/OpenAIModelResponse.java | 4 +-- .../openai/OpenAIModelHandlerIT.java | 10 ++---- .../openai/OpenAIModelHandlerTest.java | 8 ++--- .../remote}/build.gradle.kts | 9 +---- .../sdk/ml/inference/remote}/BaseInput.java | 2 +- .../inference/remote}/BaseModelHandler.java | 2 +- .../remote}/BaseModelParameters.java | 2 +- .../ml/inference/remote}/BaseResponse.java | 2 +- .../inference/remote}/PredictionResult.java | 2 +- .../ml/inference/remote}/RemoteInference.java | 4 +-- .../ml/inference/remote}/RetryHandler.java | 6 ++-- .../remote}/RemoteInferenceTest.java | 22 +++++++++--- settings.gradle.kts | 5 ++- 18 files changed, 117 insertions(+), 52 deletions(-) create mode 100644 sdks/java/ml/inference/build.gradle create mode 100644 sdks/java/ml/inference/openai/build.gradle.kts rename sdks/java/ml/{remoteinference/src/main/java/org/apache/beam/sdk/ml/remoteinference => inference/openai/src/main/java/org/apache/beam/sdk/ml/inference}/openai/OpenAIModelHandler.java (91%) rename sdks/java/ml/{remoteinference/src/main/java/org/apache/beam/sdk/ml/remoteinference => inference/openai/src/main/java/org/apache/beam/sdk/ml/inference}/openai/OpenAIModelInput.java (93%) rename sdks/java/ml/{remoteinference/src/main/java/org/apache/beam/sdk/ml/remoteinference => inference/openai/src/main/java/org/apache/beam/sdk/ml/inference}/openai/OpenAIModelParameters.java (96%) rename sdks/java/ml/{remoteinference/src/main/java/org/apache/beam/sdk/ml/remoteinference => inference/openai/src/main/java/org/apache/beam/sdk/ml/inference}/openai/OpenAIModelResponse.java (93%) rename sdks/java/ml/{remoteinference/src/test/java/org/apache/beam/sdk/ml/remoteinference => inference/openai/src/test/java/org/apache/beam/sdk/ml/inference}/openai/OpenAIModelHandlerIT.java (97%) rename sdks/java/ml/{remoteinference/src/test/java/org/apache/beam/sdk/ml/remoteinference => inference/openai/src/test/java/org/apache/beam/sdk/ml/inference}/openai/OpenAIModelHandlerTest.java (98%) rename sdks/java/ml/{remoteinference => inference/remote}/build.gradle.kts (81%) rename sdks/java/ml/{remoteinference/src/main/java/org/apache/beam/sdk/ml/remoteinference/base => inference/remote/src/main/java/org/apache/beam/sdk/ml/inference/remote}/BaseInput.java (95%) rename sdks/java/ml/{remoteinference/src/main/java/org/apache/beam/sdk/ml/remoteinference/base => inference/remote/src/main/java/org/apache/beam/sdk/ml/inference/remote}/BaseModelHandler.java (98%) rename sdks/java/ml/{remoteinference/src/main/java/org/apache/beam/sdk/ml/remoteinference/base => inference/remote/src/main/java/org/apache/beam/sdk/ml/inference/remote}/BaseModelParameters.java (96%) rename sdks/java/ml/{remoteinference/src/main/java/org/apache/beam/sdk/ml/remoteinference/base => inference/remote/src/main/java/org/apache/beam/sdk/ml/inference/remote}/BaseResponse.java (95%) rename sdks/java/ml/{remoteinference/src/main/java/org/apache/beam/sdk/ml/remoteinference/base => inference/remote/src/main/java/org/apache/beam/sdk/ml/inference/remote}/PredictionResult.java (96%) rename sdks/java/ml/{remoteinference/src/main/java/org/apache/beam/sdk/ml/remoteinference => inference/remote/src/main/java/org/apache/beam/sdk/ml/inference/remote}/RemoteInference.java (98%) rename sdks/java/ml/{remoteinference/src/main/java/org/apache/beam/sdk/ml/remoteinference => inference/remote/src/main/java/org/apache/beam/sdk/ml/inference/remote}/RetryHandler.java (97%) rename sdks/java/ml/{remoteinference/src/test/java/org/apache/beam/sdk/ml/remoteinference => inference/remote/src/test/java/org/apache/beam/sdk/ml/inference/remote}/RemoteInferenceTest.java (97%) diff --git a/sdks/java/ml/inference/build.gradle b/sdks/java/ml/inference/build.gradle new file mode 100644 index 000000000000..e1431d472779 --- /dev/null +++ b/sdks/java/ml/inference/build.gradle @@ -0,0 +1,31 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * License); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an AS IS BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +plugins { + id("org.apache.beam.module") +} +applyJavaNature( + automaticModuleName: 'org.apache.beam.sdk.ml.inference' +) +provideIntegrationTestingDependencies() +enableJavaPerformanceTesting() + +description = "Apache Beam :: SDKs :: Java :: ML :: Inference" +ext.summary = "Java ML inference module" + +dependencies { +} diff --git a/sdks/java/ml/inference/openai/build.gradle.kts b/sdks/java/ml/inference/openai/build.gradle.kts new file mode 100644 index 000000000000..a05dba16aea4 --- /dev/null +++ b/sdks/java/ml/inference/openai/build.gradle.kts @@ -0,0 +1,36 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * License); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an AS IS BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +plugins { + id("org.apache.beam.module") + id("java") +} +description = "Apache Beam :: SDKs :: Java :: ML :: Inference :: OpenAI" + +dependencies { + implementation(project(":sdks:java:ml:inference:remote")) + implementation(project(":sdks:java:core")) + + implementation("com.openai:openai-java:4.3.0") + implementation("com.fasterxml.jackson.core:jackson-core:2.20.0") + + testImplementation(project(":runners:direct-java")) + testImplementation("org.slf4j:slf4j-simple:2.0.9") + testImplementation("org.slf4j:slf4j-api:2.0.9") + testImplementation("junit:junit:4.13.2") + testImplementation(project(":sdks:java:testing:test-utils")) +} diff --git a/sdks/java/ml/remoteinference/src/main/java/org/apache/beam/sdk/ml/remoteinference/openai/OpenAIModelHandler.java b/sdks/java/ml/inference/openai/src/main/java/org/apache/beam/sdk/ml/inference/openai/OpenAIModelHandler.java similarity index 91% rename from sdks/java/ml/remoteinference/src/main/java/org/apache/beam/sdk/ml/remoteinference/openai/OpenAIModelHandler.java rename to sdks/java/ml/inference/openai/src/main/java/org/apache/beam/sdk/ml/inference/openai/OpenAIModelHandler.java index 87616ee693d1..5d37b43bde6c 100644 --- a/sdks/java/ml/remoteinference/src/main/java/org/apache/beam/sdk/ml/remoteinference/openai/OpenAIModelHandler.java +++ b/sdks/java/ml/inference/openai/src/main/java/org/apache/beam/sdk/ml/inference/openai/OpenAIModelHandler.java @@ -15,7 +15,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.apache.beam.sdk.ml.remoteinference.openai; +package org.apache.beam.sdk.ml.inference.openai; import com.fasterxml.jackson.annotation.JsonProperty; import com.fasterxml.jackson.annotation.JsonPropertyDescription; @@ -26,8 +26,8 @@ import com.openai.core.JsonSchemaLocalValidation; import com.openai.models.responses.ResponseCreateParams; import com.openai.models.responses.StructuredResponseCreateParams; -import org.apache.beam.sdk.ml.remoteinference.base.BaseModelHandler; -import org.apache.beam.sdk.ml.remoteinference.base.PredictionResult; +import org.apache.beam.sdk.ml.inference.remote.BaseModelHandler; +import org.apache.beam.sdk.ml.inference.remote.PredictionResult; import java.util.List; import java.util.stream.Collectors; @@ -61,7 +61,6 @@ public class OpenAIModelHandler implements BaseModelHandler { private transient OpenAIClient client; - private transient StructuredResponseCreateParams clientParams; private OpenAIModelParameters modelParameters; /** @@ -99,7 +98,7 @@ public Iterable> request .writeValueAsString(input.stream().map(OpenAIModelInput::getModelInput).toList()); // Build structured response parameters - this.clientParams = ResponseCreateParams.builder() + StructuredResponseCreateParams clientParams = ResponseCreateParams.builder() .model(modelParameters.getModelName()) .input(inputBatch) .text(StructuredInputOutput.class, JsonSchemaLocalValidation.NO) @@ -121,15 +120,13 @@ public Iterable> request throw new RuntimeException("Model returned no structured responses"); } - // Map responses to PredictionResults - List> results = structuredOutput.responses.stream() + // return PredictionResults + return structuredOutput.responses.stream() .map(response -> PredictionResult.create( OpenAIModelInput.create(response.input), OpenAIModelResponse.create(response.output))) .collect(Collectors.toList()); - return results; - } catch (JsonProcessingException e) { throw new RuntimeException("Failed to serialize input batch", e); } diff --git a/sdks/java/ml/remoteinference/src/main/java/org/apache/beam/sdk/ml/remoteinference/openai/OpenAIModelInput.java b/sdks/java/ml/inference/openai/src/main/java/org/apache/beam/sdk/ml/inference/openai/OpenAIModelInput.java similarity index 93% rename from sdks/java/ml/remoteinference/src/main/java/org/apache/beam/sdk/ml/remoteinference/openai/OpenAIModelInput.java rename to sdks/java/ml/inference/openai/src/main/java/org/apache/beam/sdk/ml/inference/openai/OpenAIModelInput.java index 1ef59c89da66..2bb33f467598 100644 --- a/sdks/java/ml/remoteinference/src/main/java/org/apache/beam/sdk/ml/remoteinference/openai/OpenAIModelInput.java +++ b/sdks/java/ml/inference/openai/src/main/java/org/apache/beam/sdk/ml/inference/openai/OpenAIModelInput.java @@ -15,10 +15,9 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.apache.beam.sdk.ml.remoteinference.openai; - -import org.apache.beam.sdk.ml.remoteinference.base.BaseInput; +package org.apache.beam.sdk.ml.inference.openai; +import org.apache.beam.sdk.ml.inference.remote.BaseInput; /** * Input for OpenAI model inference requests. * diff --git a/sdks/java/ml/remoteinference/src/main/java/org/apache/beam/sdk/ml/remoteinference/openai/OpenAIModelParameters.java b/sdks/java/ml/inference/openai/src/main/java/org/apache/beam/sdk/ml/inference/openai/OpenAIModelParameters.java similarity index 96% rename from sdks/java/ml/remoteinference/src/main/java/org/apache/beam/sdk/ml/remoteinference/openai/OpenAIModelParameters.java rename to sdks/java/ml/inference/openai/src/main/java/org/apache/beam/sdk/ml/inference/openai/OpenAIModelParameters.java index 3aac4112cba3..2b2b04dfa94b 100644 --- a/sdks/java/ml/remoteinference/src/main/java/org/apache/beam/sdk/ml/remoteinference/openai/OpenAIModelParameters.java +++ b/sdks/java/ml/inference/openai/src/main/java/org/apache/beam/sdk/ml/inference/openai/OpenAIModelParameters.java @@ -15,9 +15,9 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.apache.beam.sdk.ml.remoteinference.openai; +package org.apache.beam.sdk.ml.inference.openai; -import org.apache.beam.sdk.ml.remoteinference.base.BaseModelParameters; +import org.apache.beam.sdk.ml.inference.remote.BaseModelParameters; /** * Configuration parameters required for OpenAI model inference. diff --git a/sdks/java/ml/remoteinference/src/main/java/org/apache/beam/sdk/ml/remoteinference/openai/OpenAIModelResponse.java b/sdks/java/ml/inference/openai/src/main/java/org/apache/beam/sdk/ml/inference/openai/OpenAIModelResponse.java similarity index 93% rename from sdks/java/ml/remoteinference/src/main/java/org/apache/beam/sdk/ml/remoteinference/openai/OpenAIModelResponse.java rename to sdks/java/ml/inference/openai/src/main/java/org/apache/beam/sdk/ml/inference/openai/OpenAIModelResponse.java index abfd9d7cb5c3..7bb3bc075ef4 100644 --- a/sdks/java/ml/remoteinference/src/main/java/org/apache/beam/sdk/ml/remoteinference/openai/OpenAIModelResponse.java +++ b/sdks/java/ml/inference/openai/src/main/java/org/apache/beam/sdk/ml/inference/openai/OpenAIModelResponse.java @@ -15,9 +15,9 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.apache.beam.sdk.ml.remoteinference.openai; +package org.apache.beam.sdk.ml.inference.openai; -import org.apache.beam.sdk.ml.remoteinference.base.BaseResponse; +import org.apache.beam.sdk.ml.inference.remote.BaseResponse; /** * Response from OpenAI model inference results. diff --git a/sdks/java/ml/remoteinference/src/test/java/org/apache/beam/sdk/ml/remoteinference/openai/OpenAIModelHandlerIT.java b/sdks/java/ml/inference/openai/src/test/java/org/apache/beam/sdk/ml/inference/openai/OpenAIModelHandlerIT.java similarity index 97% rename from sdks/java/ml/remoteinference/src/test/java/org/apache/beam/sdk/ml/remoteinference/openai/OpenAIModelHandlerIT.java rename to sdks/java/ml/inference/openai/src/test/java/org/apache/beam/sdk/ml/inference/openai/OpenAIModelHandlerIT.java index 805ff5ad4909..ba03bce86988 100644 --- a/sdks/java/ml/remoteinference/src/test/java/org/apache/beam/sdk/ml/remoteinference/openai/OpenAIModelHandlerIT.java +++ b/sdks/java/ml/inference/openai/src/test/java/org/apache/beam/sdk/ml/inference/openai/OpenAIModelHandlerIT.java @@ -15,7 +15,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.apache.beam.sdk.ml.remoteinference.openai; +package org.apache.beam.sdk.ml.inference.openai; import java.util.ArrayList; import java.util.Arrays; @@ -37,15 +37,11 @@ import org.junit.Before; import org.junit.Rule; import org.junit.Test; -import org.junit.runner.RunWith; -import org.junit.runners.JUnit4; import org.slf4j.Logger; import org.slf4j.LoggerFactory; -import org.apache.beam.sdk.ml.remoteinference.base.*; -import org.apache.beam.sdk.ml.remoteinference.RemoteInference; -import org.apache.beam.sdk.ml.remoteinference.openai.OpenAIModelHandler.StructuredInputOutput; -import org.apache.beam.sdk.ml.remoteinference.openai.OpenAIModelHandler.Response; +import org.apache.beam.sdk.ml.inference.remote.RemoteInference; +import org.apache.beam.sdk.ml.inference.remote.PredictionResult; public class OpenAIModelHandlerIT { private static final Logger LOG = LoggerFactory.getLogger(OpenAIModelHandlerIT.class); diff --git a/sdks/java/ml/remoteinference/src/test/java/org/apache/beam/sdk/ml/remoteinference/openai/OpenAIModelHandlerTest.java b/sdks/java/ml/inference/openai/src/test/java/org/apache/beam/sdk/ml/inference/openai/OpenAIModelHandlerTest.java similarity index 98% rename from sdks/java/ml/remoteinference/src/test/java/org/apache/beam/sdk/ml/remoteinference/openai/OpenAIModelHandlerTest.java rename to sdks/java/ml/inference/openai/src/test/java/org/apache/beam/sdk/ml/inference/openai/OpenAIModelHandlerTest.java index 3bff5600aa45..0250c559fe65 100644 --- a/sdks/java/ml/remoteinference/src/test/java/org/apache/beam/sdk/ml/remoteinference/openai/OpenAIModelHandlerTest.java +++ b/sdks/java/ml/inference/openai/src/test/java/org/apache/beam/sdk/ml/inference/openai/OpenAIModelHandlerTest.java @@ -15,7 +15,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.apache.beam.sdk.ml.remoteinference.openai; +package org.apache.beam.sdk.ml.inference.openai; import java.util.Arrays; import java.util.Collections; @@ -31,9 +31,9 @@ import org.junit.runner.RunWith; import org.junit.runners.JUnit4; -import org.apache.beam.sdk.ml.remoteinference.base.*; -import org.apache.beam.sdk.ml.remoteinference.openai.OpenAIModelHandler.StructuredInputOutput; -import org.apache.beam.sdk.ml.remoteinference.openai.OpenAIModelHandler.Response; +import org.apache.beam.sdk.ml.inference.openai.OpenAIModelHandler.StructuredInputOutput; +import org.apache.beam.sdk.ml.inference.openai.OpenAIModelHandler.Response; +import org.apache.beam.sdk.ml.inference.remote.PredictionResult; diff --git a/sdks/java/ml/remoteinference/build.gradle.kts b/sdks/java/ml/inference/remote/build.gradle.kts similarity index 81% rename from sdks/java/ml/remoteinference/build.gradle.kts rename to sdks/java/ml/inference/remote/build.gradle.kts index 776be5becb63..607f51819795 100644 --- a/sdks/java/ml/remoteinference/build.gradle.kts +++ b/sdks/java/ml/inference/remote/build.gradle.kts @@ -15,23 +15,20 @@ * See the License for the specific language governing permissions and * limitations under the License. */ - plugins { id("org.apache.beam.module") id("java-library") } -description = "Apache Beam :: SDKs :: Java :: ML :: RemoteInference" +description = "Apache Beam :: SDKs :: Java :: ML :: Inference :: Remote" dependencies { // Core Beam SDK implementation(project(":sdks:java:core")) - implementation("com.openai:openai-java:4.3.0") compileOnly("com.google.auto.value:auto-value-annotations:1.11.0") compileOnly("org.checkerframework:checker-qual:3.42.0") annotationProcessor("com.google.auto.value:auto-value:1.11.0") - implementation("com.fasterxml.jackson.core:jackson-core:2.20.0") implementation("org.apache.beam:beam-vendor-guava-32_1_2-jre:0.1") implementation("org.slf4j:slf4j-api:2.0.9") implementation("org.slf4j:slf4j-simple:2.0.9") @@ -39,10 +36,6 @@ dependencies { // testing testImplementation(project(":runners:direct-java")) - testImplementation("org.hamcrest:hamcrest:2.2") - testImplementation("org.mockito:mockito-core:5.8.0") - testImplementation("org.mockito:mockito-junit-jupiter:5.8.0") testImplementation("junit:junit:4.13.2") testImplementation(project(":sdks:java:testing:test-utils")) } - diff --git a/sdks/java/ml/remoteinference/src/main/java/org/apache/beam/sdk/ml/remoteinference/base/BaseInput.java b/sdks/java/ml/inference/remote/src/main/java/org/apache/beam/sdk/ml/inference/remote/BaseInput.java similarity index 95% rename from sdks/java/ml/remoteinference/src/main/java/org/apache/beam/sdk/ml/remoteinference/base/BaseInput.java rename to sdks/java/ml/inference/remote/src/main/java/org/apache/beam/sdk/ml/inference/remote/BaseInput.java index 0a84a67fdb4c..a8503d8a89a8 100644 --- a/sdks/java/ml/remoteinference/src/main/java/org/apache/beam/sdk/ml/remoteinference/base/BaseInput.java +++ b/sdks/java/ml/inference/remote/src/main/java/org/apache/beam/sdk/ml/inference/remote/BaseInput.java @@ -15,7 +15,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.apache.beam.sdk.ml.remoteinference.base; +package org.apache.beam.sdk.ml.inference.remote; import java.io.Serializable; diff --git a/sdks/java/ml/remoteinference/src/main/java/org/apache/beam/sdk/ml/remoteinference/base/BaseModelHandler.java b/sdks/java/ml/inference/remote/src/main/java/org/apache/beam/sdk/ml/inference/remote/BaseModelHandler.java similarity index 98% rename from sdks/java/ml/remoteinference/src/main/java/org/apache/beam/sdk/ml/remoteinference/base/BaseModelHandler.java rename to sdks/java/ml/inference/remote/src/main/java/org/apache/beam/sdk/ml/inference/remote/BaseModelHandler.java index 1128a287d927..314aec34cf9b 100644 --- a/sdks/java/ml/remoteinference/src/main/java/org/apache/beam/sdk/ml/remoteinference/base/BaseModelHandler.java +++ b/sdks/java/ml/inference/remote/src/main/java/org/apache/beam/sdk/ml/inference/remote/BaseModelHandler.java @@ -15,7 +15,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.apache.beam.sdk.ml.remoteinference.base; +package org.apache.beam.sdk.ml.inference.remote; import java.util.List; diff --git a/sdks/java/ml/remoteinference/src/main/java/org/apache/beam/sdk/ml/remoteinference/base/BaseModelParameters.java b/sdks/java/ml/inference/remote/src/main/java/org/apache/beam/sdk/ml/inference/remote/BaseModelParameters.java similarity index 96% rename from sdks/java/ml/remoteinference/src/main/java/org/apache/beam/sdk/ml/remoteinference/base/BaseModelParameters.java rename to sdks/java/ml/inference/remote/src/main/java/org/apache/beam/sdk/ml/inference/remote/BaseModelParameters.java index 46ee72c73001..f285377da977 100644 --- a/sdks/java/ml/remoteinference/src/main/java/org/apache/beam/sdk/ml/remoteinference/base/BaseModelParameters.java +++ b/sdks/java/ml/inference/remote/src/main/java/org/apache/beam/sdk/ml/inference/remote/BaseModelParameters.java @@ -15,7 +15,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.apache.beam.sdk.ml.remoteinference.base; +package org.apache.beam.sdk.ml.inference.remote; import java.io.Serializable; diff --git a/sdks/java/ml/remoteinference/src/main/java/org/apache/beam/sdk/ml/remoteinference/base/BaseResponse.java b/sdks/java/ml/inference/remote/src/main/java/org/apache/beam/sdk/ml/inference/remote/BaseResponse.java similarity index 95% rename from sdks/java/ml/remoteinference/src/main/java/org/apache/beam/sdk/ml/remoteinference/base/BaseResponse.java rename to sdks/java/ml/inference/remote/src/main/java/org/apache/beam/sdk/ml/inference/remote/BaseResponse.java index b3c050c45c11..f47858cfebca 100644 --- a/sdks/java/ml/remoteinference/src/main/java/org/apache/beam/sdk/ml/remoteinference/base/BaseResponse.java +++ b/sdks/java/ml/inference/remote/src/main/java/org/apache/beam/sdk/ml/inference/remote/BaseResponse.java @@ -15,7 +15,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.apache.beam.sdk.ml.remoteinference.base; +package org.apache.beam.sdk.ml.inference.remote; import java.io.Serializable; diff --git a/sdks/java/ml/remoteinference/src/main/java/org/apache/beam/sdk/ml/remoteinference/base/PredictionResult.java b/sdks/java/ml/inference/remote/src/main/java/org/apache/beam/sdk/ml/inference/remote/PredictionResult.java similarity index 96% rename from sdks/java/ml/remoteinference/src/main/java/org/apache/beam/sdk/ml/remoteinference/base/PredictionResult.java rename to sdks/java/ml/inference/remote/src/main/java/org/apache/beam/sdk/ml/inference/remote/PredictionResult.java index edc30fd11246..bf1ae66127cf 100644 --- a/sdks/java/ml/remoteinference/src/main/java/org/apache/beam/sdk/ml/remoteinference/base/PredictionResult.java +++ b/sdks/java/ml/inference/remote/src/main/java/org/apache/beam/sdk/ml/inference/remote/PredictionResult.java @@ -15,7 +15,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.apache.beam.sdk.ml.remoteinference.base; +package org.apache.beam.sdk.ml.inference.remote; import java.io.Serializable; diff --git a/sdks/java/ml/remoteinference/src/main/java/org/apache/beam/sdk/ml/remoteinference/RemoteInference.java b/sdks/java/ml/inference/remote/src/main/java/org/apache/beam/sdk/ml/inference/remote/RemoteInference.java similarity index 98% rename from sdks/java/ml/remoteinference/src/main/java/org/apache/beam/sdk/ml/remoteinference/RemoteInference.java rename to sdks/java/ml/inference/remote/src/main/java/org/apache/beam/sdk/ml/inference/remote/RemoteInference.java index 934850575ac7..da9217bfd52e 100644 --- a/sdks/java/ml/remoteinference/src/main/java/org/apache/beam/sdk/ml/remoteinference/RemoteInference.java +++ b/sdks/java/ml/inference/remote/src/main/java/org/apache/beam/sdk/ml/inference/remote/RemoteInference.java @@ -15,16 +15,14 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.apache.beam.sdk.ml.remoteinference; +package org.apache.beam.sdk.ml.inference.remote; -import org.apache.beam.sdk.ml.remoteinference.base.*; import org.apache.beam.sdk.transforms.*; import org.checkerframework.checker.nullness.qual.Nullable; import static org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions.checkArgument; import org.apache.beam.sdk.values.PCollection; import com.google.auto.value.AutoValue; -import java.util.ArrayList; import java.util.Collections; import java.util.List; diff --git a/sdks/java/ml/remoteinference/src/main/java/org/apache/beam/sdk/ml/remoteinference/RetryHandler.java b/sdks/java/ml/inference/remote/src/main/java/org/apache/beam/sdk/ml/inference/remote/RetryHandler.java similarity index 97% rename from sdks/java/ml/remoteinference/src/main/java/org/apache/beam/sdk/ml/remoteinference/RetryHandler.java rename to sdks/java/ml/inference/remote/src/main/java/org/apache/beam/sdk/ml/inference/remote/RetryHandler.java index 4371d81dd9d7..27041d8cb237 100644 --- a/sdks/java/ml/remoteinference/src/main/java/org/apache/beam/sdk/ml/remoteinference/RetryHandler.java +++ b/sdks/java/ml/inference/remote/src/main/java/org/apache/beam/sdk/ml/inference/remote/RetryHandler.java @@ -15,7 +15,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.apache.beam.sdk.ml.remoteinference; +package org.apache.beam.sdk.ml.inference.remote; import org.apache.beam.sdk.util.BackOff; import org.apache.beam.sdk.util.FluentBackoff; @@ -68,7 +68,7 @@ public T execute(RetryableRequest request) throws Exception { Sleeper sleeper = Sleeper.DEFAULT; Exception lastException; - int attempt = 1; + int attempt = 0; while (true) { try { @@ -87,9 +87,9 @@ public T execute(RetryableRequest request) throws Exception { lastException); } + attempt++; LOG.warn("Retry request attempt {} failed with: {}. Retrying in {} ms", attempt, e.getMessage(), backoffMillis); - attempt++; sleeper.sleep(backoffMillis); } } diff --git a/sdks/java/ml/remoteinference/src/test/java/org/apache/beam/sdk/ml/remoteinference/RemoteInferenceTest.java b/sdks/java/ml/inference/remote/src/test/java/org/apache/beam/sdk/ml/inference/remote/RemoteInferenceTest.java similarity index 97% rename from sdks/java/ml/remoteinference/src/test/java/org/apache/beam/sdk/ml/remoteinference/RemoteInferenceTest.java rename to sdks/java/ml/inference/remote/src/test/java/org/apache/beam/sdk/ml/inference/remote/RemoteInferenceTest.java index 3f351b9f88a3..5b66cfab7923 100644 --- a/sdks/java/ml/remoteinference/src/test/java/org/apache/beam/sdk/ml/remoteinference/RemoteInferenceTest.java +++ b/sdks/java/ml/inference/remote/src/test/java/org/apache/beam/sdk/ml/inference/remote/RemoteInferenceTest.java @@ -15,7 +15,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.apache.beam.sdk.ml.remoteinference; +package org.apache.beam.sdk.ml.inference.remote; import java.util.Arrays; import java.util.Collections; @@ -24,11 +24,11 @@ import java.util.stream.StreamSupport; import org.apache.beam.sdk.coders.SerializableCoder; -import org.apache.beam.sdk.ml.remoteinference.base.*; import org.apache.beam.sdk.testing.PAssert; import org.apache.beam.sdk.testing.TestPipeline; import org.apache.beam.sdk.transforms.Create; import org.apache.beam.sdk.values.PCollection; + import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertNotNull; import static org.junit.Assert.assertThrows; @@ -263,6 +263,17 @@ public Iterable> request(List } } + private static boolean containsMessage(Throwable e, String message) { + Throwable current = e; + while (current != null) { + if (current.getMessage() != null && current.getMessage().contains(message)) { + return true; + } + current = current.getCause(); + } + return false; + } + @Test public void testInvokeWithSingleElement() { TestInput input = TestInput.create("test-value"); @@ -429,9 +440,10 @@ public void testHandlerRequestFailure() { pipeline.run().waitUntilFinish(); fail("Expected pipeline to fail due to request failure"); } catch (Exception e) { - String message = e.getMessage(); - assertTrue("Exception should mention request failure", - message != null && message.contains("Request failed intentionally")); + + assertTrue( + "Expected 'Request failed intentionally' in exception chain", + containsMessage(e, "Request failed intentionally")); } } diff --git a/settings.gradle.kts b/settings.gradle.kts index fdfc5da6854c..92f340a7bd49 100644 --- a/settings.gradle.kts +++ b/settings.gradle.kts @@ -384,4 +384,7 @@ findProject(":sdks:java:extensions:sql:iceberg")?.name = "iceberg" include("examples:java:iceberg") findProject(":examples:java:iceberg")?.name = "iceberg" -include("sdks:java:ml:remoteinference") \ No newline at end of file +include("sdks:java:ml:remoteinference") +include("sdks:java:ml:inference") +include("sdks:java:ml:inference:remote") +include("sdks:java:ml:inference:openai") \ No newline at end of file From 5358fec73623113cfe329f159b7f1bcf9abab542 Mon Sep 17 00:00:00 2001 From: Ganeshsivakumar Date: Fri, 21 Nov 2025 19:25:26 +0530 Subject: [PATCH 12/12] handle json string --- .../beam/sdk/ml/inference/openai/OpenAIModelHandler.java | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/sdks/java/ml/inference/openai/src/main/java/org/apache/beam/sdk/ml/inference/openai/OpenAIModelHandler.java b/sdks/java/ml/inference/openai/src/main/java/org/apache/beam/sdk/ml/inference/openai/OpenAIModelHandler.java index 5d37b43bde6c..4c37df24a471 100644 --- a/sdks/java/ml/inference/openai/src/main/java/org/apache/beam/sdk/ml/inference/openai/OpenAIModelHandler.java +++ b/sdks/java/ml/inference/openai/src/main/java/org/apache/beam/sdk/ml/inference/openai/OpenAIModelHandler.java @@ -62,6 +62,7 @@ public class OpenAIModelHandler private transient OpenAIClient client; private OpenAIModelParameters modelParameters; + private transient ObjectMapper objectMapper; /** * Initializes the OpenAI client with the provided parameters. @@ -77,6 +78,7 @@ public void createClient(OpenAIModelParameters parameters) { this.client = OpenAIOkHttpClient.builder() .apiKey(this.modelParameters.getApiKey()) .build(); + this.objectMapper = new ObjectMapper(); } /** @@ -94,7 +96,7 @@ public Iterable> request try { // Convert input list to JSON string - String inputBatch = new ObjectMapper() + String inputBatch = objectMapper .writeValueAsString(input.stream().map(OpenAIModelInput::getModelInput).toList()); // Build structured response parameters