Skip to content

Commit 14087f8

Browse files
authored
core: Report marshaller error for uncompressed size too large back to the client 2 (#12477)
Code mostly as in #12360 but trying to use `handleInternalError()` / `cancel()` as suggested by @ejona86 Fixes #11246
1 parent 97953ca commit 14087f8

File tree

6 files changed

+129
-7
lines changed

6 files changed

+129
-7
lines changed
Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
/*
2+
* Copyright 2025 The gRPC Authors
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package io.grpc.internal;
18+
19+
import io.grpc.Status;
20+
21+
/**
22+
* Marker to be used for Status sent to {@link ServerStream#cancel(Status)} to signal that stream
23+
* should be closed by sending headers.
24+
*/
25+
public class CloseWithHeadersMarker extends Throwable {
26+
private static final long serialVersionUID = 0L;
27+
28+
@Override
29+
public synchronized Throwable fillInStackTrace() {
30+
return this;
31+
}
32+
}

core/src/main/java/io/grpc/internal/ServerCallImpl.java

Lines changed: 20 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -279,6 +279,17 @@ private void handleInternalError(Throwable internalError) {
279279
serverCallTracer.reportCallEnded(false); // error so always false
280280
}
281281

282+
/**
283+
* Close the {@link ServerStream} because parsing request message failed.
284+
* Similar to {@link #handleInternalError(Throwable)}.
285+
*/
286+
private void handleParseError(StatusRuntimeException parseError) {
287+
cancelled = true;
288+
log.log(Level.WARNING, "Cancelling the stream because of parse error", parseError);
289+
stream.cancel(parseError.getStatus().withCause(new CloseWithHeadersMarker()));
290+
serverCallTracer.reportCallEnded(false); // error so always false
291+
}
292+
282293
/**
283294
* All of these callbacks are assumed to called on an application thread, and the caller is
284295
* responsible for handling thrown exceptions.
@@ -327,18 +338,23 @@ private void messagesAvailableInternal(final MessageProducer producer) {
327338
return;
328339
}
329340

330-
InputStream message;
341+
InputStream message = null;
331342
try {
332343
while ((message = producer.next()) != null) {
344+
ReqT parsed;
333345
try {
334-
listener.onMessage(call.method.parseRequest(message));
335-
} catch (Throwable t) {
346+
parsed = call.method.parseRequest(message);
347+
} catch (StatusRuntimeException e) {
336348
GrpcUtil.closeQuietly(message);
337-
throw t;
349+
GrpcUtil.closeQuietly(producer);
350+
call.handleParseError(e);
351+
return;
338352
}
339353
message.close();
354+
listener.onMessage(parsed);
340355
}
341356
} catch (Throwable t) {
357+
GrpcUtil.closeQuietly(message);
342358
GrpcUtil.closeQuietly(producer);
343359
Throwables.throwIfUnchecked(t);
344360
throw new RuntimeException(t);

core/src/test/java/io/grpc/internal/ServerCallImplTest.java

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,9 +48,11 @@
4848
import io.grpc.SecurityLevel;
4949
import io.grpc.ServerCall;
5050
import io.grpc.Status;
51+
import io.grpc.StatusRuntimeException;
5152
import io.grpc.internal.ServerCallImpl.ServerStreamListenerImpl;
5253
import io.perfmark.PerfMark;
5354
import java.io.ByteArrayInputStream;
55+
import java.io.IOException;
5456
import java.io.InputStream;
5557
import java.io.InputStreamReader;
5658
import org.junit.Before;
@@ -69,6 +71,8 @@ public class ServerCallImplTest {
6971

7072
@Mock private ServerStream stream;
7173
@Mock private ServerCall.Listener<Long> callListener;
74+
@Mock private StreamListener.MessageProducer messageProducer;
75+
@Mock private InputStream message;
7276

7377
private final CallTracer serverCallTracer = CallTracer.getDefaultFactory().create();
7478
private ServerCallImpl<Long, Long> call;
@@ -493,6 +497,44 @@ public void streamListener_unexpectedRuntimeException() {
493497
assertThat(e).hasMessageThat().isEqualTo("unexpected exception");
494498
}
495499

500+
@Test
501+
public void streamListener_statusRuntimeException() throws IOException {
502+
MethodDescriptor<Long, Long> failingParseMethod = MethodDescriptor.<Long, Long>newBuilder()
503+
.setType(MethodType.UNARY)
504+
.setFullMethodName("service/method")
505+
.setRequestMarshaller(new LongMarshaller() {
506+
@Override
507+
public Long parse(InputStream stream) {
508+
throw new StatusRuntimeException(Status.RESOURCE_EXHAUSTED
509+
.withDescription("Decompressed gRPC message exceeds maximum size"));
510+
}
511+
})
512+
.setResponseMarshaller(new LongMarshaller())
513+
.build();
514+
515+
call = new ServerCallImpl<>(stream, failingParseMethod, requestHeaders, context,
516+
DecompressorRegistry.getDefaultInstance(), CompressorRegistry.getDefaultInstance(),
517+
serverCallTracer, PerfMark.createTag());
518+
519+
ServerStreamListenerImpl<Long> streamListener =
520+
new ServerCallImpl.ServerStreamListenerImpl<>(call, callListener, context);
521+
522+
when(messageProducer.next()).thenReturn(message, (InputStream) null);
523+
streamListener.messagesAvailable(messageProducer);
524+
ArgumentCaptor<Status> statusCaptor = ArgumentCaptor.forClass(Status.class);
525+
verify(stream).cancel(statusCaptor.capture());
526+
Status status = statusCaptor.getValue();
527+
assertEquals(Status.Code.RESOURCE_EXHAUSTED, status.getCode());
528+
assertEquals("Decompressed gRPC message exceeds maximum size", status.getDescription());
529+
530+
streamListener.halfClosed();
531+
verify(callListener, never()).onHalfClose();
532+
533+
when(messageProducer.next()).thenReturn(message, (InputStream) null);
534+
streamListener.messagesAvailable(messageProducer);
535+
verify(callListener, never()).onMessage(any());
536+
}
537+
496538
private static class LongMarshaller implements Marshaller<Long> {
497539
@Override
498540
public InputStream stream(Long value) {

interop-testing/src/main/java/io/grpc/testing/integration/AbstractInteropTest.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2024,7 +2024,7 @@ private void assertPayload(Payload expected, Payload actual) {
20242024
}
20252025
}
20262026

2027-
private static void assertCodeEquals(Status.Code expected, Status actual) {
2027+
protected static void assertCodeEquals(Status.Code expected, Status actual) {
20282028
assertWithMessage("Unexpected status: %s", actual).that(actual.getCode()).isEqualTo(expected);
20292029
}
20302030

interop-testing/src/test/java/io/grpc/testing/integration/TransportCompressionTest.java

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
package io.grpc.testing.integration;
1818

1919
import static org.junit.Assert.assertEquals;
20+
import static org.junit.Assert.assertThrows;
2021
import static org.junit.Assert.assertTrue;
2122

2223
import com.google.protobuf.ByteString;
@@ -37,6 +38,8 @@
3738
import io.grpc.ServerCall.Listener;
3839
import io.grpc.ServerCallHandler;
3940
import io.grpc.ServerInterceptor;
41+
import io.grpc.Status.Code;
42+
import io.grpc.StatusRuntimeException;
4043
import io.grpc.internal.GrpcUtil;
4144
import io.grpc.netty.InternalNettyChannelBuilder;
4245
import io.grpc.netty.InternalNettyServerBuilder;
@@ -53,7 +56,9 @@
5356
import java.io.OutputStream;
5457
import org.junit.Before;
5558
import org.junit.BeforeClass;
59+
import org.junit.Rule;
5660
import org.junit.Test;
61+
import org.junit.rules.TestName;
5762
import org.junit.runner.RunWith;
5863
import org.junit.runners.JUnit4;
5964

@@ -84,10 +89,16 @@ public static void registerCompressors() {
8489
compressors.register(Codec.Identity.NONE);
8590
}
8691

92+
@Rule
93+
public final TestName currentTest = new TestName();
94+
8795
@Override
8896
protected ServerBuilder<?> getServerBuilder() {
8997
NettyServerBuilder builder = NettyServerBuilder.forPort(0, InsecureServerCredentials.create())
90-
.maxInboundMessageSize(AbstractInteropTest.MAX_MESSAGE_SIZE)
98+
.maxInboundMessageSize(
99+
DECOMPRESSED_MESSAGE_TOO_LONG_METHOD_NAME.equals(currentTest.getMethodName())
100+
? 1000
101+
: AbstractInteropTest.MAX_MESSAGE_SIZE)
91102
.compressorRegistry(compressors)
92103
.decompressorRegistry(decompressors)
93104
.intercept(new ServerInterceptor() {
@@ -126,6 +137,22 @@ public void compresses() {
126137
assertTrue(FZIPPER.anyWritten);
127138
}
128139

140+
private static final String DECOMPRESSED_MESSAGE_TOO_LONG_METHOD_NAME =
141+
"decompressedMessageTooLong";
142+
143+
@Test
144+
public void decompressedMessageTooLong() {
145+
assertEquals(DECOMPRESSED_MESSAGE_TOO_LONG_METHOD_NAME, currentTest.getMethodName());
146+
final SimpleRequest bigRequest = SimpleRequest.newBuilder()
147+
.setPayload(Payload.newBuilder().setBody(ByteString.copyFrom(new byte[10_000])))
148+
.build();
149+
StatusRuntimeException e = assertThrows(StatusRuntimeException.class,
150+
() -> blockingStub.withCompression("gzip").unaryCall(bigRequest));
151+
assertCodeEquals(Code.RESOURCE_EXHAUSTED, e.getStatus());
152+
assertEquals("Decompressed gRPC message exceeds maximum size 1000",
153+
e.getStatus().getDescription());
154+
}
155+
129156
@Override
130157
protected NettyChannelBuilder createChannelBuilder() {
131158
NettyChannelBuilder builder = NettyChannelBuilder.forAddress(getListenAddress())

netty/src/main/java/io/grpc/netty/NettyServerStream.java

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
import io.grpc.Metadata;
2424
import io.grpc.Status;
2525
import io.grpc.internal.AbstractServerStream;
26+
import io.grpc.internal.CloseWithHeadersMarker;
2627
import io.grpc.internal.StatsTraceContext;
2728
import io.grpc.internal.TransportTracer;
2829
import io.grpc.internal.WritableBuffer;
@@ -130,7 +131,11 @@ public void writeTrailers(Metadata trailers, boolean headersSent, Status status)
130131
@Override
131132
public void cancel(Status status) {
132133
try (TaskCloseable ignore = PerfMark.traceTask("NettyServerStream$Sink.cancel")) {
133-
writeQueue.enqueue(CancelServerStreamCommand.withReset(transportState(), status), true);
134+
CancelServerStreamCommand cmd =
135+
status.getCause() instanceof CloseWithHeadersMarker
136+
? CancelServerStreamCommand.withReason(transportState(), status)
137+
: CancelServerStreamCommand.withReset(transportState(), status);
138+
writeQueue.enqueue(cmd, true);
134139
}
135140
}
136141
}

0 commit comments

Comments
 (0)