From 114044495fa10b51858d0f4e0ac3b46a17b979b6 Mon Sep 17 00:00:00 2001 From: raulk Date: Sun, 26 Oct 2025 00:01:10 +0100 Subject: [PATCH] add Announce API for GossipSub message announcement MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Implements a new Announce API that allows advertising messages via IHAVE without pushing them to the mesh. These messages are retained in the mcache until, at least, the specified deadline. We send the IHAVE immediately to all connected topic subscribers, whether in mesh, in gossip, in a cached fanout, or none of these active pubsub states. This enables pull-based message distribution, useful for scenarios outside of the app's critical path, such as backup availability. Subscribers can pull messages on-demand via IWANT requests. Further details: - Add Topic.Announce() method that sends IHAVE gossip to topic subscribers with expiry-based message retention - Refactor MessageCache to support dual storage model: - Sliding window for regular published messages - Time wheel for announced messages with TTL-based expiry - Add GossipSubAnnouncementMaxTTL parameter (default 60s) for sizing announcement storage - Rename MessageCache methods for clarity (Put→AppendWindow, Shift→ShiftWindow, GetGossipIDs→GossipForTopic). Add missing godocs. - Implement unified message storage with reference counting to handle messages in both window and announcement wheel - Add heartbeat cleanup for expired announcements via PruneAnns() - Add comprehensive test coverage for announcement functionality including storage, delivery, expiry, duplicates, and edge cases --- announce_test.go | 505 +++++++++++++++++++++++++++++++++++++++++++++++ gossipsub.go | 63 +++++- mcache.go | 186 +++++++++++++---- mcache_test.go | 137 +++++++++++-- topic.go | 54 +++++ 5 files changed, 882 insertions(+), 63 deletions(-) create mode 100644 announce_test.go diff --git a/announce_test.go b/announce_test.go new file mode 100644 index 00000000..b822b168 --- /dev/null +++ b/announce_test.go @@ -0,0 +1,505 @@ +package pubsub + +import ( + "bytes" + "context" + "testing" + "time" + + pb "github.com/libp2p/go-libp2p-pubsub/pb" + "github.com/libp2p/go-libp2p/core/peer" +) + +func TestAnnounceStorage(t *testing.T) { + ctx := t.Context() + + const topic = "test-announce-storage" + hosts := getDefaultHosts(t, 2) + psubs := getGossipsubs(ctx, hosts) + connectAll(t, hosts) + + topics := getTopics(psubs, topic) + + // Host 1 subscribes + _, err := topics[1].Subscribe() + if err != nil { + t.Fatal(err) + } + + time.Sleep(time.Millisecond * 500) + + // Host 0 announces + payload := []byte("test storage") + expiry := time.Now().Add(time.Second * 10) + err = topics[0].Announce(ctx, payload, expiry) + if err != nil { + t.Fatal(err) + } + + // Verify the message is stored in host 0's message cache announcements + gs0, ok := psubs[0].rt.(*GossipSubRouter) + if !ok { + t.Fatal("expected GossipSubRouter") + } + + resultChan := make(chan int, 1) + psubs[0].eval <- func() { + // Count total announcements across all buckets in the wheel + count := 0 + for _, bucket := range gs0.mcache.annWheel { + count += len(bucket) + } + resultChan <- count + } + + count := <-resultChan + if count != 1 { + t.Fatalf("expected 1 announcement stored, got %d", count) + } +} + +func TestAnnounceBasic(t *testing.T) { + ctx := t.Context() + + const topic = "test-announce" + hosts := getDefaultHosts(t, 3) + psubs := getGossipsubs(ctx, hosts) + connectAll(t, hosts) + + // Get topics for all hosts + topics := getTopics(psubs, topic) + + // Subscribe on host 1 and 2 + sub1, err := topics[1].Subscribe() + if err != nil { + t.Fatal(err) + } + + sub2, err := topics[2].Subscribe() + if err != nil { + t.Fatal(err) + } + + // Wait for mesh to form and subscriptions to propagate + time.Sleep(time.Second * 2) + + // Host 0 announces a message (not subscribed) + payload := []byte("announced message") + expiry := time.Now().Add(time.Second * 5) + err = topics[0].Announce(ctx, payload, expiry) + if err != nil { + t.Fatal(err) + } + + // Subscribers should receive the message via IWANT + timeoutCtx, cancel := context.WithTimeout(ctx, time.Second*5) + defer cancel() + + msg1, err := sub1.Next(timeoutCtx) + if err != nil { + t.Fatalf("host 1 failed to receive message: %v", err) + } + if !bytes.Equal(msg1.Data, payload) { + t.Fatalf("received incorrect message: got %s, want %s", msg1.Data, payload) + } + + msg2, err := sub2.Next(timeoutCtx) + if err != nil { + t.Fatalf("host 2 failed to receive message: %v", err) + } + if !bytes.Equal(msg2.Data, payload) { + t.Fatalf("received incorrect message: got %s, want %s", msg2.Data, payload) + } +} + +func TestAnnounceWhenSubscribed(t *testing.T) { + ctx := t.Context() + + const topic = "test-announce-subscribed" + hosts := getDefaultHosts(t, 2) + psubs := getGossipsubs(ctx, hosts) + connectAll(t, hosts) + + topics := getTopics(psubs, topic) + + // Both hosts subscribe + sub0, err := topics[0].Subscribe() + if err != nil { + t.Fatal(err) + } + + sub1, err := topics[1].Subscribe() + if err != nil { + t.Fatal(err) + } + + time.Sleep(time.Millisecond * 500) + + // Host 0 announces while subscribed + payload := []byte("announced while subscribed") + expiry := time.Now().Add(time.Second * 5) + err = topics[0].Announce(ctx, payload, expiry) + if err != nil { + t.Fatal(err) + } + + // Host 0 should NOT receive its own announcement (marked as seen) + timeoutCtx, cancel := context.WithTimeout(ctx, time.Millisecond*200) + defer cancel() + msg, err := sub0.Next(timeoutCtx) + if err != context.DeadlineExceeded { + if msg != nil { + t.Fatal("announcer should not receive own announcement when subscribed") + } + t.Fatalf("expected timeout, got error: %v", err) + } + + // Host 1 should receive it + msg1, err := sub1.Next(ctx) + if err != nil { + t.Fatal(err) + } + if !bytes.Equal(msg1.Data, payload) { + t.Fatalf("received incorrect message: got %s, want %s", msg1.Data, payload) + } +} + +func TestAnnounceDuplicate(t *testing.T) { + ctx := t.Context() + + const topic = "test-announce-duplicate" + hosts := getDefaultHosts(t, 2) + psubs := getGossipsubs(ctx, hosts, WithMessageIdFn(func(msg *pb.Message) string { + // use a content addressed ID function + return string(msg.Data) + })) + connectAll(t, hosts) + + topics := getTopics(psubs, topic) + + // Host 0 subscribes + _, err := topics[0].Subscribe() + if err != nil { + t.Fatal(err) + } + + // Host 1 subscribes + sub1, err := topics[1].Subscribe() + if err != nil { + t.Fatal(err) + } + + time.Sleep(time.Millisecond * 500) + + payload := []byte("duplicate test") + expiry := time.Now().Add(time.Second * 5) + + // First announcement should succeed + err = topics[0].Announce(ctx, payload, expiry) + if err != nil { + t.Fatalf("first announce failed: %v", err) + } + + // Host 1 receives the message + msg1, err := sub1.Next(ctx) + if err != nil { + t.Fatal(err) + } + if !bytes.Equal(msg1.Data, payload) { + t.Fatal("received incorrect message") + } + + // Try announcing the exact same payload again - this is a duplicate + expiry = time.Now().Add(time.Second * 5) + err = topics[0].Announce(ctx, payload, expiry) + if err != nil { + t.Fatalf("second announce failed: %v", err) + } + + // Host 1 should NOT receive the duplicate message (it should be filtered) + timeoutCtx, cancel := context.WithTimeout(ctx, time.Millisecond*500) + defer cancel() + msg2, err := sub1.Next(timeoutCtx) + if err != context.DeadlineExceeded { + if msg2 != nil { + t.Fatal("host 1 should not receive duplicate announcement") + } + t.Fatalf("expected timeout for duplicate message, got error: %v", err) + } +} + +func TestAnnounceExpiry(t *testing.T) { + ctx := t.Context() + + const topic = "test-announce-expiry" + hosts := getDefaultHosts(t, 2) + psubs := getGossipsubs(ctx, hosts) + connectAll(t, hosts) + + topics := getTopics(psubs, topic) + + // Only host 1 subscribes + _, err := topics[1].Subscribe() + if err != nil { + t.Fatal(err) + } + + time.Sleep(time.Millisecond * 500) + + // Announce with very short expiry + payload := []byte("expires soon") + expiry := time.Now().Add(time.Millisecond * 100) + err = topics[0].Announce(ctx, payload, expiry) + if err != nil { + t.Fatal(err) + } + + // Wait for expiry plus heartbeat + time.Sleep(time.Millisecond*100 + time.Second*2) + + // Try to access the gossipsub router to verify cleanup + gs0, ok := psubs[0].rt.(*GossipSubRouter) + if !ok { + t.Fatal("expected GossipSubRouter") + } + + // Check that the announcement was cleaned up + resultChan := make(chan int, 1) + psubs[0].eval <- func() { + // Count total announcements across all buckets in the wheel + count := 0 + for _, bucket := range gs0.mcache.annWheel { + count += len(bucket) + } + resultChan <- count + } + + announcementCount := <-resultChan + if announcementCount != 0 { + t.Fatalf("expected 0 announcements after expiry, got %d", announcementCount) + } +} + +func TestAnnounceNoSubscribers(t *testing.T) { + ctx := t.Context() + + const topic = "test-announce-no-subs" + hosts := getDefaultHosts(t, 2) + psubs := getGossipsubs(ctx, hosts) + connectAll(t, hosts) + + topics := getTopics(psubs, topic) + + // No one subscribes + time.Sleep(time.Millisecond * 500) + + // Announce should succeed even without subscribers (it's a no-op) + payload := []byte("no subscribers") + expiry := time.Now().Add(time.Second * 5) + err := topics[0].Announce(ctx, payload, expiry) + if err != nil { + t.Fatal(err) + } + + // Since no one is subscribed, the message is not stored and no IHAVE is sent + _, err = topics[1].Subscribe() + if err != nil { + t.Fatal(err) + } + + time.Sleep(time.Millisecond * 500) + + // Now announce another message - this one should be received + payload2 := []byte("with subscriber") + expiry2 := time.Now().Add(time.Second * 5) + err = topics[0].Announce(ctx, payload2, expiry2) + if err != nil { + t.Fatal(err) + } + + // Verify the announcement was stored + gs0, ok := psubs[0].rt.(*GossipSubRouter) + if !ok { + t.Fatal("expected GossipSubRouter") + } + + resultChan := make(chan int, 1) + psubs[0].eval <- func() { + // Count total announcements across all buckets in the wheel + count := 0 + for _, bucket := range gs0.mcache.annWheel { + count += len(bucket) + } + resultChan <- count + } + + count := <-resultChan + if count != 1 { + t.Fatalf("expected 1 announcements stored, got %d", count) + } +} + +func TestAnnounceMultipleMessages(t *testing.T) { + ctx := t.Context() + + const topic = "test-announce-multiple" + hosts := getDefaultHosts(t, 3) + psubs := getGossipsubs(ctx, hosts) + connectAll(t, hosts) + + topics := getTopics(psubs, topic) + + // All hosts subscribe + subs := make([]*Subscription, 3) + for i := range 3 { + sub, err := topics[i].Subscribe() + if err != nil { + t.Fatal(err) + } + subs[i] = sub + } + + time.Sleep(time.Millisecond * 500) + + // Host 0 announces multiple messages + numMessages := 5 + payloads := make([][]byte, numMessages) + expiry := time.Now().Add(time.Second * 10) + + for i := range numMessages { + payloads[i] = []byte("message " + string(rune('0'+i))) + err := topics[0].Announce(ctx, payloads[i], expiry) + if err != nil { + t.Fatal(err) + } + time.Sleep(time.Millisecond * 50) + } + + // Host 1 and 2 should receive all messages + for hostIdx := 1; hostIdx < 3; hostIdx++ { + receivedCount := 0 + for receivedCount < numMessages { + timeoutCtx, cancel := context.WithTimeout(ctx, time.Second*2) + _, err := subs[hostIdx].Next(timeoutCtx) + cancel() + if err != nil { + t.Fatalf("host %d: failed to receive message %d: %v", hostIdx, receivedCount, err) + } + receivedCount++ + } + } +} + +func TestAnnounceWithClosedTopic(t *testing.T) { + ctx := t.Context() + + const topic = "test-announce-closed" + hosts := getDefaultHosts(t, 1) + psubs := getGossipsubs(ctx, hosts) + + topics := getTopics(psubs, topic) + + // Close the topic + err := topics[0].Close() + if err != nil { + t.Fatal(err) + } + + // Announce should fail with ErrTopicClosed + payload := []byte("should fail") + expiry := time.Now().Add(time.Second * 5) + err = topics[0].Announce(ctx, payload, expiry) + if err != ErrTopicClosed { + t.Fatalf("expected ErrTopicClosed, got %v", err) + } +} + +func TestAnnounceWithFloodsub(t *testing.T) { + ctx := t.Context() + + const topic = "test-announce-floodsub" + hosts := getDefaultHosts(t, 1) + + // Create a floodsub instance instead of gossipsub + psubs := getPubsubs(ctx, hosts) // This creates floodsub + + topics := getTopics(psubs, topic) + + // Announce should fail with non-GossipSub router + payload := []byte("floodsub test") + expiry := time.Now().Add(time.Second * 5) + err := topics[0].Announce(ctx, payload, expiry) + if err == nil { + t.Fatal("expected error with floodsub router, got nil") + } +} + +func TestAnnounceGossipThreshold(t *testing.T) { + ctx := t.Context() + + const topic = "test-announce-threshold" + hosts := getDefaultHosts(t, 3) + + // Setup peer scoring with gossip threshold + psubs := getGossipsubs(ctx, hosts, + WithPeerScore( + &PeerScoreParams{ + AppSpecificScore: func(p peer.ID) float64 { + // Give host 2 a very low score + if p == hosts[2].ID() { + return -1000 + } + return 0 + }, + AppSpecificWeight: 1.0, + DecayInterval: time.Second, + DecayToZero: 0.01, + }, + &PeerScoreThresholds{ + GossipThreshold: -500, + PublishThreshold: -1000, + GraylistThreshold: -2000, + }, + ), + ) + + connectAll(t, hosts) + topics := getTopics(psubs, topic) + + // All hosts subscribe + _, err := topics[0].Subscribe() + if err != nil { + t.Fatal(err) + } + + _, err = topics[1].Subscribe() + if err != nil { + t.Fatal(err) + } + + sub2, err := topics[2].Subscribe() + if err != nil { + t.Fatal(err) + } + + time.Sleep(time.Second * 1) + + // Host 0 announces + payload := []byte("threshold test") + expiry := time.Now().Add(time.Second * 5) + err = topics[0].Announce(ctx, payload, expiry) + if err != nil { + t.Fatal(err) + } + + // Host 2 with low score should not receive IHAVE + timeoutCtx, cancel := context.WithTimeout(ctx, time.Millisecond*500) + defer cancel() + msg, err := sub2.Next(timeoutCtx) + if err != context.DeadlineExceeded { + if msg != nil { + t.Fatal("host with low score should not receive announcement") + } + t.Fatalf("expected timeout for low-score peer, got error: %v", err) + } +} diff --git a/gossipsub.go b/gossipsub.go index c492ded9..498ba442 100644 --- a/gossipsub.go +++ b/gossipsub.go @@ -62,6 +62,7 @@ var ( GossipSubHeartbeatInitialDelay = 100 * time.Millisecond GossipSubHeartbeatInterval = 1 * time.Second GossipSubFanoutTTL = 60 * time.Second + GossipSubAnnouncementMaxTTL = 60 * time.Second GossipSubPrunePeers = 16 GossipSubPruneBackoff = time.Minute GossipSubUnsubscribeBackoff = 10 * time.Second @@ -168,6 +169,11 @@ type GossipSubParams struct { // we'll delete the fanout map for that topic. FanoutTTL time.Duration + // AnnouncementMaxTTL is the maximum possible time-to-live for a message announced + // via Announce. This is used to size internal data structures. Deadlines passed to + // Announce exceeding this value will be clamped, and a warning will be logged. + AnnouncementMaxTTL time.Duration + // PrunePeers controls the number of peers to include in prune Peer eXchange. // When we prune a peer that's eligible for PX (has a good score, etc), we will try to // send them signed peer records for up to PrunePeers other peers that we @@ -292,6 +298,7 @@ func NewGossipSubWithRouter(ctx context.Context, h host.Host, rt PubSubRouter, o // DefaultGossipSubRouter returns a new GossipSubRouter with default parameters. func DefaultGossipSubRouter(h host.Host) *GossipSubRouter { params := DefaultGossipSubParams() + mcache := NewMessageCache(params.HistoryGossip, params.HistoryLength, params.HeartbeatInterval, params.AnnouncementMaxTTL) rt := &GossipSubRouter{ peers: make(map[peer.ID]protocol.ID), mesh: make(map[string]map[peer.ID]struct{}), @@ -307,7 +314,7 @@ func DefaultGossipSubRouter(h host.Host) *GossipSubRouter { outbound: make(map[peer.ID]bool), connect: make(chan connectInfo, params.MaxPendingConnections), cab: pstoremem.NewAddrBook(), - mcache: NewMessageCache(params.HistoryGossip, params.HistoryLength), + mcache: mcache, protos: GossipSubDefaultProtocols, feature: GossipSubDefaultFeatures, tagTracer: newTagTracer(h.ConnManager()), @@ -341,6 +348,7 @@ func DefaultGossipSubParams() GossipSubParams { HeartbeatInitialDelay: GossipSubHeartbeatInitialDelay, HeartbeatInterval: GossipSubHeartbeatInterval, FanoutTTL: GossipSubFanoutTTL, + AnnouncementMaxTTL: GossipSubAnnouncementMaxTTL, PrunePeers: GossipSubPrunePeers, PruneBackoff: GossipSubPruneBackoff, UnsubscribeBackoff: GossipSubUnsubscribeBackoff, @@ -569,7 +577,7 @@ func WithGossipSubParams(cfg GossipSubParams) Option { // Overwrite current config and associated variables in the router. gs.params = cfg gs.connect = make(chan connectInfo, cfg.MaxPendingConnections) - gs.mcache = NewMessageCache(cfg.HistoryGossip, cfg.HistoryLength) + gs.mcache = NewMessageCache(cfg.HistoryGossip, cfg.HistoryLength, cfg.HeartbeatInterval, cfg.AnnouncementMaxTTL) return nil } @@ -1303,7 +1311,7 @@ func (gs *GossipSubRouter) Publish(msg *Message) { func (gs *GossipSubRouter) rpcs(msg *Message) iter.Seq2[peer.ID, *RPC] { return func(yield func(peer.ID, *RPC) bool) { - gs.mcache.Put(msg) + gs.mcache.AppendWindow(msg) from := msg.ReceivedFrom topic := msg.GetTopic() @@ -1595,6 +1603,9 @@ func (gs *GossipSubRouter) heartbeat() { // clean up IDONTWANT counters gs.clearIDontWantCounters() + // clean up expired announcements + gs.purgeAnnouncements() + // apply IWANT request penalties gs.applyIwantPenalties() @@ -1832,7 +1843,7 @@ func (gs *GossipSubRouter) heartbeat() { gs.flush() // advance the message history window - gs.mcache.Shift() + gs.mcache.ShiftWindow() } func (gs *GossipSubRouter) clearIHaveCounters() { @@ -1864,6 +1875,10 @@ func (gs *GossipSubRouter) clearIDontWantCounters() { } } +func (gs *GossipSubRouter) purgeAnnouncements() { + gs.mcache.PruneAnns() +} + func (gs *GossipSubRouter) applyIwantPenalties() { for p, count := range gs.gossipTracer.GetBrokenPromises() { gs.logger.Info("peer didn't follow up in IWANT requests; adding penalty", "peer", p, "requestCount", count) @@ -1956,7 +1971,7 @@ func (gs *GossipSubRouter) sendGraftPrune(tograft, toprune map[peer.ID][]string, // emitGossip emits IHAVE gossip advertising items in the message cache window // of this topic. func (gs *GossipSubRouter) emitGossip(topic string, exclude map[peer.ID]struct{}) { - mids := gs.mcache.GetGossipIDs(topic) + mids := gs.mcache.GossipForTopic(topic) if len(mids) == 0 { return } @@ -2033,6 +2048,44 @@ func (gs *GossipSubRouter) enqueueGossip(p peer.ID, ihave *pb.ControlIHave) { gs.gossip[p] = gossip } +func (gs *GossipSubRouter) announceMessage(topic string, msg *Message, expiry time.Time) { + // Get all peers in topic + tmap, ok := gs.p.topics[topic] + if !ok { + return + } + + // Store message for IWANT retrieval in the message cache + msgID := gs.p.idGen.ID(msg) + + // Send IHAVE to all topic peers (excluding direct peers, applying score threshold) + // Match the filtering logic from emitGossip + var gossipQueued bool + for p := range tmap { + if !gs.feature(GossipSubFeatureMesh, gs.peers[p]) { + continue + } + if gs.score.Score(p) < gs.gossipThreshold { + continue + } + gs.enqueueGossip(p, &pb.ControlIHave{ + TopicID: &topic, + MessageIDs: []string{msgID}, + }) + gossipQueued = true + } + + if !gossipQueued { + return + } + + // Track announcement in message cache for IWANT retrieval + gs.mcache.TrackAnn(msg, expiry) + + // Flush gossip immediately + gs.flush() +} + func (gs *GossipSubRouter) piggybackGossip(p peer.ID, out *RPC, ihave []*pb.ControlIHave) { ctl := out.GetControl() if ctl == nil { diff --git a/mcache.go b/mcache.go index e4e82d90..12d9f3ba 100644 --- a/mcache.go +++ b/mcache.go @@ -2,69 +2,108 @@ package pubsub import ( "fmt" + "time" "github.com/libp2p/go-libp2p/core/peer" ) +type historyEntry struct { + mid string + topic string +} + +type messageRef struct { + *Message + refs int +} + +type MessageCache struct { + msgID func(*Message) string + + // All messages unified storage, indexed by message ID + // Messages can be in window, announcement wheel, or both + msgs map[string]*messageRef + + // Sliding window for all messages + history [][]historyEntry + gossipLen int + + // Time wheel for announcements with expiry-based cleanup + // Behaves like a circular buffer of time buckets containing message IDs + // Actual messages are stored in the unified storage `msgs` + annWheel [][]string + annWheelPos int + annWheelTick time.Duration + + // Per-peer transmission counters + peertx map[string]map[peer.ID]int +} + // NewMessageCache creates a sliding window cache that remembers messages for as -// long as `history` slots. +// long as `historyLen` slots. // -// When queried for messages to advertise, the cache only returns messages in -// the last `gossip` slots. +// When queried for messages to advertise via gossip, the cache only returns messages +// in the last `gossipLen` slots. // -// The `gossip` parameter must be smaller or equal to `history`, or this +// The `gossipLen` parameter must be smaller or equal to `historyLen`, or this // function will panic. // -// The slack between `gossip` and `history` accounts for the reaction time +// The slack between `gossipLen` and `historyLen` accounts for the reaction time // between when a message is advertised via IHAVE gossip, and the peer pulls it // via an IWANT command. -func NewMessageCache(gossip, history int) *MessageCache { - if gossip > history { +func NewMessageCache(gossipLen, historyLen int, heartbeatInterval, maxTTL time.Duration) *MessageCache { + if gossipLen > historyLen { err := fmt.Errorf("invalid parameters for message cache; gossip slots (%d) cannot be larger than history slots (%d)", - gossip, history) + gossipLen, historyLen) panic(err) } + + wheelLen := ceilDivDuration(maxTTL, heartbeatInterval) + wheel := make([][]string, wheelLen) + return &MessageCache{ - msgs: make(map[string]*Message), - peertx: make(map[string]map[peer.ID]int), - history: make([][]CacheEntry, history), - gossip: gossip, + msgs: make(map[string]*messageRef), + peertx: make(map[string]map[peer.ID]int), + history: make([][]historyEntry, historyLen), + gossipLen: gossipLen, + annWheel: wheel, + annWheelPos: 0, + annWheelTick: heartbeatInterval, msgID: func(msg *Message) string { return DefaultMsgIdFn(msg.Message) }, } } -type MessageCache struct { - msgs map[string]*Message - peertx map[string]map[peer.ID]int - history [][]CacheEntry - gossip int - msgID func(*Message) string -} - func (mc *MessageCache) SetMsgIdFn(msgID func(*Message) string) { mc.msgID = msgID } -type CacheEntry struct { - mid string - topic string -} - -func (mc *MessageCache) Put(msg *Message) { - mid := mc.msgID(msg) - mc.msgs[mid] = msg - mc.history[0] = append(mc.history[0], CacheEntry{mid: mid, topic: msg.GetTopic()}) +// AppendWindow adds a message to the sliding window cache. +// The message will be retained for the duration of the window. +// If the message already exists in the cache, its reference count is incremented. +func (mc *MessageCache) AppendWindow(msg *Message) { + mid := mc.upsertMessage(msg) + mc.history[0] = append(mc.history[0], historyEntry{mid: mid, topic: msg.GetTopic()}) } +// Get retrieves the message for the given message ID without modifying +// any transmission counts. +// It returns the message and a boolean indicating whether the message was found in the cache. func (mc *MessageCache) Get(mid string) (*Message, bool) { - m, ok := mc.msgs[mid] - return m, ok + ref, ok := mc.msgs[mid] + if !ok { + return nil, false + } + return ref.Message, true } +// GetForPeer retrieves the message for the given message ID and increments +// the transmission count for the specified peer. +// It returns the message, the updated transmission count, and a boolean indicating +// whether the message was found in the cache. func (mc *MessageCache) GetForPeer(mid string, p peer.ID) (*Message, int, bool) { - m, ok := mc.msgs[mid] + ref, ok := mc.msgs[mid] if !ok { return nil, 0, false } @@ -76,12 +115,13 @@ func (mc *MessageCache) GetForPeer(mid string, p peer.ID) (*Message, int, bool) } tx[p]++ - return m, tx[p], true + return ref.Message, tx[p], true } -func (mc *MessageCache) GetGossipIDs(topic string) []string { +// GossipForTopic returns the message IDs in the gossip window for the given topic. +func (mc *MessageCache) GossipForTopic(topic string) []string { var mids []string - for _, entries := range mc.history[:mc.gossip] { + for _, entries := range mc.history[:mc.gossipLen] { for _, entry := range entries { if entry.topic == topic { mids = append(mids, entry.mid) @@ -91,10 +131,13 @@ func (mc *MessageCache) GetGossipIDs(topic string) []string { return mids } -func (mc *MessageCache) Shift() { +// ShiftWindow advances the sliding window by one slot. +// Messages that fall out of the window have their reference counts decremented +// and are removed from the cache if they are no longer referenced. +func (mc *MessageCache) ShiftWindow() { last := mc.history[len(mc.history)-1] for _, entry := range last { - delete(mc.msgs, entry.mid) + mc.tryDropMessage(entry.mid) delete(mc.peertx, entry.mid) } for i := len(mc.history) - 2; i >= 0; i-- { @@ -102,3 +145,72 @@ func (mc *MessageCache) Shift() { } mc.history[0] = nil } + +// TrackAnn adds a message to the announcement cache with time-based expiry. +// Unlike AppendWindow, these messages are not part of the sliding window and expire at a specific time. +func (mc *MessageCache) TrackAnn(msg *Message, expiry time.Time) { + ttl := time.Until(expiry) + if ttl <= 0 { + return + } + + mid := mc.upsertMessage(msg) + + // Insert the message into the storage and the wheel + offset := ceilDivDuration(ttl, mc.annWheelTick) + bucket := (mc.annWheelPos + offset) % len(mc.annWheel) + mc.annWheel[bucket] = append(mc.annWheel[bucket], mid) +} + +// PruneAnns removes expired announcements from the cache. +// This should be called periodically (e.g., during heartbeat). +// Advances the time wheel by one tick and cleans up the current bucket. +func (mc *MessageCache) PruneAnns() { + bucket := mc.annWheel[mc.annWheelPos] + + // Drop all messages in the current bucket + for _, mid := range bucket { + mc.tryDropMessage(mid) + delete(mc.peertx, mid) + } + + // Clear the current bucket and advance the wheel position + mc.annWheel[mc.annWheelPos] = mc.annWheel[mc.annWheelPos][:0] + mc.annWheelPos = (mc.annWheelPos + 1) % len(mc.annWheel) +} + +// tryDropMessage decrements the reference count of the message with the given ID. +// If the reference count reaches zero, the message is removed from the cache. +// Returns true if the message was dropped, false otherwise. +func (mc *MessageCache) tryDropMessage(mid string) { + ref, ok := mc.msgs[mid] + if !ok { + return + } + if ref.refs--; ref.refs == 0 { + delete(mc.msgs, mid) + } +} + +func (mc *MessageCache) upsertMessage(msg *Message) string { + mid := mc.msgID(msg) + ref, exists := mc.msgs[mid] + if !exists { + ref = &messageRef{Message: msg} + mc.msgs[mid] = ref + } + ref.refs++ + return mid +} + +// ceilDivDuration performs ceiling division of two time.Duration values. +func ceilDivDuration(a, b time.Duration) int { + switch { + case b <= 0: + panic("b must be > 0") + case a <= 0: + return 0 + default: + return (int(a) + int(b) - 1) / int(b) + } +} diff --git a/mcache_test.go b/mcache_test.go index 93bcfdc6..41ef6e64 100644 --- a/mcache_test.go +++ b/mcache_test.go @@ -4,12 +4,13 @@ import ( "encoding/binary" "fmt" "testing" + "time" pb "github.com/libp2p/go-libp2p-pubsub/pb" ) func TestMessageCache(t *testing.T) { - mcache := NewMessageCache(3, 5) + mcache := NewMessageCache(3, 5, time.Second, 60*time.Second) // 3 gossip, 5 history, 1s heartbeat, 60s max TTL msgID := DefaultMsgIdFn msgs := make([]*pb.Message, 60) @@ -17,11 +18,11 @@ func TestMessageCache(t *testing.T) { msgs[i] = makeTestMessage(i) } - for i := 0; i < 10; i++ { - mcache.Put(&Message{Message: msgs[i]}) + for i := range 10 { + mcache.AppendWindow(&Message{Message: msgs[i]}) } - for i := 0; i < 10; i++ { + for i := range 10 { mid := msgID(msgs[i]) m, ok := mcache.Get(mid) if !ok { @@ -33,21 +34,21 @@ func TestMessageCache(t *testing.T) { } } - gids := mcache.GetGossipIDs("test") + gids := mcache.GossipForTopic("test") if len(gids) != 10 { t.Fatalf("Expected 10 gossip IDs; got %d", len(gids)) } - for i := 0; i < 10; i++ { + for i := range 10 { mid := msgID(msgs[i]) if mid != gids[i] { t.Fatalf("GossipID mismatch for message %d", i) } } - mcache.Shift() + mcache.ShiftWindow() for i := 10; i < 20; i++ { - mcache.Put(&Message{Message: msgs[i]}) + mcache.AppendWindow(&Message{Message: msgs[i]}) } for i := 0; i < 20; i++ { @@ -62,12 +63,12 @@ func TestMessageCache(t *testing.T) { } } - gids = mcache.GetGossipIDs("test") + gids = mcache.GossipForTopic("test") if len(gids) != 20 { t.Fatalf("Expected 20 gossip IDs; got %d", len(gids)) } - for i := 0; i < 10; i++ { + for i := range 10 { mid := msgID(msgs[i]) if mid != gids[10+i] { t.Fatalf("GossipID mismatch for message %d", i) @@ -81,31 +82,31 @@ func TestMessageCache(t *testing.T) { } } - mcache.Shift() + mcache.ShiftWindow() for i := 20; i < 30; i++ { - mcache.Put(&Message{Message: msgs[i]}) + mcache.AppendWindow(&Message{Message: msgs[i]}) } - mcache.Shift() + mcache.ShiftWindow() for i := 30; i < 40; i++ { - mcache.Put(&Message{Message: msgs[i]}) + mcache.AppendWindow(&Message{Message: msgs[i]}) } - mcache.Shift() + mcache.ShiftWindow() for i := 40; i < 50; i++ { - mcache.Put(&Message{Message: msgs[i]}) + mcache.AppendWindow(&Message{Message: msgs[i]}) } - mcache.Shift() + mcache.ShiftWindow() for i := 50; i < 60; i++ { - mcache.Put(&Message{Message: msgs[i]}) + mcache.AppendWindow(&Message{Message: msgs[i]}) } if len(mcache.msgs) != 50 { t.Fatalf("Expected 50 messages in the cache; got %d", len(mcache.msgs)) } - for i := 0; i < 10; i++ { + for i := range 10 { mid := msgID(msgs[i]) _, ok := mcache.Get(mid) if ok { @@ -125,12 +126,12 @@ func TestMessageCache(t *testing.T) { } } - gids = mcache.GetGossipIDs("test") + gids = mcache.GossipForTopic("test") if len(gids) != 30 { t.Fatalf("Expected 30 gossip IDs; got %d", len(gids)) } - for i := 0; i < 10; i++ { + for i := range 10 { mid := msgID(msgs[50+i]) if mid != gids[i] { t.Fatalf("GossipID mismatch for message %d", i) @@ -165,3 +166,97 @@ func makeTestMessage(n int) *pb.Message { Seqno: seqno, } } + +func TestAnnouncementTimeWheel(t *testing.T) { + // Create cache with 60 buckets for announcements (simulating 60 heartbeat intervals) + mcache := NewMessageCache(3, 5, time.Second, 60*time.Second) + msgID := DefaultMsgIdFn + + // Test basic insertion + msg1 := makeTestMessage(1) + expiry1 := time.Now().Add(5 * time.Second) + mcache.TrackAnn(&Message{Message: msg1}, expiry1) + + mid1 := msgID(msg1) + + // Verify message is in cache (announcements are stored in msgs) + if _, ok := mcache.Get(mid1); !ok { + t.Fatal("Message not in announcement cache") + } + + // Verify message can be retrieved + m, _, ok := mcache.GetForPeer(mid1, "peer1") + if !ok { + t.Fatal("Failed to retrieve announced message") + } + if m.Message != msg1 { + t.Fatal("Retrieved message doesn't match") + } + + // Test multiple messages with different expiries + msg2 := makeTestMessage(2) + msg3 := makeTestMessage(3) + expiry2 := time.Now().Add(10 * time.Second) + expiry3 := time.Now().Add(15 * time.Second) + + mcache.TrackAnn(&Message{Message: msg2}, expiry2) + mcache.TrackAnn(&Message{Message: msg3}, expiry3) + + mid2 := msgID(msg2) + mid3 := msgID(msg3) + + // Verify all messages are in cache + if _, ok := mcache.Get(mid1); !ok { + t.Fatal("Message 1 should be in cache") + } + if _, ok := mcache.Get(mid2); !ok { + t.Fatal("Message 2 should be in cache") + } + if _, ok := mcache.Get(mid3); !ok { + t.Fatal("Message 3 should be in cache") + } + + // Test wheel advancement (cleanup) + // Advance 6 ticks (6 seconds) - msg1 should be cleaned up + for i := 0; i < 6; i++ { + mcache.PruneAnns() + } + + // msg1 should be gone + if _, ok := mcache.Get(mid1); ok { + t.Fatal("Message 1 should have been cleaned up") + } + + // msg2 and msg3 should still exist + if _, ok := mcache.Get(mid2); !ok { + t.Fatal("Message 2 should still exist") + } + if _, ok := mcache.Get(mid3); !ok { + t.Fatal("Message 3 should still exist") + } + + // Test expired message insertion (shouldn't be added) + msg4 := makeTestMessage(4) + expiry4 := time.Now().Add(-1 * time.Second) // Already expired + mcache.TrackAnn(&Message{Message: msg4}, expiry4) + + mid4 := msgID(msg4) + if _, ok := mcache.Get(mid4); ok { + t.Fatal("Expired message should not have been added") + } + + // Test wraparound (TTL > wheel size) + msg5 := makeTestMessage(5) + expiry5 := time.Now().Add(70 * time.Second) // Exceeds 60s max + mcache.TrackAnn(&Message{Message: msg5}, expiry5) + + mid5 := msgID(msg5) + if _, ok := mcache.Get(mid5); !ok { + t.Fatal("Long TTL message should still be added (clamped to last bucket)") + } + + // Verify we still have at least 3 messages in cache (msg2, msg3, msg5) + if len(mcache.msgs) < 3 { + t.Fatalf("Expected at least 3 messages in cache, got %d", len(mcache.msgs)) + } +} diff --git a/topic.go b/topic.go index dd094eae..ee4ca117 100644 --- a/topic.go +++ b/topic.go @@ -261,6 +261,60 @@ func (t *Topic) AddToBatch(ctx context.Context, batch *MessageBatch, data []byte return nil } +// Announce sends IHAVE gossip for a message to all peers subscribed to the topic +// without publishing it through the mesh. The message is stored for IWANT retrieval +// until the expiry time. Works even if we're not subscribed to the topic - in that +// case, IHAVE is sent to all connected peers who are subscribed. If we are subscribed, +// the message is marked as seen to prevent duplicate processing. +func (t *Topic) Announce(ctx context.Context, data []byte, expiry time.Time, opts ...PubOpt) error { + t.mux.RLock() + defer t.mux.RUnlock() + + if t.closed { + return ErrTopicClosed + } + + // Validate and construct message (reuse existing validation logic) + msg, err := t.validate(ctx, data, opts...) + if err != nil { + if errors.Is(err, dupeErr{}) { + // If it was a duplicate, we return nil to indicate success. + // Semantically the message was published by us or someone else. + return nil + } + return err + } + + // Get GossipSubRouter + gs, ok := t.p.rt.(*GossipSubRouter) + if !ok { + return fmt.Errorf("announce only works with GossipSub router") + } + + // Execute in pubsub event loop + done := make(chan struct{}) + select { + case t.p.eval <- func() { + gs.announceMessage(t.topic, msg, expiry) + close(done) + }: + case <-t.p.ctx.Done(): + return t.p.ctx.Err() + case <-ctx.Done(): + return ctx.Err() + } + + // Wait for completion + select { + case <-done: + return nil + case <-t.p.ctx.Done(): + return t.p.ctx.Err() + case <-ctx.Done(): + return ctx.Err() + } +} + func (t *Topic) validate(ctx context.Context, data []byte, opts ...PubOpt) (*Message, error) { t.mux.RLock() defer t.mux.RUnlock()