From 815e6fdba45e90cfd89d301844fc21abf0b710c9 Mon Sep 17 00:00:00 2001 From: Hakan Sariman Date: Thu, 2 Oct 2025 19:00:20 +0700 Subject: [PATCH 1/7] [client] Implement DNS query caching in DNSForwarder --- client/internal/dnsfwd/cache.go | 69 +++++++++++++++++++++++++++++ client/internal/dnsfwd/forwarder.go | 65 ++++++++++++++++++++------- 2 files changed, 119 insertions(+), 15 deletions(-) create mode 100644 client/internal/dnsfwd/cache.go diff --git a/client/internal/dnsfwd/cache.go b/client/internal/dnsfwd/cache.go new file mode 100644 index 00000000000..87bb5a48f7f --- /dev/null +++ b/client/internal/dnsfwd/cache.go @@ -0,0 +1,69 @@ +package dnsfwd + +import ( + "net/netip" + "sync" + + "github.com/miekg/dns" +) + +type cache struct { + mu sync.RWMutex + records map[string]*cacheEntry +} + +type cacheEntry struct { + ip4Addrs []netip.Addr + ip6Addrs []netip.Addr +} + +func newCache() *cache { + return &cache{ + records: make(map[string]*cacheEntry), + } +} + +func (c *cache) get(domain string, reqType uint16) ([]netip.Addr, bool) { + c.mu.RLock() + defer c.mu.RUnlock() + entry, exists := c.records[domain] + if !exists { + return nil, false + } + + switch reqType { + case dns.TypeA: + return cloneAddrs(entry.ip4Addrs), entry.ip4Addrs != nil + case dns.TypeAAAA: + return cloneAddrs(entry.ip6Addrs), entry.ip6Addrs != nil + default: + return nil, false + } + +} + +func (c *cache) set(domain string, reqType uint16, addrs []netip.Addr) { + c.mu.Lock() + defer c.mu.Unlock() + entry, exists := c.records[domain] + if !exists { + entry = &cacheEntry{} + c.records[domain] = entry + } + + switch reqType { + case dns.TypeA: + entry.ip4Addrs = cloneAddrs(addrs) + case dns.TypeAAAA: + entry.ip6Addrs = cloneAddrs(addrs) + } +} + +func cloneAddrs(in []netip.Addr) []netip.Addr { + if in == nil { + return nil + } + out := make([]netip.Addr, len(in)) + copy(out, in) + return out +} diff --git a/client/internal/dnsfwd/forwarder.go b/client/internal/dnsfwd/forwarder.go index d912919a1f6..4a52c79baec 100644 --- a/client/internal/dnsfwd/forwarder.go +++ b/client/internal/dnsfwd/forwarder.go @@ -46,6 +46,7 @@ type DNSForwarder struct { fwdEntries []*ForwarderEntry firewall firewaller resolver resolver + cache *cache } func NewDNSForwarder(listenAddress string, ttl uint32, firewall firewaller, statusRecorder *peer.Status) *DNSForwarder { @@ -56,6 +57,7 @@ func NewDNSForwarder(listenAddress string, ttl uint32, firewall firewaller, stat firewall: firewall, statusRecorder: statusRecorder, resolver: net.DefaultResolver, + cache: newCache(), } } @@ -171,6 +173,7 @@ func (f *DNSForwarder) handleDNSQuery(w dns.ResponseWriter, query *dns.Msg) *dns f.updateInternalState(ips, mostSpecificResId, matchingEntries) f.addIPsToResponse(resp, domain, ips) + f.cache.set(domain, question.Qtype, ips) return resp } @@ -282,29 +285,61 @@ func (f *DNSForwarder) setResponseCodeForNotFound(ctx context.Context, resp *dns resp.Rcode = dns.RcodeSuccess } -// handleDNSError processes DNS lookup errors and sends an appropriate error response -func (f *DNSForwarder) handleDNSError(ctx context.Context, w dns.ResponseWriter, question dns.Question, resp *dns.Msg, domain string, err error) { +// handleDNSError processes DNS lookup errors and sends an appropriate error response. +func (f *DNSForwarder) handleDNSError( + ctx context.Context, + w dns.ResponseWriter, + question dns.Question, + resp *dns.Msg, + domain string, + err error, +) { + // Default to SERVFAIL; override below when appropriate. + resp.Rcode = dns.RcodeServerFailure + + qType := question.Qtype + qTypeName := dns.TypeToString[qType] + + // Prefer typed DNS errors; fall back to generic logging otherwise. var dnsErr *net.DNSError + if !errors.As(err, &dnsErr) { + log.Warnf(errResolveFailed, domain, err) + if writeErr := w.WriteMsg(resp); writeErr != nil { + log.Errorf("failed to write failure DNS response: %v", writeErr) + } + return + } - switch { - case errors.As(err, &dnsErr): - resp.Rcode = dns.RcodeServerFailure - if dnsErr.IsNotFound { - f.setResponseCodeForNotFound(ctx, resp, domain, question.Qtype) + // NotFound: set NXDOMAIN / appropriate code via helper. + if dnsErr.IsNotFound { + f.setResponseCodeForNotFound(ctx, resp, domain, qType) + if writeErr := w.WriteMsg(resp); writeErr != nil { + log.Errorf("failed to write failure DNS response: %v", writeErr) } + return + } - if dnsErr.Server != "" { - log.Warnf("failed to resolve query for type=%s domain=%s server=%s: %v", dns.TypeToString[question.Qtype], domain, dnsErr.Server, err) - } else { - log.Warnf(errResolveFailed, domain, err) + // Upstream failed but we might have a cached answer—serve it if present. + if ips, ok := f.cache.get(domain, qType); ok && len(ips) > 0 { + log.Debugf("serving cached DNS response after upstream failure: domain=%s type=%s", domain, qTypeName) + f.addIPsToResponse(resp, domain, ips) + resp.Rcode = dns.RcodeSuccess + if writeErr := w.WriteMsg(resp); writeErr != nil { + log.Errorf("failed to write cached DNS response: %v", writeErr) } - default: - resp.Rcode = dns.RcodeServerFailure + return + } + + // No cache. Log with or without the server field for more context. + if dnsErr.Server != "" { + log.Warnf("failed to resolve: type=%s domain=%s server=%s: %v", qTypeName, domain, dnsErr.Server, err) + } else { log.Warnf(errResolveFailed, domain, err) } - if err := w.WriteMsg(resp); err != nil { - log.Errorf("failed to write failure DNS response: %v", err) + // Write final failure response. + if writeErr := w.WriteMsg(resp); writeErr != nil { + log.Errorf("failed to write failure DNS response: %v", writeErr) } } From dd1ee59e5ba6fb7730969b7c2dbbc62d94817638 Mon Sep 17 00:00:00 2001 From: Hakan Sariman Date: Thu, 2 Oct 2025 19:03:09 +0700 Subject: [PATCH 2/7] Fix cache get method to check for non-empty address slices --- client/internal/dnsfwd/cache.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/client/internal/dnsfwd/cache.go b/client/internal/dnsfwd/cache.go index 87bb5a48f7f..3539f7a51b8 100644 --- a/client/internal/dnsfwd/cache.go +++ b/client/internal/dnsfwd/cache.go @@ -33,9 +33,9 @@ func (c *cache) get(domain string, reqType uint16) ([]netip.Addr, bool) { switch reqType { case dns.TypeA: - return cloneAddrs(entry.ip4Addrs), entry.ip4Addrs != nil + return cloneAddrs(entry.ip4Addrs), len(entry.ip4Addrs) > 0 case dns.TypeAAAA: - return cloneAddrs(entry.ip6Addrs), entry.ip6Addrs != nil + return cloneAddrs(entry.ip6Addrs), len(entry.ip6Addrs) > 0 default: return nil, false } From a928540b01a4693caa5fff190401acc27b1f78b2 Mon Sep 17 00:00:00 2001 From: Hakan Sariman Date: Thu, 2 Oct 2025 20:59:30 +0700 Subject: [PATCH 3/7] Normalize domain names in cache to ensure consistent casing and trailing dot usage --- client/internal/dnsfwd/cache.go | 16 ++++- client/internal/dnsfwd/cache_test.go | 86 +++++++++++++++++++++++ client/internal/dnsfwd/forwarder_test.go | 89 ++++++++++++++++++++++++ 3 files changed, 188 insertions(+), 3 deletions(-) create mode 100644 client/internal/dnsfwd/cache_test.go diff --git a/client/internal/dnsfwd/cache.go b/client/internal/dnsfwd/cache.go index 3539f7a51b8..8497b11975e 100644 --- a/client/internal/dnsfwd/cache.go +++ b/client/internal/dnsfwd/cache.go @@ -2,6 +2,7 @@ package dnsfwd import ( "net/netip" + "strings" "sync" "github.com/miekg/dns" @@ -26,7 +27,8 @@ func newCache() *cache { func (c *cache) get(domain string, reqType uint16) ([]netip.Addr, bool) { c.mu.RLock() defer c.mu.RUnlock() - entry, exists := c.records[domain] + + entry, exists := c.records[normalizeDomain(domain)] if !exists { return nil, false } @@ -45,10 +47,11 @@ func (c *cache) get(domain string, reqType uint16) ([]netip.Addr, bool) { func (c *cache) set(domain string, reqType uint16, addrs []netip.Addr) { c.mu.Lock() defer c.mu.Unlock() - entry, exists := c.records[domain] + norm := normalizeDomain(domain) + entry, exists := c.records[norm] if !exists { entry = &cacheEntry{} - c.records[domain] = entry + c.records[norm] = entry } switch reqType { @@ -59,6 +62,13 @@ func (c *cache) set(domain string, reqType uint16, addrs []netip.Addr) { } } +// normalizeDomain converts an input domain into a canonical form used as cache key: +// lowercase and fully-qualified (with trailing dot). +func normalizeDomain(domain string) string { + // dns.Fqdn ensures trailing dot; ToLower for consistent casing + return dns.Fqdn(strings.ToLower(domain)) +} + func cloneAddrs(in []netip.Addr) []netip.Addr { if in == nil { return nil diff --git a/client/internal/dnsfwd/cache_test.go b/client/internal/dnsfwd/cache_test.go new file mode 100644 index 00000000000..c23f0f31d6d --- /dev/null +++ b/client/internal/dnsfwd/cache_test.go @@ -0,0 +1,86 @@ +package dnsfwd + +import ( + "net/netip" + "testing" +) + +func mustAddr(t *testing.T, s string) netip.Addr { + t.Helper() + a, err := netip.ParseAddr(s) + if err != nil { + t.Fatalf("parse addr %s: %v", s, err) + } + return a +} + +func TestCacheNormalization(t *testing.T) { + c := newCache() + + // Mixed case, without trailing dot + domainInput := "ExAmPlE.CoM" + ipv4 := []netip.Addr{mustAddr(t, "1.2.3.4")} + c.set(domainInput, 1 /* dns.TypeA */, ipv4) + + // Lookup with lower, with trailing dot + if got, ok := c.get("example.com.", 1); !ok || len(got) != 1 || got[0].String() != "1.2.3.4" { + t.Fatalf("expected cached IPv4 result via normalized key, got=%v ok=%v", got, ok) + } + + // Lookup with different casing again + if got, ok := c.get("EXAMPLE.COM", 1); !ok || len(got) != 1 || got[0].String() != "1.2.3.4" { + t.Fatalf("expected cached IPv4 result via different casing, got=%v ok=%v", got, ok) + } +} + +func TestCacheSeparateTypes(t *testing.T) { + c := newCache() + + domain := "test.local" + ipv4 := []netip.Addr{mustAddr(t, "10.0.0.1")} + ipv6 := []netip.Addr{mustAddr(t, "2001:db8::1")} + + c.set(domain, 1 /* A */, ipv4) + c.set(domain, 28 /* AAAA */, ipv6) + + got4, ok4 := c.get(domain, 1) + if !ok4 || len(got4) != 1 || got4[0] != ipv4[0] { + t.Fatalf("expected A record from cache, got=%v ok=%v", got4, ok4) + } + + got6, ok6 := c.get(domain, 28) + if !ok6 || len(got6) != 1 || got6[0] != ipv6[0] { + t.Fatalf("expected AAAA record from cache, got=%v ok=%v", got6, ok6) + } +} + +func TestCacheCloneOnGetAndSet(t *testing.T) { + c := newCache() + domain := "clone.test" + + src := []netip.Addr{mustAddr(t, "8.8.8.8")} + c.set(domain, 1, src) + + // Mutate source slice; cache should be unaffected + src[0] = mustAddr(t, "9.9.9.9") + + got, ok := c.get(domain, 1) + if !ok || len(got) != 1 || got[0].String() != "8.8.8.8" { + t.Fatalf("expected cached value to be independent of source slice, got=%v ok=%v", got, ok) + } + + // Mutate returned slice; internal cache should remain unchanged + got[0] = mustAddr(t, "4.4.4.4") + got2, ok2 := c.get(domain, 1) + if !ok2 || len(got2) != 1 || got2[0].String() != "8.8.8.8" { + t.Fatalf("expected returned slice to be a clone, got=%v ok=%v", got2, ok2) + } +} + +func TestCacheMiss(t *testing.T) { + c := newCache() + if got, ok := c.get("missing.example", 1); ok || got != nil { + t.Fatalf("expected cache miss, got=%v ok=%v", got, ok) + } +} + diff --git a/client/internal/dnsfwd/forwarder_test.go b/client/internal/dnsfwd/forwarder_test.go index 57085e19a13..c1c95a2c1c6 100644 --- a/client/internal/dnsfwd/forwarder_test.go +++ b/client/internal/dnsfwd/forwarder_test.go @@ -648,6 +648,95 @@ func TestDNSForwarder_TCPTruncation(t *testing.T) { assert.LessOrEqual(t, writtenResp.Len(), dns.MinMsgSize, "Response should fit in minimum UDP size") } +// Ensures that when the first query succeeds and populates the cache, +// a subsequent upstream failure still returns a successful response from cache. +func TestDNSForwarder_ServeFromCacheOnUpstreamFailure(t *testing.T) { + mockResolver := &MockResolver{} + forwarder := NewDNSForwarder("127.0.0.1:0", 300, nil, &peer.Status{}) + forwarder.resolver = mockResolver + + d, err := domain.FromString("example.com") + require.NoError(t, err) + entries := []*ForwarderEntry{{Domain: d, ResID: "res-cache"}} + forwarder.UpdateDomains(entries) + + ip := netip.MustParseAddr("1.2.3.4") + + // First call resolves successfully and populates cache + mockResolver.On("LookupNetIP", mock.Anything, "ip4", dns.Fqdn("example.com")). + Return([]netip.Addr{ip}, nil).Once() + + // Second call fails upstream; forwarder should serve from cache + mockResolver.On("LookupNetIP", mock.Anything, "ip4", dns.Fqdn("example.com")). + Return([]netip.Addr{}, &net.DNSError{Err: "temporary failure"}).Once() + + // First query: populate cache + q1 := &dns.Msg{} + q1.SetQuestion(dns.Fqdn("example.com"), dns.TypeA) + w1 := &test.MockResponseWriter{} + resp1 := forwarder.handleDNSQuery(w1, q1) + require.NotNil(t, resp1) + require.Equal(t, dns.RcodeSuccess, resp1.Rcode) + require.Len(t, resp1.Answer, 1) + + // Second query: serve from cache after upstream failure + q2 := &dns.Msg{} + q2.SetQuestion(dns.Fqdn("example.com"), dns.TypeA) + var writtenResp *dns.Msg + w2 := &test.MockResponseWriter{WriteMsgFunc: func(m *dns.Msg) error { writtenResp = m; return nil }} + _ = forwarder.handleDNSQuery(w2, q2) + + require.NotNil(t, writtenResp, "expected response to be written") + require.Equal(t, dns.RcodeSuccess, writtenResp.Rcode) + require.Len(t, writtenResp.Answer, 1) + + mockResolver.AssertExpectations(t) +} + +// Verifies that cache normalization works across casing and trailing dot variations. +func TestDNSForwarder_CacheNormalizationCasingAndDot(t *testing.T) { + mockResolver := &MockResolver{} + forwarder := NewDNSForwarder("127.0.0.1:0", 300, nil, &peer.Status{}) + forwarder.resolver = mockResolver + + d, err := domain.FromString("ExAmPlE.CoM") + require.NoError(t, err) + entries := []*ForwarderEntry{{Domain: d, ResID: "res-norm"}} + forwarder.UpdateDomains(entries) + + ip := netip.MustParseAddr("9.8.7.6") + + // Initial resolution with mixed case to populate cache + mixedQuery := "ExAmPlE.CoM" + mockResolver.On("LookupNetIP", mock.Anything, "ip4", dns.Fqdn(strings.ToLower(mixedQuery))). + Return([]netip.Addr{ip}, nil).Once() + + q1 := &dns.Msg{} + q1.SetQuestion(mixedQuery+".", dns.TypeA) + w1 := &test.MockResponseWriter{} + resp1 := forwarder.handleDNSQuery(w1, q1) + require.NotNil(t, resp1) + require.Equal(t, dns.RcodeSuccess, resp1.Rcode) + require.Len(t, resp1.Answer, 1) + + // Subsequent query without dot and upper case should hit cache even if upstream fails + // Forwarder lowercases and uses the question name as-is (no trailing dot here) + mockResolver.On("LookupNetIP", mock.Anything, "ip4", strings.ToLower("EXAMPLE.COM")). + Return([]netip.Addr{}, &net.DNSError{Err: "temporary failure"}).Once() + + q2 := &dns.Msg{} + q2.SetQuestion("EXAMPLE.COM", dns.TypeA) + var writtenResp *dns.Msg + w2 := &test.MockResponseWriter{WriteMsgFunc: func(m *dns.Msg) error { writtenResp = m; return nil }} + _ = forwarder.handleDNSQuery(w2, q2) + + require.NotNil(t, writtenResp) + require.Equal(t, dns.RcodeSuccess, writtenResp.Rcode) + require.Len(t, writtenResp.Answer, 1) + + mockResolver.AssertExpectations(t) +} + func TestDNSForwarder_MultipleOverlappingPatterns(t *testing.T) { // Test complex overlapping pattern scenarios mockFirewall := &MockFirewall{} From 3a40f6991b810fb7043266aa9d5da69cd91865fe Mon Sep 17 00:00:00 2001 From: Hakan Sariman Date: Wed, 8 Oct 2025 13:52:09 +0300 Subject: [PATCH 4/7] Refactor cache methods to use slices.Clone for address slices --- client/internal/dnsfwd/cache.go | 18 +++++------------- 1 file changed, 5 insertions(+), 13 deletions(-) diff --git a/client/internal/dnsfwd/cache.go b/client/internal/dnsfwd/cache.go index 8497b11975e..c8ee4e30a51 100644 --- a/client/internal/dnsfwd/cache.go +++ b/client/internal/dnsfwd/cache.go @@ -2,6 +2,7 @@ package dnsfwd import ( "net/netip" + "slices" "strings" "sync" @@ -35,9 +36,9 @@ func (c *cache) get(domain string, reqType uint16) ([]netip.Addr, bool) { switch reqType { case dns.TypeA: - return cloneAddrs(entry.ip4Addrs), len(entry.ip4Addrs) > 0 + return slices.Clone(entry.ip4Addrs), len(entry.ip4Addrs) > 0 case dns.TypeAAAA: - return cloneAddrs(entry.ip6Addrs), len(entry.ip6Addrs) > 0 + return slices.Clone(entry.ip6Addrs), len(entry.ip6Addrs) > 0 default: return nil, false } @@ -56,9 +57,9 @@ func (c *cache) set(domain string, reqType uint16, addrs []netip.Addr) { switch reqType { case dns.TypeA: - entry.ip4Addrs = cloneAddrs(addrs) + entry.ip4Addrs = slices.Clone(addrs) case dns.TypeAAAA: - entry.ip6Addrs = cloneAddrs(addrs) + entry.ip6Addrs = slices.Clone(addrs) } } @@ -68,12 +69,3 @@ func normalizeDomain(domain string) string { // dns.Fqdn ensures trailing dot; ToLower for consistent casing return dns.Fqdn(strings.ToLower(domain)) } - -func cloneAddrs(in []netip.Addr) []netip.Addr { - if in == nil { - return nil - } - out := make([]netip.Addr, len(in)) - copy(out, in) - return out -} From bdda31d7775dfe81315d94ea60b23ed841f0a9c4 Mon Sep 17 00:00:00 2001 From: Hakan Sariman Date: Wed, 8 Oct 2025 13:55:41 +0300 Subject: [PATCH 5/7] cache nxdomain result --- client/internal/dnsfwd/cache.go | 4 ++-- client/internal/dnsfwd/forwarder.go | 1 + 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/client/internal/dnsfwd/cache.go b/client/internal/dnsfwd/cache.go index c8ee4e30a51..6bdec6e5f13 100644 --- a/client/internal/dnsfwd/cache.go +++ b/client/internal/dnsfwd/cache.go @@ -36,9 +36,9 @@ func (c *cache) get(domain string, reqType uint16) ([]netip.Addr, bool) { switch reqType { case dns.TypeA: - return slices.Clone(entry.ip4Addrs), len(entry.ip4Addrs) > 0 + return slices.Clone(entry.ip4Addrs), true case dns.TypeAAAA: - return slices.Clone(entry.ip6Addrs), len(entry.ip6Addrs) > 0 + return slices.Clone(entry.ip6Addrs), true default: return nil, false } diff --git a/client/internal/dnsfwd/forwarder.go b/client/internal/dnsfwd/forwarder.go index 4a52c79baec..dc7d7507e0b 100644 --- a/client/internal/dnsfwd/forwarder.go +++ b/client/internal/dnsfwd/forwarder.go @@ -316,6 +316,7 @@ func (f *DNSForwarder) handleDNSError( if writeErr := w.WriteMsg(resp); writeErr != nil { log.Errorf("failed to write failure DNS response: %v", writeErr) } + f.cache.set(domain, question.Qtype, nil) return } From 2eefe14be06fe3ca7212eac8909328006b9be5ed Mon Sep 17 00:00:00 2001 From: Hakan Sariman Date: Wed, 8 Oct 2025 13:57:44 +0300 Subject: [PATCH 6/7] Enhance DNSForwarder to handle NXDOMAIN responses when cache is empty --- client/internal/dnsfwd/forwarder.go | 19 +++++++++++++------ 1 file changed, 13 insertions(+), 6 deletions(-) diff --git a/client/internal/dnsfwd/forwarder.go b/client/internal/dnsfwd/forwarder.go index dc7d7507e0b..03647e699be 100644 --- a/client/internal/dnsfwd/forwarder.go +++ b/client/internal/dnsfwd/forwarder.go @@ -321,12 +321,19 @@ func (f *DNSForwarder) handleDNSError( } // Upstream failed but we might have a cached answer—serve it if present. - if ips, ok := f.cache.get(domain, qType); ok && len(ips) > 0 { - log.Debugf("serving cached DNS response after upstream failure: domain=%s type=%s", domain, qTypeName) - f.addIPsToResponse(resp, domain, ips) - resp.Rcode = dns.RcodeSuccess - if writeErr := w.WriteMsg(resp); writeErr != nil { - log.Errorf("failed to write cached DNS response: %v", writeErr) + if ips, ok := f.cache.get(domain, qType); ok { + if len(ips) > 0 { + log.Debugf("serving cached DNS response after upstream failure: domain=%s type=%s", domain, qTypeName) + f.addIPsToResponse(resp, domain, ips) + resp.Rcode = dns.RcodeSuccess + if writeErr := w.WriteMsg(resp); writeErr != nil { + log.Errorf("failed to write cached DNS response: %v", writeErr) + } + } else { // send NXDOMAIN / appropriate code if cache is empty + f.setResponseCodeForNotFound(ctx, resp, domain, qType) + if writeErr := w.WriteMsg(resp); writeErr != nil { + log.Errorf("failed to write failure DNS response: %v", writeErr) + } } return } From 7d69d65b3610da38490682230ab31c8bf7e6e102 Mon Sep 17 00:00:00 2001 From: Hakan Sariman Date: Wed, 8 Oct 2025 14:03:46 +0300 Subject: [PATCH 7/7] Add cache unset method and remove stale cache entries in DNSForwarder --- client/internal/dnsfwd/cache.go | 7 +++++++ client/internal/dnsfwd/forwarder.go | 29 +++++++++++++++++++++++++++++ 2 files changed, 36 insertions(+) diff --git a/client/internal/dnsfwd/cache.go b/client/internal/dnsfwd/cache.go index 6bdec6e5f13..43fe2d0203e 100644 --- a/client/internal/dnsfwd/cache.go +++ b/client/internal/dnsfwd/cache.go @@ -63,6 +63,13 @@ func (c *cache) set(domain string, reqType uint16, addrs []netip.Addr) { } } +// unset removes cached entries for the given domain and request type. +func (c *cache) unset(domain string) { + c.mu.Lock() + defer c.mu.Unlock() + delete(c.records, normalizeDomain(domain)) +} + // normalizeDomain converts an input domain into a canonical form used as cache key: // lowercase and fully-qualified (with trailing dot). func normalizeDomain(domain string) string { diff --git a/client/internal/dnsfwd/forwarder.go b/client/internal/dnsfwd/forwarder.go index 03647e699be..7a262fa4c9e 100644 --- a/client/internal/dnsfwd/forwarder.go +++ b/client/internal/dnsfwd/forwarder.go @@ -105,10 +105,39 @@ func (f *DNSForwarder) UpdateDomains(entries []*ForwarderEntry) { f.mutex.Lock() defer f.mutex.Unlock() + // remove cache entries for domains that no longer appear + f.removeStaleCacheEntries(f.fwdEntries, entries) + f.fwdEntries = entries log.Debugf("Updated DNS forwarder with %d domains", len(entries)) } +// removeStaleCacheEntries unsets cache items for domains that were present +// in the old list but not present in the new list. +func (f *DNSForwarder) removeStaleCacheEntries(oldEntries, newEntries []*ForwarderEntry) { + if f.cache == nil { + return + } + + newSet := make(map[string]struct{}, len(newEntries)) + for _, e := range newEntries { + if e == nil { + continue + } + newSet[e.Domain.PunycodeString()] = struct{}{} + } + + for _, e := range oldEntries { + if e == nil { + continue + } + pattern := e.Domain.PunycodeString() + if _, ok := newSet[pattern]; !ok { + f.cache.unset(pattern) + } + } +} + func (f *DNSForwarder) Close(ctx context.Context) error { var result *multierror.Error