diff --git a/comm.go b/comm.go index d38cce08..fd7cc8f6 100644 --- a/comm.go +++ b/comm.go @@ -158,7 +158,12 @@ func (p *PubSub) handlePeerDead(s network.Stream) { func (p *PubSub) handleSendingMessages(ctx context.Context, s network.Stream, outgoing *rpcQueue) { writeRpc := func(rpc *RPC) error { + rpc = rpc.filterUnwanted(s.Conn().RemotePeer()) size := uint64(rpc.Size()) + if size == 0 { + // Nothing to do, the peer cancelled all our messages + return nil + } buf := pool.Get(varint.UvarintSize(size) + int(size)) defer pool.Put(buf) @@ -198,6 +203,14 @@ func rpcWithSubs(subs ...*pb.RPC_SubOpts) *RPC { } } +func rpcWithMessagesAndChecksums(msgs []*pb.Message, checksums []checksum, unwanted *unwantedState) *RPC { + return &RPC{ + RPC: pb.RPC{Publish: msgs}, + messageChecksums: checksums, + unwanted: unwanted, + } +} + func rpcWithMessages(msgs ...*pb.Message) *RPC { return &RPC{RPC: pb.RPC{Publish: msgs}} } @@ -222,6 +235,7 @@ func rpcWithControl(msgs []*pb.Message, } } +// copyRPC shallow copies a RPC message. func copyRPC(rpc *RPC) *RPC { res := new(RPC) *res = *rpc diff --git a/gossipsub.go b/gossipsub.go index ecd4edaa..f6394585 100644 --- a/gossipsub.go +++ b/gossipsub.go @@ -7,6 +7,7 @@ import ( "io" "math/rand" "sort" + "sync" "time" pb "github.com/libp2p/go-libp2p-pubsub/pb" @@ -265,7 +266,7 @@ func DefaultGossipSubRouter(h host.Host) *GossipSubRouter { backoff: make(map[string]map[peer.ID]time.Time), peerhave: make(map[peer.ID]int), peerdontwant: make(map[peer.ID]int), - unwanted: make(map[peer.ID]map[checksum]int), + unwanted: newUnwantedState(), iasked: make(map[peer.ID]int), outbound: make(map[peer.ID]bool), connect: make(chan connectInfo, params.MaxPendingConnections), @@ -471,7 +472,7 @@ type GossipSubRouter struct { control map[peer.ID]*pb.ControlMessage // pending control messages peerhave map[peer.ID]int // number of IHAVEs received from peer in the last heartbeat peerdontwant map[peer.ID]int // number of IDONTWANTs received from peer in the last heartbeat - unwanted map[peer.ID]map[checksum]int // TTL of the message ids peers don't want + unwanted *unwantedState // TTL of the message ids peers don't want iasked map[peer.ID]int // number of messages we have asked from peer in the last heartbeat outbound map[peer.ID]bool // connection direction cache, marks peers with outbound connections backoff map[string]map[peer.ID]time.Time // prune backoff @@ -522,6 +523,48 @@ type GossipSubRouter struct { heartbeatTicks uint64 } +type unwantedState struct { + sync.RWMutex + m map[peer.ID]map[checksum]int // TTL of the message ids peers don't want +} + +func newUnwantedState() *unwantedState { + return &unwantedState{ + m: make(map[peer.ID]map[checksum]int), + } +} + +func (u *unwantedState) add(peer peer.ID, id checksum, ttl int) { + u.Lock() + defer u.Unlock() + if u.m[peer] == nil { + u.m[peer] = make(map[checksum]int) + } + u.m[peer][id] = ttl +} + +func (u *unwantedState) gc() { + u.Lock() + defer u.Unlock() + + // decrement TTLs of all the IDONTWANTs and delete it from the cache when it reaches zero + for _, mids := range u.m { + for mid := range mids { + mids[mid]-- + if mids[mid] == 0 { + delete(mids, mid) + } + } + } +} + +func (u *unwantedState) has(peer peer.ID, id checksum) bool { + u.RLock() + defer u.RUnlock() + _, ok := u.m[peer][id] + return ok +} + type connectInfo struct { p peer.ID spr *record.Envelope @@ -844,7 +887,7 @@ func (gs *GossipSubRouter) handleIWant(p peer.ID, ctl *pb.ControlMessage) []*pb. for _, iwant := range ctl.GetIwant() { for _, mid := range iwant.GetMessageIDs() { // Check if that peer has sent IDONTWANT before, if so don't send them the message - if _, ok := gs.unwanted[p][computeChecksum(mid)]; ok { + if gs.unwanted.has(p, computeChecksum(mid)) { continue } @@ -1013,10 +1056,6 @@ func (gs *GossipSubRouter) handlePrune(p peer.ID, ctl *pb.ControlMessage) { } func (gs *GossipSubRouter) handleIDontWant(p peer.ID, ctl *pb.ControlMessage) { - if gs.unwanted[p] == nil { - gs.unwanted[p] = make(map[checksum]int) - } - // IDONTWANT flood protection if gs.peerdontwant[p] >= gs.params.MaxIDontWantMessages { log.Debugf("IDONWANT: peer %s has advertised too many times (%d) within this heartbeat interval; ignoring", p, gs.peerdontwant[p]) @@ -1036,7 +1075,7 @@ mainIDWLoop: } totalUnwantedIds++ - gs.unwanted[p][computeChecksum(mid)] = gs.params.IDontWantMessageTTL + gs.unwanted.add(p, computeChecksum(mid), gs.params.IDontWantMessageTTL) } } } @@ -1156,6 +1195,7 @@ func (gs *GossipSubRouter) Publish(msg *Message) { if !ok { return } + messageChecksum := computeChecksum(gs.p.idGen.ID(msg)) if gs.floodPublish && from == gs.p.host.ID() { for p := range tmap { @@ -1200,18 +1240,17 @@ func (gs *GossipSubRouter) Publish(msg *Message) { gs.lastpub[topic] = time.Now().UnixNano() } - csum := computeChecksum(gs.p.idGen.ID(msg)) for p := range gmap { // Check if it has already received an IDONTWANT for the message. // If so, don't send it to the peer - if _, ok := gs.unwanted[p][csum]; ok { + if gs.unwanted.has(p, messageChecksum) { continue } tosend[p] = struct{}{} } } - out := rpcWithMessages(msg.Message) + out := rpcWithMessagesAndChecksums([]*pb.Message{msg.Message}, []checksum{messageChecksum}, gs.unwanted) for pid := range tosend { if pid == from || pid == peer.ID(msg.GetFrom()) { continue @@ -1348,14 +1387,14 @@ func (gs *GossipSubRouter) sendRPC(p peer.ID, out *RPC, urgent bool) { } // Potentially split the RPC into multiple RPCs that are below the max message size - outRPCs := appendOrMergeRPC(nil, gs.p.maxMessageSize, *out) + outRPCs := out.split(gs.p.maxMessageSize) for _, rpc := range outRPCs { if rpc.Size() > gs.p.maxMessageSize { // This should only happen if a single message/control is above the maxMessageSize. gs.doDropRPC(out, p, fmt.Sprintf("Dropping oversized RPC. Size: %d, limit: %d. (Over by %d bytes)", rpc.Size(), gs.p.maxMessageSize, rpc.Size()-gs.p.maxMessageSize)) continue } - gs.doSendRPC(rpc, p, q, urgent) + gs.doSendRPC(&rpc, p, q, urgent) } } @@ -1414,7 +1453,7 @@ func appendOrMergeRPC(slice []*RPC, limit int, elems ...RPC) []*RPC { // old behavior. In the future let's not merge messages. Since, // it may increase message latency. for _, msg := range elem.GetPublish() { - if lastRPC.Publish = append(lastRPC.Publish, msg); lastRPC.Size() > limit { + if lastRPC.Publish = append(lastRPC.Publish, msg); lastRPC.RPC.Size() > limit { lastRPC.Publish = lastRPC.Publish[:len(lastRPC.Publish)-1] lastRPC = &RPC{RPC: pb.RPC{}, from: elem.from} lastRPC.Publish = append(lastRPC.Publish, msg) @@ -1424,7 +1463,7 @@ func appendOrMergeRPC(slice []*RPC, limit int, elems ...RPC) []*RPC { // Merge/Append Subscriptions for _, sub := range elem.GetSubscriptions() { - if lastRPC.Subscriptions = append(lastRPC.Subscriptions, sub); lastRPC.Size() > limit { + if lastRPC.Subscriptions = append(lastRPC.Subscriptions, sub); lastRPC.RPC.Size() > limit { lastRPC.Subscriptions = lastRPC.Subscriptions[:len(lastRPC.Subscriptions)-1] lastRPC = &RPC{RPC: pb.RPC{}, from: elem.from} lastRPC.Subscriptions = append(lastRPC.Subscriptions, sub) @@ -1436,7 +1475,7 @@ func appendOrMergeRPC(slice []*RPC, limit int, elems ...RPC) []*RPC { if ctl := elem.GetControl(); ctl != nil { if lastRPC.Control == nil { lastRPC.Control = &pb.ControlMessage{} - if lastRPC.Size() > limit { + if lastRPC.RPC.Size() > limit { lastRPC.Control = nil lastRPC = &RPC{RPC: pb.RPC{Control: &pb.ControlMessage{}}, from: elem.from} out = append(out, lastRPC) @@ -1444,7 +1483,7 @@ func appendOrMergeRPC(slice []*RPC, limit int, elems ...RPC) []*RPC { } for _, graft := range ctl.GetGraft() { - if lastRPC.Control.Graft = append(lastRPC.Control.Graft, graft); lastRPC.Size() > limit { + if lastRPC.Control.Graft = append(lastRPC.Control.Graft, graft); lastRPC.RPC.Size() > limit { lastRPC.Control.Graft = lastRPC.Control.Graft[:len(lastRPC.Control.Graft)-1] lastRPC = &RPC{RPC: pb.RPC{Control: &pb.ControlMessage{}}, from: elem.from} lastRPC.Control.Graft = append(lastRPC.Control.Graft, graft) @@ -1453,7 +1492,7 @@ func appendOrMergeRPC(slice []*RPC, limit int, elems ...RPC) []*RPC { } for _, prune := range ctl.GetPrune() { - if lastRPC.Control.Prune = append(lastRPC.Control.Prune, prune); lastRPC.Size() > limit { + if lastRPC.Control.Prune = append(lastRPC.Control.Prune, prune); lastRPC.RPC.Size() > limit { lastRPC.Control.Prune = lastRPC.Control.Prune[:len(lastRPC.Control.Prune)-1] lastRPC = &RPC{RPC: pb.RPC{Control: &pb.ControlMessage{}}, from: elem.from} lastRPC.Control.Prune = append(lastRPC.Control.Prune, prune) @@ -1467,7 +1506,7 @@ func appendOrMergeRPC(slice []*RPC, limit int, elems ...RPC) []*RPC { // For IWANTs we don't need more than a single one, // since there are no topic IDs here. newIWant := &pb.ControlIWant{} - if lastRPC.Control.Iwant = append(lastRPC.Control.Iwant, newIWant); lastRPC.Size() > limit { + if lastRPC.Control.Iwant = append(lastRPC.Control.Iwant, newIWant); lastRPC.RPC.Size() > limit { lastRPC.Control.Iwant = lastRPC.Control.Iwant[:len(lastRPC.Control.Iwant)-1] lastRPC = &RPC{RPC: pb.RPC{Control: &pb.ControlMessage{ Iwant: []*pb.ControlIWant{newIWant}, @@ -1476,7 +1515,7 @@ func appendOrMergeRPC(slice []*RPC, limit int, elems ...RPC) []*RPC { } } for _, msgID := range iwant.GetMessageIDs() { - if lastRPC.Control.Iwant[0].MessageIDs = append(lastRPC.Control.Iwant[0].MessageIDs, msgID); lastRPC.Size() > limit { + if lastRPC.Control.Iwant[0].MessageIDs = append(lastRPC.Control.Iwant[0].MessageIDs, msgID); lastRPC.RPC.Size() > limit { lastRPC.Control.Iwant[0].MessageIDs = lastRPC.Control.Iwant[0].MessageIDs[:len(lastRPC.Control.Iwant[0].MessageIDs)-1] lastRPC = &RPC{RPC: pb.RPC{Control: &pb.ControlMessage{ Iwant: []*pb.ControlIWant{{MessageIDs: []string{msgID}}}, @@ -1824,16 +1863,7 @@ func (gs *GossipSubRouter) clearIDontWantCounters() { // throw away the old map and make a new one gs.peerdontwant = make(map[peer.ID]int) } - - // decrement TTLs of all the IDONTWANTs and delete it from the cache when it reaches zero - for _, mids := range gs.unwanted { - for mid := range mids { - mids[mid]-- - if mids[mid] == 0 { - delete(mids, mid) - } - } - } + gs.unwanted.gc() } func (gs *GossipSubRouter) applyIwantPenalties() { diff --git a/gossipsub_spam_test.go b/gossipsub_spam_test.go index 9f6f0f94..11bafd48 100644 --- a/gossipsub_spam_test.go +++ b/gossipsub_spam_test.go @@ -929,12 +929,12 @@ func TestGossipsubHandleIDontwantSpam(t *testing.T) { t.Errorf("Wanted message count of %d but received %d", 1, grt.peerdontwant[rPid]) } mid := fmt.Sprintf("idontwant-%d", GossipSubMaxIDontWantLength-1) - if _, ok := grt.unwanted[rPid][computeChecksum(mid)]; !ok { + if !grt.unwanted.has(rPid, computeChecksum(mid)) { t.Errorf("Desired message id was not stored in the unwanted map: %s", mid) } mid = fmt.Sprintf("idontwant-%d", GossipSubMaxIDontWantLength) - if _, ok := grt.unwanted[rPid][computeChecksum(mid)]; ok { + if grt.unwanted.has(rPid, computeChecksum(mid)) { t.Errorf("Unwanted message id was stored in the unwanted map: %s", mid) } } diff --git a/gossipsub_test.go b/gossipsub_test.go index abb347fd..ba0cc848 100644 --- a/gossipsub_test.go +++ b/gossipsub_test.go @@ -5,10 +5,13 @@ import ( "context" crand "crypto/rand" "encoding/base64" + "encoding/binary" "fmt" "io" mrand "math/rand" + mrand2 "math/rand/v2" "sort" + "strconv" "sync" "sync/atomic" "testing" @@ -2348,6 +2351,15 @@ func validRPCSizes(slice []*RPC, limit int) bool { return true } +func validRPCSizesStructSlice(slice []RPC, limit int) bool { + for _, rpc := range slice { + if rpc.Size() > limit { + return false + } + } + return true +} + func TestFragmentRPCFunction(t *testing.T) { fragmentRPC := func(rpc *RPC, limit int) ([]*RPC, error) { rpcs := appendOrMergeRPC(nil, limit, *rpc) @@ -2556,6 +2568,59 @@ func FuzzAppendOrMergeRPC(f *testing.F) { }) } +func FuzzRPCSplit(f *testing.F) { + minMaxMsgSize := 100 + maxMaxMsgSize := 2048 + f.Fuzz(func(t *testing.T, data []byte) { + maxSize := int(generateU16(&data)) % maxMaxMsgSize + if maxSize < minMaxMsgSize { + maxSize = minMaxMsgSize + } + rpc := generateRPC(data, maxSize) + rpcs := rpc.split(maxSize) + + if !validRPCSizesStructSlice(rpcs, maxSize) { + t.Fatalf("invalid RPC size") + } + }) +} + +func genNRpcs(tb testing.TB, n int, maxSize int) []*RPC { + r := mrand2.NewChaCha8([32]byte{}) + rpcs := make([]*RPC, n) + for i := range rpcs { + var data [64]byte + _, err := r.Read(data[:]) + if err != nil { + tb.Fatal(err) + } + rpcs[i] = generateRPC(data[:], maxSize) + } + return rpcs +} + +func BenchmarkSplitRPC(b *testing.B) { + maxSize := 2048 + rpcs := genNRpcs(b, 100, maxSize) + + b.ResetTimer() + for i := 0; i < b.N; i++ { + rpc := rpcs[i%len(rpcs)] + rpc.split(maxSize) + } +} + +func BenchmarkAppendOrMergeRPC(b *testing.B) { + maxSize := 2048 + rpcs := genNRpcs(b, 100, maxSize) + + b.ResetTimer() + for i := 0; i < b.N; i++ { + rpc := rpcs[i%len(rpcs)] + appendOrMergeRPC(nil, maxSize, *rpc) + } +} + func TestGossipsubManagesAnAddressBook(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) defer cancel() @@ -3406,3 +3471,169 @@ func BenchmarkAllocDoDropRPC(b *testing.B) { gs.doDropRPC(&RPC{}, "peerID", "reason") } } + +type blockableHost struct { + host.Host + streams []blockableStream +} + +func (bh *blockableHost) NewStream(ctx context.Context, p peer.ID, pids ...protocol.ID) (network.Stream, error) { + s, err := bh.Host.NewStream(ctx, p, pids...) + if err != nil { + return nil, err + } + + bh.streams = append(bh.streams, blockableStream{Stream: s}) + return &bh.streams[len(bh.streams)-1], nil +} + +func (bh *blockableHost) BlockAll() { + for i := range bh.streams { + bh.streams[i].Block() + } +} + +func (bh *blockableHost) UnblockAll() { + for i := range bh.streams { + bh.streams[i].Unblock() + } +} + +type blockableStream struct { + network.Stream + blocked sync.Mutex +} + +func (bs *blockableStream) Block() { + bs.blocked.Lock() +} + +func (bs *blockableStream) Unblock() { + bs.blocked.Unlock() +} + +func (bs *blockableStream) Write(p []byte) (int, error) { + bs.blocked.Lock() + defer bs.blocked.Unlock() + return bs.Stream.Write(p) +} + +func TestGossipsubIDONTWANTCancelsQueuedRPC(t *testing.T) { + msgCount := 3 + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + hosts := getDefaultHosts(t, 2) + denseConnect(t, hosts) + + psubs := make([]*PubSub, 2) + + publisherHost := &blockableHost{Host: hosts[0]} + messageIDFn := func(msg *pb.Message) string { + return strconv.FormatUint(binary.BigEndian.Uint64(msg.Data), 10) + } + psubs[0] = getGossipsub(ctx, publisherHost, WithMessageIdFn(messageIDFn)) + msgIDsReceived := make(chan string, msgCount) + psubs[1] = getGossipsub(ctx, hosts[1], WithMessageIdFn(messageIDFn), WithRawTracer(&mockRawTracer{ + onRecvRPC: func(rpc *RPC) { + if len(rpc.GetPublish()) > 0 { + for _, msg := range rpc.GetPublish() { + msgIDsReceived <- messageIDFn(msg) + } + } + }, + })) + + topicString := "foobar" + var topics []*Topic + for _, ps := range psubs { + topic, err := ps.Join(topicString) + if err != nil { + t.Fatal(err) + } + topics = append(topics, topic) + + _, err = ps.Subscribe(topicString, WithBufferSize(msgCount+1)) + if err != nil { + t.Fatal(err) + } + } + + time.Sleep(2 * time.Second) + + publisherHost.BlockAll() + // Have the publisher queue up a bunch of mesages. The actual sending will + // be blocked on the call to stream.Write + for i := range msgCount { + msg := make([]byte, GossipSubIDontWantMessageThreshold+1) + binary.BigEndian.AppendUint64(msg[:0], uint64(i)) + err := topics[0].Publish(ctx, msg) + if err != nil { + t.Fatal(err) + } + } + + // Now have the receiver cancel these messages by sending IDONTWANTs + idontwantIDs := make([]string, msgCount) + for i := range idontwantIDs { + idontwantIDs[i] = strconv.Itoa(i) + } + idontwantRPC := &RPC{ + RPC: pb.RPC{ + Control: &pb.ControlMessage{ + Idontwant: []*pb.ControlIDontWant{&pb.ControlIDontWant{ + MessageIDs: idontwantIDs, + }}, + }, + }, + } + q := psubs[1].peers[publisherHost.ID()] + + // Call this via the eval func to run it in the event loop + psubs[1].eval <- func() { + psubs[1].rt.(*GossipSubRouter).doSendRPC(idontwantRPC, publisherHost.ID(), q, true) + } + + // Wait for the RPCs to send + time.Sleep(time.Second) + + // Unblock writes + publisherHost.UnblockAll() + + // Have the publisher send one more message. We expect this one to make it + msg := make([]byte, GossipSubIDontWantMessageThreshold+1) + binary.BigEndian.AppendUint64(msg[:0], uint64(msgCount+1)) + err := topics[0].Publish(ctx, msg) + if err != nil { + t.Fatal(err) + } + + // We should get the last message +outerExpectMsg: + for { + select { + case msgID := <-msgIDsReceived: + if msgID == "0" { + // one early message that got sent before we could cancel it + continue + } + if msgID != strconv.FormatUint(uint64(msgCount)+1, 10) { + t.Fatal("received unexpected message: ", msgID) + } + break outerExpectMsg + case <-time.After(5 * time.Second): + t.Fatal("Should have received the last message") + } + } + + // We should not get any more messages +outer: + for { + select { + case <-msgIDsReceived: + t.Fatal("Should not have received a publish as the node sent IDONTWANT") + case <-time.After(5 * time.Second): + break outer + } + } +} diff --git a/pubsub.go b/pubsub.go index 5c27c3e9..76115f5b 100644 --- a/pubsub.go +++ b/pubsub.go @@ -5,7 +5,10 @@ import ( "encoding/binary" "errors" "fmt" + "iter" + math_bits "math/bits" "math/rand" + "slices" "sync" "sync/atomic" "time" @@ -243,9 +246,203 @@ func (m *Message) GetFrom() peer.ID { type RPC struct { pb.RPC - // unexported on purpose, not sending this over the wire - from peer.ID + from peer.ID + messageChecksums []checksum + unwanted *unwantedState +} + +func (r *RPC) filterUnwanted(to peer.ID) *RPC { + if len(r.Publish) != len(r.messageChecksums) || r.unwanted == nil { + return r + } + // First check if there are any unwanted messages. + // If all messages are wanted, we can return the original RPC. + anyUnwanted := false + for i := range r.Publish { + csum := r.messageChecksums[i] + if r.unwanted.has(to, csum) { + anyUnwanted = true + break + } + } + if !anyUnwanted { + return r + } + + // There are some unwanted messages, so we need to filter them out. + // We need to copy the RPC as other senders could be using the same RPC. + filtered := copyRPC(r) + filtered.Publish = slices.Clone(r.Publish) + // The filtered RPC will not need these checksums since it will not be + // filitered again. If we did need them, we'd have to close the slice and + // synchronize the deletes below + filtered.messageChecksums = nil + + // Remove unwanted messages from the publish list. + i := slices.IndexFunc(r.messageChecksums, func(csum checksum) bool { + return r.unwanted.has(to, csum) + }) + if i == -1 { + return r + } + for j := i + 1; j < len(filtered.Publish); j++ { + if !r.unwanted.has(to, r.messageChecksums[j]) { + filtered.Publish[i] = filtered.Publish[j] + i++ + } + } + clear(filtered.Publish[i:]) + filtered.Publish = filtered.Publish[:i] + return filtered +} + +// split splits the RPC into multiple RPCs if the total size exceeds maxRPCSize. +// It may still return a single RPC that is larger than maxRPCSize in case it +// can't split the RPC up further. Caller should take care of handling oversized +// RPCs appropriately. +// +// A note for maintainers: +// The details of this are tied to Protobuf encoding. It is recommended to +// familiarize yourself with the following resource: https://protobuf.dev/programming-guides/encoding/ +// +// Also note that the +1 byte for the protobuf field number + wire type assumes the field +// number is <= 15. If the field number is larger, it will use more than one byte. The formula is: +// byteLengthOfVarint(varintEncode(fieldNumber << 3 | wireType)) or sovRpc(fieldNumber << 3 | wireType) +// +// Make sure to run the Fuzz test after any changes. It's very good at detecting issues. +func (r *RPC) split(maxRPCSize int) []RPC { + // Fast path: if the RPC is smaller than maxRPCSize, return it as is. + if r.Size() <= maxRPCSize { + return []RPC{*r} + } + + out := make([]RPC, 0, 1) + currentRPC := RPC{} + + // Split control messages. This is trickier than other fields because we are + // splitting one level deeper. + var ctrlSize int + for incrementalSize, mergeFn := range r.rpcControlComponents() { + nextSize := ctrlSize + incrementalSize + nextLenPrefixSize := sovRpc(uint64(nextSize)) + // +1 for the protobuf field number + wire type + if nextSize+nextLenPrefixSize+1 >= maxRPCSize && ctrlSize > 0 { + out = append(out, currentRPC) + currentRPC = RPC{} + ctrlSize = 0 + } + ctrlSize += incrementalSize + if currentRPC.Control == nil { + currentRPC.Control = &pb.ControlMessage{} + } + mergeFn(currentRPC.Control) + } + + var currentSize int + if ctrlSize > 0 { + currentSize = ctrlSize + 1 + sovRpc(uint64(ctrlSize)) + } + // Split subscriptions. + for _, rpc := range r.Subscriptions { + subSize := rpc.Size() + // +1 for the protobuf field number + wire type + incrementalSize := subSize + 1 + sovRpc(uint64(subSize)) + if currentSize+incrementalSize >= maxRPCSize && currentSize > 0 { + out = append(out, currentRPC) + currentRPC = RPC{} + currentSize = 0 + } + currentSize += incrementalSize + currentRPC.Subscriptions = append(currentRPC.Subscriptions, rpc) + } + + for _, msg := range r.Publish { + msgSize := msg.Size() + // +1 for the protobuf field number + wire type + incrementalSize := msgSize + 1 + sovRpc(uint64(msgSize)) + if currentSize+incrementalSize >= maxRPCSize && currentSize > 0 { + out = append(out, currentRPC) + currentRPC = RPC{} + currentSize = 0 + } + currentSize += incrementalSize + currentRPC.Publish = append(currentRPC.Publish, msg) + } + + if currentSize > 0 { + out = append(out, currentRPC) + } + + // Set common fields for all RPCs + for i := range out { + out[i].from = r.from + out[i].unwanted = r.unwanted + out[i].messageChecksums = r.messageChecksums + } + return out +} + +func sovRpc(x uint64) (n int) { + return (math_bits.Len64(x|1) + 6) / 7 +} + +// rpcControlComponents returns an iterator over the control messages of the RPC. +// The first item in the pair is the incremental size of adding this component to the control message. +// The second item is a function that merges the component into an existing ControlMessage pb. +// +// Read the comment in RPC.split() before modifying this function. +func (r *RPC) rpcControlComponents() iter.Seq2[int, func(*pb.ControlMessage)] { + return func(yield func(int, func(*pb.ControlMessage)) bool) { + if r.Control == nil { + return + } + for _, idontwant := range r.Control.Idontwant { + s := idontwant.Size() + incrementalSize := s + 1 + sovRpc(uint64(s)) + if !yield(incrementalSize, func(a *pb.ControlMessage) { + a.Idontwant = append(a.Idontwant, idontwant) + }) { + return + } + } + for _, graft := range r.Control.Graft { + s := graft.Size() + incrementalSize := s + 1 + sovRpc(uint64(s)) + if !yield(incrementalSize, func(a *pb.ControlMessage) { + a.Graft = append(a.Graft, graft) + }) { + return + } + } + for _, prune := range r.Control.Prune { + s := prune.Size() + incrementalSize := s + 1 + sovRpc(uint64(s)) + if !yield(incrementalSize, func(a *pb.ControlMessage) { + a.Prune = append(a.Prune, prune) + }) { + return + } + } + for _, iwant := range r.Control.Iwant { + s := iwant.Size() + incrementalSize := s + 1 + sovRpc(uint64(s)) + if !yield(incrementalSize, func(a *pb.ControlMessage) { + a.Iwant = append(a.Iwant, iwant) + }) { + return + } + } + for _, ihave := range r.Control.Ihave { + s := ihave.Size() + incrementalSize := s + 1 + sovRpc(uint64(s)) + if !yield(incrementalSize, func(a *pb.ControlMessage) { + a.Ihave = append(a.Ihave, ihave) + }) { + return + } + } + } } type Option func(*PubSub) error