Skip to content

Commit 1eea0ae

Browse files
authored
fix(rpcinfo): fix type conversion panic when chaining call FreezeRPCInfo twice (#1888)
1 parent 0ceb2a8 commit 1eea0ae

File tree

2 files changed

+48
-19
lines changed

2 files changed

+48
-19
lines changed

pkg/rpcinfo/copy.go

Lines changed: 29 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@
1616

1717
package rpcinfo
1818

19+
import "github.com/cloudwego/kitex/pkg/kerrors"
20+
1921
type plainRPCInfo struct {
2022
// use anonymous structs to force objects to be read-only
2123
from struct{ EndpointInfo }
@@ -82,17 +84,34 @@ func copyInvocation(i Invocation) Invocation {
8284
if i == nil {
8385
return nil
8486
}
85-
ink := i.(*invocation)
86-
nink := &invocation{
87-
packageName: ink.PackageName(),
88-
serviceName: ink.ServiceName(),
89-
methodName: ink.MethodName(),
90-
methodInfo: ink.MethodInfo(),
91-
streamingMode: ink.StreamingMode(),
92-
seqID: ink.SeqID(),
93-
// ignore extra info to users
87+
var nink *invocation
88+
var bizErr kerrors.BizStatusErrorIface
89+
// fast-path, calling function of struct directly
90+
if ink, ok := i.(*invocation); ok {
91+
nink = &invocation{
92+
packageName: ink.PackageName(),
93+
serviceName: ink.ServiceName(),
94+
methodName: ink.MethodName(),
95+
methodInfo: ink.MethodInfo(),
96+
streamingMode: ink.StreamingMode(),
97+
seqID: ink.SeqID(),
98+
// ignore extra info to users
99+
}
100+
bizErr = ink.BizStatusErr()
101+
} else {
102+
nink = &invocation{
103+
packageName: i.PackageName(),
104+
serviceName: i.ServiceName(),
105+
methodName: i.MethodName(),
106+
methodInfo: i.MethodInfo(),
107+
streamingMode: i.StreamingMode(),
108+
seqID: i.SeqID(),
109+
// ignore extra info to users
110+
}
111+
bizErr = i.BizStatusErr()
94112
}
95-
nink.SetBizStatusErr(ink.BizStatusErr())
113+
114+
nink.SetBizStatusErr(bizErr)
96115
return nink
97116
}
98117

pkg/rpcinfo/copy_test.go

Lines changed: 19 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -43,13 +43,23 @@ func TestFreezeRPCInfo(t *testing.T) {
4343
ctx2 := FreezeRPCInfo(ctx)
4444
test.Assert(t, ctx2 != ctx)
4545
ri2 := GetRPCInfo(ctx2)
46-
test.Assert(t, ri2 != nil)
47-
test.Assert(t, ri2 != ri)
48-
49-
test.Assert(t, AsTaggable(ri2.From()) == nil)
50-
test.Assert(t, AsTaggable(ri2.To()) == nil)
51-
test.Assert(t, AsMutableEndpointInfo(ri2.From()) == nil)
52-
test.Assert(t, AsMutableEndpointInfo(ri2.To()) == nil)
53-
test.Assert(t, AsMutableRPCConfig(ri2.Config()) == nil)
54-
test.Assert(t, ri2.Stats() == nil)
46+
checkFreezeRPCInfo(t, ri2, ri)
47+
48+
// call FreezeRPCInfo continuously should not cause a panic
49+
ctx3 := FreezeRPCInfo(ctx2)
50+
test.Assert(t, ctx3 != ctx2)
51+
ri3 := GetRPCInfo(ctx3)
52+
checkFreezeRPCInfo(t, ri3, ri2)
53+
}
54+
55+
func checkFreezeRPCInfo(t *testing.T, ri, prevRI RPCInfo) {
56+
test.Assert(t, ri != nil)
57+
test.Assert(t, ri != prevRI)
58+
59+
test.Assert(t, AsTaggable(ri.From()) == nil)
60+
test.Assert(t, AsTaggable(ri.To()) == nil)
61+
test.Assert(t, AsMutableEndpointInfo(ri.From()) == nil)
62+
test.Assert(t, AsMutableEndpointInfo(ri.To()) == nil)
63+
test.Assert(t, AsMutableRPCConfig(ri.Config()) == nil)
64+
test.Assert(t, ri.Stats() == nil)
5565
}

0 commit comments

Comments
 (0)