Skip to content

Commit eb2f04a

Browse files
committed
partialmessages: add explicit MergePartsMetadata function
This allows a peer's local partsMetadata to merge with their remote partsMetadata. As long as the underlying PartsMetadata type supports CRDT-like addition, we can avoid sending the peer duplicates. Take the following example (credit Sukun): > Say: We have parts 1, 2, 5, 6 > > 1. Peer sends a message saying I have part 1, 2 > 2. We send parts 5, 6 and update the peerstate in PublishMessage to (1, 2, 5, 6) > 3. Concurrently with 2, we receive message from the peer saying I now have parts (1, 2, 3, 4) With this change we can now correctly store the fact that this peer should have parts 1,2,3,4,5,6. And in a way that does not assume a specific representation of the PartsMetadata.
1 parent 8fe8b9f commit eb2f04a

File tree

4 files changed

+98
-66
lines changed

4 files changed

+98
-66
lines changed

gossipsub_test.go

Lines changed: 26 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ import (
2525

2626
"github.com/libp2p/go-libp2p-pubsub/internal/gologshim"
2727
"github.com/libp2p/go-libp2p-pubsub/partialmessages"
28+
"github.com/libp2p/go-libp2p-pubsub/partialmessages/bitmap"
2829
pb "github.com/libp2p/go-libp2p-pubsub/pb"
2930
"github.com/libp2p/go-libp2p/core/crypto"
3031
"github.com/libp2p/go-msgio"
@@ -4408,15 +4409,14 @@ func (m *minimalTestPartialMessage) complete() bool {
44084409
}
44094410

44104411
// PartsMetadata implements partialmessages.PartialMessage.
4411-
func (m *minimalTestPartialMessage) PartsMetadata() []byte {
4412-
out := byte(0)
4413-
if len(m.Parts[0]) > 0 {
4414-
out |= 1
4415-
}
4416-
if len(m.Parts[1]) > 0 {
4417-
out |= 2
4412+
func (m *minimalTestPartialMessage) PartsMetadata() partialmessages.PartsMetadata {
4413+
out := make(bitmap.Bitmap, 1)
4414+
for i := range m.Parts {
4415+
if len(m.Parts[i]) > 0 {
4416+
out.Set(i)
4417+
}
44184418
}
4419-
return []byte{out}
4419+
return partialmessages.PartsMetadata(out)
44204420
}
44214421

44224422
func (m *minimalTestPartialMessage) extendFromEncodedPartialMessage(_ peer.ID, data []byte) (extended bool) {
@@ -4471,34 +4471,30 @@ func (m *minimalTestPartialMessage) GroupID() []byte {
44714471
return m.Group
44724472
}
44734473

4474-
func (m *minimalTestPartialMessage) PartialMessageBytes(peerPartsMetadata []byte) ([]byte, []byte, error) {
4474+
func (m *minimalTestPartialMessage) PartialMessageBytes(peerPartsMetadata partialmessages.PartsMetadata) ([]byte, error) {
44754475
if len(peerPartsMetadata) == 0 {
4476-
return nil, nil, errors.New("invalid metadata")
4476+
return nil, errors.New("invalid metadata")
44774477
}
4478-
peerWants := ^peerPartsMetadata[0]
4478+
peerHas := bitmap.Bitmap(peerPartsMetadata)
44794479

44804480
var temp minimalTestPartialMessage
44814481
temp.Group = m.Group
4482-
if peerWants&1 == 1 && m.Parts[0] != nil {
4483-
peerWants ^= 1
4482+
if !peerHas.Get(0) && m.Parts[0] != nil {
44844483
temp.Parts[0] = m.Parts[0]
44854484
}
4486-
if peerWants&2 == 2 && m.Parts[1] != nil {
4487-
peerWants ^= 2
4485+
if !peerHas.Get(1) && m.Parts[1] != nil {
44884486
temp.Parts[1] = m.Parts[1]
44894487
}
44904488

4491-
restMetadata := []byte{^peerWants}
4492-
44934489
if temp.Parts[0] == nil && temp.Parts[1] == nil {
4494-
return nil, restMetadata, nil
4490+
return nil, nil
44954491
}
44964492

44974493
b, err := json.Marshal(temp)
44984494
if err != nil {
4499-
return nil, nil, err
4495+
return nil, err
45004496
}
4501-
return b, restMetadata, nil
4497+
return b, nil
45024498
}
45034499

45044500
func (m *minimalTestPartialMessage) shouldRequest(_ peer.ID, peerHasMetadata []byte) bool {
@@ -4516,6 +4512,10 @@ type minimalTestPartialMessageChecker struct {
45164512
fullMessage *minimalTestPartialMessage
45174513
}
45184514

4515+
func (m *minimalTestPartialMessageChecker) MergePartsMetadata(left, right partialmessages.PartsMetadata) partialmessages.PartsMetadata {
4516+
return partialmessages.MergeBitmap(left, right)
4517+
}
4518+
45194519
// EmptyMessage implements partialmessages.InvariantChecker.
45204520
func (m *minimalTestPartialMessageChecker) EmptyMessage() *minimalTestPartialMessage {
45214521
return &minimalTestPartialMessage{
@@ -4606,6 +4606,9 @@ func TestPartialMessages(t *testing.T) {
46064606
// have some basic fast rules here.
46074607
return nil
46084608
},
4609+
MergePartsMetadata: func(_ string, left, right partialmessages.PartsMetadata) partialmessages.PartsMetadata {
4610+
return partialmessages.MergeBitmap(left, right)
4611+
},
46094612
OnIncomingRPC: func(from peer.ID, rpc *pb.PartialMessagesExtension) error {
46104613
groupID := rpc.GroupID
46114614
pm, ok := partialMessageStore[i][topic+string(groupID)]
@@ -4706,6 +4709,9 @@ func TestSkipPublishingToPeersWithPartialMessageSupport(t *testing.T) {
47064709
ValidateRPC: func(from peer.ID, rpc *pb.PartialMessagesExtension) error {
47074710
return nil
47084711
},
4712+
MergePartsMetadata: func(_ string, left, right partialmessages.PartsMetadata) partialmessages.PartsMetadata {
4713+
return partialmessages.MergeBitmap(left, right)
4714+
},
47094715
OnIncomingRPC: func(from peer.ID, rpc *pb.PartialMessagesExtension) error {
47104716
topicID := rpc.GetTopicID()
47114717
groupID := rpc.GetGroupID()

partialmessages/invariants.go

Lines changed: 16 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -30,13 +30,15 @@ type InvariantChecker[P Message] interface {
3030
// in (as determined from the parts metadata)
3131
ShouldRequest(a P, from peer.ID, partsMetadata []byte) bool
3232

33+
MergePartsMetadata(left, right PartsMetadata) PartsMetadata
34+
3335
Equal(a, b P) bool
3436
}
3537

3638
func TestPartialMessageInvariants[P Message](t *testing.T, checker InvariantChecker[P]) {
3739
extend := func(a, b P) (P, error) {
3840
emptyParts := checker.EmptyMessage().PartsMetadata()
39-
encodedB, _, err := b.PartialMessageBytes(emptyParts)
41+
encodedB, err := b.PartialMessageBytes(emptyParts)
4042
if err != nil {
4143
var out P
4244
return out, err
@@ -64,7 +66,7 @@ func TestPartialMessageInvariants[P Message](t *testing.T, checker InvariantChec
6466

6567
recombined := checker.EmptyMessage()
6668
for _, part := range parts {
67-
b, _, err := part.PartialMessageBytes(recombined.PartsMetadata())
69+
b, err := part.PartialMessageBytes(recombined.PartsMetadata())
6870
if err != nil {
6971
t.Fatal(err)
7072
}
@@ -86,10 +88,11 @@ func TestPartialMessageInvariants[P Message](t *testing.T, checker InvariantChec
8688
emptyMsgPartsMeta := emptyMessage.PartsMetadata()
8789

8890
// Empty message should not be able to fulfill any request
89-
response, rest, err := emptyMessage.PartialMessageBytes(emptyMsgPartsMeta)
91+
response, err := emptyMessage.PartialMessageBytes(emptyMsgPartsMeta)
9092
if err != nil {
9193
t.Fatal(err)
9294
}
95+
rest := checker.MergePartsMetadata(emptyMsgPartsMeta, emptyMessage.PartsMetadata())
9396

9497
if len(response) != 0 {
9598
t.Error("Empty message should return nil response when requesting parts it doesn't have")
@@ -122,10 +125,11 @@ func TestPartialMessageInvariants[P Message](t *testing.T, checker InvariantChec
122125
emptyMsgPartsMeta := emptyMessage.PartsMetadata()
123126

124127
// Request all parts from the partial message
125-
response1, rest1, err := parts[0].PartialMessageBytes(emptyMsgPartsMeta)
128+
response1, err := parts[0].PartialMessageBytes(emptyMsgPartsMeta)
126129
if err != nil {
127130
t.Fatal(err)
128131
}
132+
rest1 := checker.MergePartsMetadata(emptyMsgPartsMeta, parts[0].PartsMetadata())
129133

130134
// Should get some response since partial message has at least one part
131135
if len(response1) == 0 {
@@ -134,10 +138,10 @@ func TestPartialMessageInvariants[P Message](t *testing.T, checker InvariantChec
134138

135139
// Rest should be non-zero and different from original request since something was fulfilled
136140
if len(rest1) == 0 {
137-
t.Error("Rest should be non-zero when partial fulfillment occurred")
141+
t.Fatal("Rest should be non-zero when partial fulfillment occurred")
138142
}
139143
if bytes.Equal(rest1, emptyMsgPartsMeta) {
140-
t.Error("Rest should be different from original request since partial fulfillment occurred")
144+
t.Fatalf("Rest should be different from original request since partial fulfillment occurred")
141145
}
142146

143147
// Create another partial message with the remaining parts
@@ -150,10 +154,11 @@ func TestPartialMessageInvariants[P Message](t *testing.T, checker InvariantChec
150154
}
151155

152156
// The remaining partial message should be able to fulfill the "rest" request
153-
response2, rest2, err := remainingPartial.PartialMessageBytes(rest1)
157+
response2, err := remainingPartial.PartialMessageBytes(rest1)
154158
if err != nil {
155159
t.Fatal(err)
156160
}
161+
rest2 := checker.MergePartsMetadata(rest1, remainingPartial.PartsMetadata())
157162

158163
// response2 should be non-empty since we have remaining parts to fulfill
159164
if len(response2) == 0 {
@@ -192,7 +197,8 @@ func TestPartialMessageInvariants[P Message](t *testing.T, checker InvariantChec
192197

193198
// Request with empty metadata should return all available parts
194199
emptyMeta := checker.EmptyMessage().PartsMetadata()
195-
response, rest, err := fullMessage.PartialMessageBytes(emptyMeta)
200+
response, err := fullMessage.PartialMessageBytes(emptyMeta)
201+
rest := checker.MergePartsMetadata(emptyMeta, fullMessage.PartsMetadata())
196202
if err != nil {
197203
t.Fatal(err)
198204
}
@@ -244,7 +250,8 @@ func TestPartialMessageInvariants[P Message](t *testing.T, checker InvariantChec
244250
// Get the MissingParts() and have the full message fulfill the request
245251
msgPartsMeta := testMsg.PartsMetadata()
246252

247-
response, rest, err := fullMessage.PartialMessageBytes(msgPartsMeta)
253+
response, err := fullMessage.PartialMessageBytes(msgPartsMeta)
254+
rest := checker.MergePartsMetadata(msgPartsMeta, fullMessage.PartsMetadata())
248255
if err != nil {
249256
t.Fatal(err)
250257
}

partialmessages/partialmsgs.go

Lines changed: 24 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ import (
77
"log/slog"
88
"slices"
99

10+
"github.com/libp2p/go-libp2p-pubsub/partialmessages/bitmap"
1011
pb "github.com/libp2p/go-libp2p-pubsub/pb"
1112
"github.com/libp2p/go-libp2p/core/peer"
1213
)
@@ -21,6 +22,10 @@ import (
2122

2223
const minGroupTTL = 3
2324

25+
// PartsMetadata returns metadata about the parts this partial message
26+
// contains and, possibly implicitly, the parts it wants.
27+
type PartsMetadata []byte
28+
2429
// Message is a message that can be broken up into parts. It can be
2530
// complete, partially complete, or empty. It is up to the application to define
2631
// how a message is split into parts and recombined, as well as how missing and
@@ -43,20 +48,18 @@ type Message interface {
4348
//
4449
// If the Partial Message is empty, the implementation MUST return:
4550
// nil, metadata, nil.
46-
PartialMessageBytes(partsMetadata []byte) (msg []byte, newPartsMetadata []byte, _ error)
51+
PartialMessageBytes(partsMetadata PartsMetadata) (msg []byte, _ error)
4752

48-
// PartsMetadata returns metadata about the parts this partial message
49-
// contains and, possibly implicitly, the parts it wants.
50-
PartsMetadata() []byte
53+
PartsMetadata() PartsMetadata
5154
}
5255

5356
// peerState is the state we keep per peer. Used to make Publish
5457
// Idempotent.
5558
type peerState struct {
5659
// The parts metadata the peer has sent us
57-
partsMetadata []byte
58-
// The parts metadata this endpoint has sent to the peer
59-
sentPartsMetadata []byte
60+
partsMetadata PartsMetadata
61+
// The parts metadata this node has sent to the peer
62+
sentPartsMetadata PartsMetadata
6063
}
6164

6265
func (ps *peerState) IsZero() bool {
@@ -84,9 +87,17 @@ func (s *partialMessageStatePerTopicGroup) clearPeerWants(peerID peer.ID) {
8487
}
8588
}
8689

90+
// MergeBitmap is a helper function for merging parts metadata if they are a
91+
// bitmap.
92+
func MergeBitmap(left, right PartsMetadata) PartsMetadata {
93+
return PartsMetadata(bitmap.Merge(bitmap.Bitmap(left), bitmap.Bitmap(right)))
94+
}
95+
8796
type PartialMessageExtension struct {
8897
Logger *slog.Logger
8998

99+
MergePartsMetadata func(topic string, left, right PartsMetadata) PartsMetadata
100+
90101
// OnIncomingRPC is called whenever we receive an encoded
91102
// partial message from a peer. This func MUST be fast and non-blocking.
92103
// If you need to do slow work, dispatch the work to your own goroutine.
@@ -147,6 +158,9 @@ func (e *PartialMessageExtension) Init(router Router) error {
147158
if e.OnIncomingRPC == nil {
148159
return errors.New("field OnIncomingRPC must be set")
149160
}
161+
if e.MergePartsMetadata == nil {
162+
return errors.New("field MergePartsMetadata must be set")
163+
}
150164
e.statePerTopicPerGroup = make(map[string]map[string]*partialMessageStatePerTopicGroup)
151165

152166
return nil
@@ -186,14 +200,14 @@ func (e *PartialMessageExtension) PublishPartial(topic string, partial Message,
186200
if pState.partsMetadata != nil {
187201
// This peer has previously asked for a certain part. We'll give
188202
// them what we can.
189-
pm, rest, err := partial.PartialMessageBytes(pState.partsMetadata)
203+
pm, err := partial.PartialMessageBytes(pState.partsMetadata)
190204
if err != nil {
191205
log.Warn("partial message extension failed to get partial message bytes", "error", err)
192206
// Possibly a bad request, we'll delete the request as we will likely error next time we try to handle it
193207
state.clearPeerWants(p)
194208
continue
195209
}
196-
pState.partsMetadata = rest
210+
pState.partsMetadata = e.MergePartsMetadata(topic, pState.partsMetadata, myPartsMeta)
197211
if len(pm) > 0 {
198212
log.Debug("Respond to peer's IWant")
199213
sendRPC = true
@@ -283,7 +297,7 @@ func (e *PartialMessageExtension) HandleRPC(from peer.ID, rpc *pb.PartialMessage
283297
pState = &peerState{}
284298
state.peerState[from] = pState
285299
}
286-
pState.partsMetadata = rpc.PartsMetadata
300+
pState.partsMetadata = e.MergePartsMetadata(rpc.GetTopicID(), pState.partsMetadata, rpc.PartsMetadata)
287301
}
288302

289303
return e.OnIncomingRPC(from, rpc)

0 commit comments

Comments
 (0)