Skip to content

Commit cdef621

Browse files
authored
fix: do not warn if requestid middleware errors due to ErrIllegalHeaderWrite (#2654)
1 parent e93dedf commit cdef621

File tree

5 files changed

+80
-64
lines changed

5 files changed

+80
-64
lines changed

internal/middleware/usagemetrics/usagemetrics.go

Lines changed: 16 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ package usagemetrics
33
import (
44
"context"
55
"strconv"
6+
"strings"
67
"time"
78

89
"github.com/grpc-ecosystem/go-grpc-middleware/v2/interceptors"
@@ -38,7 +39,7 @@ type reporter struct{}
3839

3940
func (r *reporter) ServerReporter(ctx context.Context, callMeta interceptors.CallMeta) (interceptors.Reporter, context.Context) {
4041
_, methodName := grpcutil.SplitMethodName(callMeta.FullMethod())
41-
ctx = ContextWithHandle(ctx)
42+
ctx = contextWithHandle(ctx)
4243
return &serverReporter{ctx: ctx, methodName: methodName}, ctx
4344
}
4445

@@ -48,19 +49,25 @@ type serverReporter struct {
4849
methodName string
4950
}
5051

52+
// PostCall is invoked after all PostMsgSend operations.
5153
func (r *serverReporter) PostCall(_ error, _ time.Duration) {
5254
responseMeta := FromContext(r.ctx)
5355
if responseMeta == nil {
5456
responseMeta = &dispatch.ResponseMeta{}
5557
}
5658

57-
err := annotateAndReportForMetadata(r.ctx, r.methodName, responseMeta)
58-
// if context is cancelled, the stream will be closed, and gRPC will return ErrIllegalHeaderWrite
59-
// this prevents logging unnecessary error messages
60-
if r.ctx.Err() != nil {
61-
return
62-
}
59+
DispatchedCountHistogram.WithLabelValues(r.methodName, "false").Observe(float64(responseMeta.DispatchCount))
60+
DispatchedCountHistogram.WithLabelValues(r.methodName, "true").Observe(float64(responseMeta.CachedDispatchCount))
61+
err := responsemeta.SetResponseTrailerMetadata(r.ctx, map[responsemeta.ResponseMetadataTrailerKey]string{
62+
responsemeta.DispatchedOperationsCount: strconv.Itoa(int(responseMeta.DispatchCount)),
63+
responsemeta.CachedOperationsCount: strconv.Itoa(int(responseMeta.CachedDispatchCount)),
64+
})
6365
if err != nil {
66+
// if context is cancelled, the stream will be closed, and gRPC will return ErrIllegalHeaderWrite (which is private)
67+
// this prevents logging unnecessary error messages
68+
if strings.Contains(err.Error(), "SendHeader called multiple times") {
69+
return
70+
}
6471
log.Ctx(r.ctx).Warn().Err(err).Msg("usagemetrics: could not report metadata")
6572
}
6673
}
@@ -79,16 +86,6 @@ func StreamServerInterceptor() grpc.StreamServerInterceptor {
7986
return interceptors.StreamServerInterceptor(&reporter{})
8087
}
8188

82-
func annotateAndReportForMetadata(ctx context.Context, methodName string, metadata *dispatch.ResponseMeta) error {
83-
DispatchedCountHistogram.WithLabelValues(methodName, "false").Observe(float64(metadata.DispatchCount))
84-
DispatchedCountHistogram.WithLabelValues(methodName, "true").Observe(float64(metadata.CachedDispatchCount))
85-
86-
return responsemeta.SetResponseTrailerMetadata(ctx, map[responsemeta.ResponseMetadataTrailerKey]string{
87-
responsemeta.DispatchedOperationsCount: strconv.Itoa(int(metadata.DispatchCount)),
88-
responsemeta.CachedOperationsCount: strconv.Itoa(int(metadata.CachedDispatchCount)),
89-
})
90-
}
91-
9289
// Create a new type to prevent context collisions
9390
type responseMetaKey string
9491

@@ -119,11 +116,11 @@ func FromContext(ctx context.Context) *dispatch.ResponseMeta {
119116
return possibleHandle.(*metaHandle).metadata
120117
}
121118

122-
// ContextWithHandle creates a new context with a location to store metadata
119+
// contextWithHandle creates a new context with a location to store metadata
123120
// returned from a dispatched request.
124121
//
125122
// This should only be called in middleware or testing functions.
126-
func ContextWithHandle(ctx context.Context) context.Context {
123+
func contextWithHandle(ctx context.Context) context.Context {
127124
var handle metaHandle
128125
return context.WithValue(ctx, metadataCtxKey, &handle)
129126
}

internal/middleware/usagemetrics/usagemetrics_test.go

Lines changed: 32 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@ package usagemetrics
33
import (
44
"context"
55
"errors"
6-
"fmt"
76
"io"
87
"testing"
98

@@ -21,53 +20,28 @@ type testServer struct {
2120
testpb.UnimplementedTestServiceServer
2221
}
2322

24-
func (t testServer) PingEmpty(ctx context.Context, _ *testpb.PingEmptyRequest) (*testpb.PingEmptyResponse, error) {
25-
SetInContext(ctx, &dispatch.ResponseMeta{
26-
DispatchCount: 1,
27-
CachedDispatchCount: 1,
28-
})
29-
return &testpb.PingEmptyResponse{}, nil
30-
}
31-
32-
func (t testServer) Ping(ctx context.Context, _ *testpb.PingRequest) (*testpb.PingResponse, error) {
23+
func (t *testServer) Ping(ctx context.Context, _ *testpb.PingRequest) (*testpb.PingResponse, error) {
3324
SetInContext(ctx, &dispatch.ResponseMeta{
3425
DispatchCount: 1,
3526
CachedDispatchCount: 1,
3627
})
3728
return &testpb.PingResponse{Value: ""}, nil
3829
}
3930

40-
func (t testServer) PingError(ctx context.Context, _ *testpb.PingErrorRequest) (*testpb.PingErrorResponse, error) {
41-
SetInContext(ctx, &dispatch.ResponseMeta{
42-
DispatchCount: 1,
43-
CachedDispatchCount: 1,
44-
})
45-
return nil, fmt.Errorf("err")
31+
// PingError returns the context error
32+
func (t *testServer) PingError(ctx context.Context, _ *testpb.PingErrorRequest) (*testpb.PingErrorResponse, error) {
33+
<-ctx.Done()
34+
return nil, ctx.Err()
4635
}
4736

48-
func (t testServer) PingList(_ *testpb.PingListRequest, server testpb.TestService_PingListServer) error {
37+
func (t *testServer) PingList(_ *testpb.PingListRequest, server testpb.TestService_PingListServer) error {
4938
SetInContext(server.Context(), &dispatch.ResponseMeta{
5039
DispatchCount: 1,
5140
CachedDispatchCount: 1,
5241
})
5342
return nil
5443
}
5544

56-
func (t testServer) PingStream(stream testpb.TestService_PingStreamServer) error {
57-
count := int32(0)
58-
for {
59-
_, err := stream.Recv()
60-
if errors.Is(err, io.EOF) {
61-
break
62-
} else if err != nil {
63-
return err
64-
}
65-
_ = stream.Send(&testpb.PingStreamResponse{Value: "", Counter: count})
66-
count++
67-
}
68-
return nil
69-
}
70-
7145
type metricsMiddlewareTestSuite struct {
7246
*testpb.InterceptorTestSuite
7347
}
@@ -131,3 +105,29 @@ func (s *metricsMiddlewareTestSuite) TestTrailers_Stream() {
131105
s.Require().NoError(err)
132106
s.Require().Equal(1, cachedCount)
133107
}
108+
109+
func (s *metricsMiddlewareTestSuite) TestErrCtx() {
110+
var trailerMD metadata.MD
111+
112+
// SimpleCtx times out after two seconds
113+
_, err := s.Client.PingError(s.SimpleCtx(), &testpb.PingErrorRequest{}, grpc.Trailer(&trailerMD))
114+
s.Require().ErrorContains(err, context.DeadlineExceeded.Error())
115+
116+
// TODO ideally, this test would assert that no error log has been written
117+
// but right now we have no way of capturing the logs
118+
119+
// No metadata should have been sent
120+
dispatchCount, err := responsemeta.GetIntResponseTrailerMetadata(
121+
trailerMD,
122+
responsemeta.DispatchedOperationsCount,
123+
)
124+
s.Require().ErrorContains(err, "key `io.spicedb.respmeta.dispatchedoperationscount` not found in trailer")
125+
s.Require().Equal(0, dispatchCount)
126+
127+
cachedCount, err := responsemeta.GetIntResponseTrailerMetadata(
128+
trailerMD,
129+
responsemeta.CachedOperationsCount,
130+
)
131+
s.Require().ErrorContains(err, "key `io.spicedb.respmeta.cachedoperationscount` not found in trailer")
132+
s.Require().Equal(0, cachedCount)
133+
}

pkg/middleware/requestid/requestid.go

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ package requestid
22

33
import (
44
"context"
5+
"strings"
56

67
"github.com/grpc-ecosystem/go-grpc-middleware/v2/interceptors"
78
"github.com/rs/xid"
@@ -54,6 +55,7 @@ func (r *handleRequestID) ClientReporter(ctx context.Context, meta interceptors.
5455
return interceptors.NoopReporter{}, ctx
5556
}
5657

58+
// ServerReporter is invoked before the request begins processing.
5759
func (r *handleRequestID) ServerReporter(ctx context.Context, _ interceptors.CallMeta) (interceptors.Reporter, context.Context) {
5860
haveRequestID, requestID, ctx := r.fromContextOrGenerate(ctx)
5961

@@ -67,7 +69,11 @@ func (r *handleRequestID) ServerReporter(ctx context.Context, _ interceptors.Cal
6769
responsemeta.ResponseMetadataTrailerKey(responsemeta.RequestID): requestID,
6870
})
6971
if err != nil {
70-
log.Ctx(ctx).Warn().Err(err).Msg("requestid: could not report metadata")
72+
// if context is cancelled, the stream will be closed, and gRPC will return ErrIllegalHeaderWrite (which is private)
73+
// this prevents logging unnecessary error messages
74+
if !strings.Contains(err.Error(), "SendHeader called multiple times") {
75+
log.Ctx(ctx).Warn().Err(err).Msg("requestid: could not report metadata")
76+
}
7177
}
7278
}
7379

pkg/middleware/requestid/requestid_test.go

Lines changed: 24 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ import (
44
"context"
55
"errors"
66
"io"
7+
"sync"
78
"testing"
89
"time"
910

@@ -358,17 +359,29 @@ func TestRequestIDMiddlewareTimeout(t *testing.T) {
358359
suite.Run(t, s)
359360
}
360361

362+
// If a stream receives a context that is already cancelled,
363+
// the middleware will not be invoked at all. Therefore, we try to induce
364+
// what happens when the context errors *while the middleware is running*.
365+
// However, right now, this test is only useful when ran by hand and inspecting logs manually,
366+
// because we don't have a way of asserting what logs were emmitted by the middleware.
361367
func (s *requestIDTimeoutTestSuite) TestContextTimeout() {
362-
// Create a context that's already cancelled
363-
ctx, cancel := context.WithTimeout(s.SimpleCtx(), time.Nanosecond)
364-
defer cancel()
365-
366-
// Wait for context to be cancelled
367-
time.Sleep(time.Millisecond)
368-
369-
var trailer metadata.MD
370-
_, err := s.Client.PingEmpty(ctx, &testpb.PingEmptyRequest{}, grpc.Trailer(&trailer))
368+
var wg sync.WaitGroup
369+
for i := 0; i < 10_000; i++ {
370+
wg.Add(1)
371+
wg.Go(func() {
372+
defer wg.Done()
373+
374+
// context will be cancelled in the middle of the middleware (if we are lucky)
375+
ctx, cancel := context.WithTimeout(s.T().Context(), 1*time.Millisecond)
376+
defer cancel()
377+
378+
var trailer metadata.MD
379+
_, err := s.Client.PingEmpty(ctx, &testpb.PingEmptyRequest{}, grpc.Trailer(&trailer))
380+
381+
// The RPC should fail due to context cancellation, but middleware should handle it gracefully
382+
require.Error(s.T(), err)
383+
})
384+
}
371385

372-
// The RPC should fail due to context cancellation, but middleware should handle it gracefully
373-
require.Error(s.T(), err)
386+
wg.Wait()
374387
}

pkg/middleware/serverversion/serverversion.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ func (r *HandleServerVersion) ServerReporter(ctx context.Context, _ interceptors
3737
err = responsemeta.SetResponseTrailerMetadata(ctx, map[responsemeta.ResponseMetadataTrailerKey]string{
3838
responsemeta.ResponseMetadataTrailerKey(responsemeta.ServerVersion): version,
3939
})
40-
// if context is cancelled, the stream will be closed, and gRPC will return ErrIllegalHeaderWrite
40+
// if context is cancelled, the stream will be closed, and gRPC will return ErrIllegalHeaderWrite (which is private)
4141
// this prevents logging unnecessary error messages
4242
if err := ctx.Err(); err != nil {
4343
return interceptors.NoopReporter{}, ctx

0 commit comments

Comments
 (0)