diff --git a/trie.go b/trie.go index f61fcef..a46f74d 100644 --- a/trie.go +++ b/trie.go @@ -83,6 +83,17 @@ func (pt *Trie) CoveredNetworks(network netip.Prefix) []netip.Prefix { return pt.coveredNetworks(network) } +// WalkFunc is the type of the function called for each network visited by Walk methods. +type WalkFunc func(network netip.Prefix, value any) error + +// CoveredNetworksWalk walks networks contained within the given network, calling walkFn. +// +// Note: Inserted addresses are normalized to IPv6, so the returned list will be IPv6 only. +func (pt *Trie) CoveredNetworksWalk(network netip.Prefix, walkFn WalkFunc) error { + network = normalizePrefix(network) + return pt.coveredNetworksWalk(network, walkFn) +} + // String returns string representation of trie. // // The result will contain implicit nodes which exist as parents for multiple entries, but can be distinguished by the @@ -197,6 +208,19 @@ func (pt *Trie) coveredNetworks(network netip.Prefix) []netip.Prefix { return results } +func (pt *Trie) coveredNetworksWalk(network netip.Prefix, walkFn WalkFunc) error { + if network.Bits() <= pt.network.Bits() && network.Contains(pt.network.Addr()) { + return pt.walkDepthFunc(walkFn) + } else if pt.network.Bits() < 128 { + bit := pt.discriminatorBitFromIP(network.Addr()) + child := pt.children[bit] + if child != nil { + return child.coveredNetworksWalk(network, walkFn) + } + } + return nil +} + // This is an unsafe, but faster version of netip.Prefix.Contains func netContains(pfx netip.Prefix, ip netip.Addr) bool { pfxAddr := addr128(pfx.Addr()) @@ -378,6 +402,24 @@ func (pt *Trie) walkDepth() <-chan netip.Prefix { return entries } +// walkDepthFunc walks the trie in depth order, calling walkFn for each network. +func (pt *Trie) walkDepthFunc(walkFn WalkFunc) error { + if pt.value != nil { + if err := walkFn(pt.network, pt.value); err != nil { + return err + } + } + for _, trie := range pt.children { + if trie == nil { + continue + } + if err := trie.walkDepthFunc(walkFn); err != nil { + return err + } + } + return nil +} + // TrieLoader can be used to improve the performance of bulk inserts to a Trie. It caches the node of the // last insert in the tree, using it as the starting point to start searching for the location of the next insert. This // is highly beneficial when the addresses are pre-sorted. @@ -452,6 +494,7 @@ func unempty(v any) any { func addr128(addr netip.Addr) uint128 { return *(*uint128)(unsafe.Pointer(&addr)) } + func init() { // Accessing the underlying data of a `netip.Addr` relies upon the data being // in a known format, which is not guaranteed to be stable. So this init() diff --git a/trie_test.go b/trie_test.go index e7b0d87..c9d8fb6 100644 --- a/trie_test.go +++ b/trie_test.go @@ -2,6 +2,7 @@ package iptrie import ( "encoding/binary" + "errors" "fmt" "math/rand" "net/netip" @@ -405,6 +406,9 @@ type coveredNetworkTest struct { inserts []string search string networks []string + walk []string + stopWalk string + error bool name string } @@ -413,36 +417,54 @@ var coveredNetworkTests = []coveredNetworkTest{ []string{"192.168.0.0/24"}, "192.168.0.0/16", []string{"192.168.0.0/24"}, + []string{"192.168.0.0/24"}, + "", + false, "basic covered networks", }, { []string{"192.168.0.0/24"}, "10.1.0.0/16", nil, + nil, + "", + false, "nothing", }, { []string{"192.168.0.0/24", "192.168.0.0/25"}, "192.168.0.0/16", []string{"192.168.0.0/24", "192.168.0.0/25"}, + []string{"192.168.0.0/24", "192.168.0.0/25"}, + "192.168.1.0/25", + false, "multiple networks", }, { []string{"192.168.0.0/24", "192.168.0.0/25", "192.168.0.1/32"}, "192.168.0.0/16", []string{"192.168.0.0/24", "192.168.0.0/25", "192.168.0.1/32"}, + []string{"192.168.0.0/24", "192.168.0.0/25"}, + "192.168.0.0/25", + true, "multiple networks 2", }, { []string{"192.168.1.1/32"}, "192.168.0.0/16", []string{"192.168.1.1/32"}, + []string{"192.168.1.1/32"}, + "", + false, "leaf", }, { []string{"0.0.0.0/0", "192.168.1.1/32"}, "192.168.0.0/16", []string{"192.168.1.1/32"}, + []string{"192.168.1.1/32"}, + "", + false, "leaf with root", }, { @@ -452,14 +474,32 @@ var coveredNetworkTests = []coveredNetworkTest{ }, "192.168.0.0/16", []string{"192.168.0.0/24", "192.168.1.1/32"}, + []string{"192.168.0.0/24", "192.168.1.1/32"}, + "10.1.0.0/16", + false, "path not taken", }, + { + []string{ + "0.0.0.0/0", "192.168.0.0/24", "192.168.1.1/32", + "10.1.0.0/16", "10.1.1.0/24", "192.168.2.2/32", + }, + "192.168.0.0/16", + []string{"192.168.0.0/24", "192.168.1.1/32", "192.168.2.2/32"}, + []string{"192.168.0.0/24", "192.168.1.1/32"}, + "192.168.1.1/32", + true, + "path not taken and stopped", + }, { []string{ "192.168.0.0/15", }, "192.168.0.0/16", nil, + nil, + "", + false, "only masks different", }, } @@ -485,6 +525,40 @@ func TestTrieCoveredNetworks(t *testing.T) { } } +func TestTrieCoveredNetworksWalk(t *testing.T) { + for _, tc := range coveredNetworkTests { + t.Run(tc.name, func(t *testing.T) { + trie := NewTrie() + for _, insert := range tc.inserts { + network := netip.MustParsePrefix(insert) + v := any(insert) + trie.Insert(network, v) + } + var expectedEntries []netip.Prefix + for _, network := range tc.walk { + expected := normalizePrefix(netip.MustParsePrefix(network)) + expectedEntries = append(expectedEntries, expected) + } + snet := netip.MustParsePrefix(tc.search) + var networks []netip.Prefix + walkFn := func(network netip.Prefix, v any) error { + networks = append(networks, network) + if stopWalk := v.(string); stopWalk == tc.stopWalk { + return errors.New(stopWalk) + } + return nil + } + err := trie.CoveredNetworksWalk(snet, walkFn) + assert.Equal(t, expectedEntries, networks) + if tc.error { + assert.EqualError(t, err, tc.stopWalk) + } else { + assert.Nil(t, err) + } + }) + } +} + func TestTrieMemUsage(t *testing.T) { if testing.Short() { t.Skip("Skipping memory test in `-short` mode")