diff --git a/client/firewall/uspfilter/filter.go b/client/firewall/uspfilter/filter.go index fbc39b74041..ec2d2c57f12 100644 --- a/client/firewall/uspfilter/filter.go +++ b/client/firewall/uspfilter/filter.go @@ -50,6 +50,12 @@ const ( var errNatNotSupported = errors.New("nat not supported with userspace firewall") +// serviceKey represents a protocol/port combination for netstack service registry +type serviceKey struct { + protocol gopacket.LayerType + port uint16 +} + // RuleSet is a set of rules grouped by a string key type RuleSet map[string]PeerRule @@ -113,6 +119,9 @@ type Manager struct { portDNATEnabled atomic.Bool portDNATRules []portDNATRule portDNATMutex sync.RWMutex + + netstackServices map[serviceKey]struct{} + netstackServiceMutex sync.RWMutex } // decoder for packages @@ -203,6 +212,7 @@ func create(iface common.IFaceMapper, nativeFirewall firewall.Manager, disableSe localForwarding: enableLocalForwarding, dnatMappings: make(map[netip.Addr]netip.Addr), portDNATRules: []portDNATRule{}, + netstackServices: make(map[serviceKey]struct{}), } m.routingEnabled.Store(false) @@ -838,9 +848,7 @@ func (m *Manager) handleLocalTraffic(d *decoder, srcIP, dstIP netip.Addr, packet return true } - // If requested we pass local traffic to internal interfaces to the forwarder. - // netstack doesn't have an interface to forward packets to the native stack so we always need to use the forwarder. - if m.localForwarding && (m.netstack || dstIP != m.wgIface.Address().IP) { + if m.shouldForward(d, dstIP) { return m.handleForwardedLocalTraffic(packetData) } @@ -1274,3 +1282,86 @@ func (m *Manager) DisableRouting() error { return nil } + +// RegisterNetstackService registers a service as listening on the netstack for the given protocol and port +func (m *Manager) RegisterNetstackService(protocol nftypes.Protocol, port uint16) { + m.netstackServiceMutex.Lock() + defer m.netstackServiceMutex.Unlock() + layerType := m.protocolToLayerType(protocol) + key := serviceKey{protocol: layerType, port: port} + m.netstackServices[key] = struct{}{} + m.logger.Debug3("RegisterNetstackService: registered %s:%d (layerType=%s)", protocol, port, layerType) + m.logger.Debug1("RegisterNetstackService: current registry size: %d", len(m.netstackServices)) +} + +// UnregisterNetstackService removes a service from the netstack registry +func (m *Manager) UnregisterNetstackService(protocol nftypes.Protocol, port uint16) { + m.netstackServiceMutex.Lock() + defer m.netstackServiceMutex.Unlock() + layerType := m.protocolToLayerType(protocol) + key := serviceKey{protocol: layerType, port: port} + delete(m.netstackServices, key) + m.logger.Debug2("Unregistered netstack service on protocol %s port %d", protocol, port) +} + +// protocolToLayerType converts nftypes.Protocol to gopacket.LayerType for internal use +func (m *Manager) protocolToLayerType(protocol nftypes.Protocol) gopacket.LayerType { + switch protocol { + case nftypes.TCP: + return layers.LayerTypeTCP + case nftypes.UDP: + return layers.LayerTypeUDP + case nftypes.ICMP: + return layers.LayerTypeICMPv4 + default: + return gopacket.LayerType(0) // Invalid/unknown + } +} + +// shouldForward determines if a packet should be forwarded to the forwarder. +// The forwarder handles routing packets to the native OS network stack. +// Returns true if packet should go to the forwarder, false if it should go to netstack listeners or the native stack directly. +func (m *Manager) shouldForward(d *decoder, dstIP netip.Addr) bool { + // not enabled, never forward + if !m.localForwarding { + return false + } + + // netstack always needs to forward because it's lacking a native interface + // exception for registered netstack services, those should go to netstack listeners + if m.netstack { + return !m.hasMatchingNetstackService(d) + } + + // traffic to our other local interfaces (not NetBird IP) - always forward + if dstIP != m.wgIface.Address().IP { + return true + } + + // traffic to our NetBird IP, not netstack mode - send to netstack listeners + return false +} + +// hasMatchingNetstackService checks if there's a registered netstack service for this packet +func (m *Manager) hasMatchingNetstackService(d *decoder) bool { + if len(d.decoded) < 2 { + return false + } + + var dstPort uint16 + switch d.decoded[1] { + case layers.LayerTypeTCP: + dstPort = uint16(d.tcp.DstPort) + case layers.LayerTypeUDP: + dstPort = uint16(d.udp.DstPort) + default: + return false + } + + key := serviceKey{protocol: d.decoded[1], port: dstPort} + m.netstackServiceMutex.RLock() + _, exists := m.netstackServices[key] + m.netstackServiceMutex.RUnlock() + + return exists +} diff --git a/client/internal/dnsfwd/cache_test.go b/client/internal/dnsfwd/cache_test.go index c23f0f31d6d..44ebe290bbd 100644 --- a/client/internal/dnsfwd/cache_test.go +++ b/client/internal/dnsfwd/cache_test.go @@ -83,4 +83,3 @@ func TestCacheMiss(t *testing.T) { t.Fatalf("expected cache miss, got=%v ok=%v", got, ok) } } - diff --git a/client/internal/dnsfwd/forwarder.go b/client/internal/dnsfwd/forwarder.go index 7a262fa4c9e..aef16a8cfea 100644 --- a/client/internal/dnsfwd/forwarder.go +++ b/client/internal/dnsfwd/forwarder.go @@ -14,6 +14,7 @@ import ( "github.com/hashicorp/go-multierror" "github.com/miekg/dns" log "github.com/sirupsen/logrus" + "golang.zx2c4.com/wireguard/tun/netstack" nberrors "github.com/netbirdio/netbird/client/errors" firewall "github.com/netbirdio/netbird/client/firewall/manager" @@ -33,7 +34,7 @@ type firewaller interface { } type DNSForwarder struct { - listenAddress string + listenAddress netip.AddrPort ttl uint32 statusRecorder *peer.Status @@ -47,9 +48,11 @@ type DNSForwarder struct { firewall firewaller resolver resolver cache *cache + + wgIface wgIface } -func NewDNSForwarder(listenAddress string, ttl uint32, firewall firewaller, statusRecorder *peer.Status) *DNSForwarder { +func NewDNSForwarder(listenAddress netip.AddrPort, ttl uint32, firewall firewaller, statusRecorder *peer.Status, wgIface wgIface) *DNSForwarder { log.Debugf("creating DNS forwarder with listen_address=%s ttl=%d", listenAddress, ttl) return &DNSForwarder{ listenAddress: listenAddress, @@ -58,30 +61,46 @@ func NewDNSForwarder(listenAddress string, ttl uint32, firewall firewaller, stat statusRecorder: statusRecorder, resolver: net.DefaultResolver, cache: newCache(), + wgIface: wgIface, } } func (f *DNSForwarder) Listen(entries []*ForwarderEntry) error { - log.Infof("starting DNS forwarder on address=%s", f.listenAddress) + var netstackNet *netstack.Net + if f.wgIface != nil { + netstackNet = f.wgIface.GetNet() + } + + addrDesc := f.listenAddress.String() + if netstackNet != nil { + addrDesc = fmt.Sprintf("netstack %s", f.listenAddress) + } + log.Infof("starting DNS forwarder on address=%s", addrDesc) + + udpLn, err := f.createUDPListener(netstackNet) + if err != nil { + return fmt.Errorf("create UDP listener: %w", err) + } + + tcpLn, err := f.createTCPListener(netstackNet) + if err != nil { + return fmt.Errorf("create TCP listener: %w", err) + } - // UDP server mux := dns.NewServeMux() f.mux = mux mux.HandleFunc(".", f.handleDNSQueryUDP) f.dnsServer = &dns.Server{ - Addr: f.listenAddress, - Net: "udp", - Handler: mux, + PacketConn: udpLn, + Handler: mux, } - // TCP server tcpMux := dns.NewServeMux() f.tcpMux = tcpMux tcpMux.HandleFunc(".", f.handleDNSQueryTCP) f.tcpServer = &dns.Server{ - Addr: f.listenAddress, - Net: "tcp", - Handler: tcpMux, + Listener: tcpLn, + Handler: tcpMux, } f.UpdateDomains(entries) @@ -89,18 +108,33 @@ func (f *DNSForwarder) Listen(entries []*ForwarderEntry) error { errCh := make(chan error, 2) go func() { - log.Infof("DNS UDP listener running on %s", f.listenAddress) - errCh <- f.dnsServer.ListenAndServe() + log.Infof("DNS UDP listener running on %s", addrDesc) + errCh <- f.dnsServer.ActivateAndServe() }() go func() { - log.Infof("DNS TCP listener running on %s", f.listenAddress) - errCh <- f.tcpServer.ListenAndServe() + log.Infof("DNS TCP listener running on %s", addrDesc) + errCh <- f.tcpServer.ActivateAndServe() }() - // return the first error we get (e.g. bind failure or shutdown) return <-errCh } +func (f *DNSForwarder) createUDPListener(netstackNet *netstack.Net) (net.PacketConn, error) { + if netstackNet != nil { + return netstackNet.ListenUDPAddrPort(f.listenAddress) + } + + return net.ListenUDP("udp", net.UDPAddrFromAddrPort(f.listenAddress)) +} + +func (f *DNSForwarder) createTCPListener(netstackNet *netstack.Net) (net.Listener, error) { + if netstackNet != nil { + return netstackNet.ListenTCPAddrPort(f.listenAddress) + } + + return net.ListenTCP("tcp", net.TCPAddrFromAddrPort(f.listenAddress)) +} + func (f *DNSForwarder) UpdateDomains(entries []*ForwarderEntry) { f.mutex.Lock() defer f.mutex.Unlock() diff --git a/client/internal/dnsfwd/forwarder_test.go b/client/internal/dnsfwd/forwarder_test.go index c1c95a2c1c6..4d0b96a758f 100644 --- a/client/internal/dnsfwd/forwarder_test.go +++ b/client/internal/dnsfwd/forwarder_test.go @@ -297,7 +297,7 @@ func TestDNSForwarder_UnauthorizedDomainAccess(t *testing.T) { mockResolver.On("LookupNetIP", mock.Anything, "ip4", dns.Fqdn(tt.queryDomain)).Return([]netip.Addr{fakeIP}, nil) } - forwarder := NewDNSForwarder("127.0.0.1:0", 300, mockFirewall, &peer.Status{}) + forwarder := NewDNSForwarder(netip.MustParseAddrPort("127.0.0.1:0"), 300, mockFirewall, &peer.Status{}, nil) forwarder.resolver = mockResolver d, err := domain.FromString(tt.configuredDomain) @@ -402,7 +402,7 @@ func TestDNSForwarder_FirewallSetUpdates(t *testing.T) { mockResolver := &MockResolver{} // Set up forwarder - forwarder := NewDNSForwarder("127.0.0.1:0", 300, mockFirewall, &peer.Status{}) + forwarder := NewDNSForwarder(netip.MustParseAddrPort("127.0.0.1:0"), 300, mockFirewall, &peer.Status{}, nil) forwarder.resolver = mockResolver // Create entries and track sets @@ -489,7 +489,7 @@ func TestDNSForwarder_MultipleIPsInSingleUpdate(t *testing.T) { mockFirewall := &MockFirewall{} mockResolver := &MockResolver{} - forwarder := NewDNSForwarder("127.0.0.1:0", 300, mockFirewall, &peer.Status{}) + forwarder := NewDNSForwarder(netip.MustParseAddrPort("127.0.0.1:0"), 300, mockFirewall, &peer.Status{}, nil) forwarder.resolver = mockResolver // Configure a single domain @@ -584,7 +584,7 @@ func TestDNSForwarder_ResponseCodes(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - forwarder := NewDNSForwarder("127.0.0.1:0", 300, nil, &peer.Status{}) + forwarder := NewDNSForwarder(netip.MustParseAddrPort("127.0.0.1:0"), 300, nil, &peer.Status{}, nil) d, err := domain.FromString(tt.configured) require.NoError(t, err) @@ -616,7 +616,7 @@ func TestDNSForwarder_ResponseCodes(t *testing.T) { func TestDNSForwarder_TCPTruncation(t *testing.T) { // Test that large UDP responses are truncated with TC bit set mockResolver := &MockResolver{} - forwarder := NewDNSForwarder("127.0.0.1:0", 300, nil, &peer.Status{}) + forwarder := NewDNSForwarder(netip.MustParseAddrPort("127.0.0.1:0"), 300, nil, &peer.Status{}, nil) forwarder.resolver = mockResolver d, _ := domain.FromString("example.com") @@ -652,7 +652,7 @@ func TestDNSForwarder_TCPTruncation(t *testing.T) { // a subsequent upstream failure still returns a successful response from cache. func TestDNSForwarder_ServeFromCacheOnUpstreamFailure(t *testing.T) { mockResolver := &MockResolver{} - forwarder := NewDNSForwarder("127.0.0.1:0", 300, nil, &peer.Status{}) + forwarder := NewDNSForwarder(netip.MustParseAddrPort("127.0.0.1:0"), 300, nil, &peer.Status{}, nil) forwarder.resolver = mockResolver d, err := domain.FromString("example.com") @@ -696,7 +696,7 @@ func TestDNSForwarder_ServeFromCacheOnUpstreamFailure(t *testing.T) { // Verifies that cache normalization works across casing and trailing dot variations. func TestDNSForwarder_CacheNormalizationCasingAndDot(t *testing.T) { mockResolver := &MockResolver{} - forwarder := NewDNSForwarder("127.0.0.1:0", 300, nil, &peer.Status{}) + forwarder := NewDNSForwarder(netip.MustParseAddrPort("127.0.0.1:0"), 300, nil, &peer.Status{}, nil) forwarder.resolver = mockResolver d, err := domain.FromString("ExAmPlE.CoM") @@ -742,7 +742,7 @@ func TestDNSForwarder_MultipleOverlappingPatterns(t *testing.T) { mockFirewall := &MockFirewall{} mockResolver := &MockResolver{} - forwarder := NewDNSForwarder("127.0.0.1:0", 300, mockFirewall, &peer.Status{}) + forwarder := NewDNSForwarder(netip.MustParseAddrPort("127.0.0.1:0"), 300, mockFirewall, &peer.Status{}, nil) forwarder.resolver = mockResolver // Set up complex overlapping patterns @@ -804,7 +804,7 @@ func TestDNSForwarder_NodataVsNxdomain(t *testing.T) { mockFirewall := &MockFirewall{} mockResolver := &MockResolver{} - forwarder := NewDNSForwarder("127.0.0.1:0", 300, mockFirewall, &peer.Status{}) + forwarder := NewDNSForwarder(netip.MustParseAddrPort("127.0.0.1:0"), 300, mockFirewall, &peer.Status{}, nil) forwarder.resolver = mockResolver d, err := domain.FromString("example.com") @@ -925,7 +925,7 @@ func TestDNSForwarder_NodataVsNxdomain(t *testing.T) { func TestDNSForwarder_EmptyQuery(t *testing.T) { // Test handling of malformed query with no questions - forwarder := NewDNSForwarder("127.0.0.1:0", 300, nil, &peer.Status{}) + forwarder := NewDNSForwarder(netip.MustParseAddrPort("127.0.0.1:0"), 300, nil, &peer.Status{}, nil) query := &dns.Msg{} // Don't set any question diff --git a/client/internal/dnsfwd/manager.go b/client/internal/dnsfwd/manager.go index a1c0dff98e8..b26836d17e6 100644 --- a/client/internal/dnsfwd/manager.go +++ b/client/internal/dnsfwd/manager.go @@ -10,9 +10,11 @@ import ( "github.com/hashicorp/go-multierror" log "github.com/sirupsen/logrus" + "golang.zx2c4.com/wireguard/tun/netstack" nberrors "github.com/netbirdio/netbird/client/errors" firewall "github.com/netbirdio/netbird/client/firewall/manager" + "github.com/netbirdio/netbird/client/iface/wgaddr" "github.com/netbirdio/netbird/client/internal/peer" nbdns "github.com/netbirdio/netbird/dns" "github.com/netbirdio/netbird/route" @@ -24,6 +26,12 @@ const ( envServerPort = "NB_DNS_FORWARDER_PORT" ) +// wgIface defines the interface for WireGuard interface operations needed by the DNS forwarder. +type wgIface interface { + GetNet() *netstack.Net + Address() wgaddr.Address +} + // ForwarderEntry is a mapping from a domain to a resource ID and a hash of the parent domain list. type ForwarderEntry struct { Domain domain.Domain @@ -34,7 +42,7 @@ type ForwarderEntry struct { type Manager struct { firewall firewall.Manager statusRecorder *peer.Status - localAddr netip.Addr + wgIface wgIface serverPort uint16 fwRules []firewall.Rule @@ -42,7 +50,7 @@ type Manager struct { dnsForwarder *DNSForwarder } -func NewManager(fw firewall.Manager, statusRecorder *peer.Status, localAddr netip.Addr) *Manager { +func NewManager(fw firewall.Manager, statusRecorder *peer.Status, wgIface wgIface) *Manager { serverPort := nbdns.ForwarderServerPort if envPort := os.Getenv(envServerPort); envPort != "" { if port, err := strconv.ParseUint(envPort, 10, 16); err == nil && port > 0 { @@ -56,7 +64,7 @@ func NewManager(fw firewall.Manager, statusRecorder *peer.Status, localAddr neti return &Manager{ firewall: fw, statusRecorder: statusRecorder, - localAddr: localAddr, + wgIface: wgIface, serverPort: serverPort, } } @@ -71,21 +79,25 @@ func (m *Manager) Start(fwdEntries []*ForwarderEntry) error { return err } - if m.localAddr.IsValid() && m.firewall != nil { - if err := m.firewall.AddInboundDNAT(m.localAddr, firewall.ProtocolUDP, nbdns.ForwarderClientPort, m.serverPort); err != nil { + localAddr := m.wgIface.Address().IP + + if localAddr.IsValid() && m.firewall != nil { + if err := m.firewall.AddInboundDNAT(localAddr, firewall.ProtocolUDP, nbdns.ForwarderClientPort, m.serverPort); err != nil { log.Warnf("failed to add DNS UDP DNAT rule: %v", err) } else { - log.Infof("added DNS UDP DNAT rule: %s:%d -> %s:%d", m.localAddr, nbdns.ForwarderClientPort, m.localAddr, m.serverPort) + log.Infof("added DNS UDP DNAT rule: %s:%d -> %s:%d", localAddr, nbdns.ForwarderClientPort, localAddr, m.serverPort) } - if err := m.firewall.AddInboundDNAT(m.localAddr, firewall.ProtocolTCP, nbdns.ForwarderClientPort, m.serverPort); err != nil { + if err := m.firewall.AddInboundDNAT(localAddr, firewall.ProtocolTCP, nbdns.ForwarderClientPort, m.serverPort); err != nil { log.Warnf("failed to add DNS TCP DNAT rule: %v", err) } else { - log.Infof("added DNS TCP DNAT rule: %s:%d -> %s:%d", m.localAddr, nbdns.ForwarderClientPort, m.localAddr, m.serverPort) + log.Infof("added DNS TCP DNAT rule: %s:%d -> %s:%d", localAddr, nbdns.ForwarderClientPort, localAddr, m.serverPort) } } - m.dnsForwarder = NewDNSForwarder(fmt.Sprintf(":%d", m.serverPort), dnsTTL, m.firewall, m.statusRecorder) + listenAddress := netip.AddrPortFrom(localAddr, m.serverPort) + m.dnsForwarder = NewDNSForwarder(listenAddress, dnsTTL, m.firewall, m.statusRecorder, m.wgIface) + go func() { if err := m.dnsForwarder.Listen(fwdEntries); err != nil { // todo handle close error if it is exists @@ -111,12 +123,13 @@ func (m *Manager) Stop(ctx context.Context) error { var mErr *multierror.Error - if m.localAddr.IsValid() && m.firewall != nil { - if err := m.firewall.RemoveInboundDNAT(m.localAddr, firewall.ProtocolUDP, nbdns.ForwarderClientPort, m.serverPort); err != nil { + localAddr := m.wgIface.Address().IP + if localAddr.IsValid() && m.firewall != nil { + if err := m.firewall.RemoveInboundDNAT(localAddr, firewall.ProtocolUDP, nbdns.ForwarderClientPort, m.serverPort); err != nil { mErr = multierror.Append(mErr, fmt.Errorf("remove DNS UDP DNAT rule: %w", err)) } - if err := m.firewall.RemoveInboundDNAT(m.localAddr, firewall.ProtocolTCP, nbdns.ForwarderClientPort, m.serverPort); err != nil { + if err := m.firewall.RemoveInboundDNAT(localAddr, firewall.ProtocolTCP, nbdns.ForwarderClientPort, m.serverPort); err != nil { mErr = multierror.Append(mErr, fmt.Errorf("remove DNS TCP DNAT rule: %w", err)) } } diff --git a/client/internal/engine.go b/client/internal/engine.go index 19d37eee17a..ad69bcf435e 100644 --- a/client/internal/engine.go +++ b/client/internal/engine.go @@ -1855,35 +1855,69 @@ func (e *Engine) updateDNSForwarder( } if !enabled { - if e.dnsForwardMgr == nil { - return - } - if err := e.dnsForwardMgr.Stop(context.Background()); err != nil { - log.Errorf("failed to stop DNS forward: %v", err) - } + e.stopDNSForwarder() return } if len(fwdEntries) > 0 { if e.dnsForwardMgr == nil { - localAddr := e.wgInterface.Address().IP - e.dnsForwardMgr = dnsfwd.NewManager(e.firewall, e.statusRecorder, localAddr) - - if err := e.dnsForwardMgr.Start(fwdEntries); err != nil { - log.Errorf("failed to start DNS forward: %v", err) - e.dnsForwardMgr = nil - } - - log.Infof("started domain router service with %d entries", len(fwdEntries)) + e.startDNSForwarder(fwdEntries) } else { e.dnsForwardMgr.UpdateDomains(fwdEntries) } } else if e.dnsForwardMgr != nil { log.Infof("disable domain router service") - if err := e.dnsForwardMgr.Stop(context.Background()); err != nil { - log.Errorf("failed to stop DNS forward: %v", err) - } + e.stopDNSForwarder() + } +} + +func (e *Engine) startDNSForwarder(fwdEntries []*dnsfwd.ForwarderEntry) { + e.dnsForwardMgr = dnsfwd.NewManager(e.firewall, e.statusRecorder, e.wgInterface) + e.registerDNSServices() + + if err := e.dnsForwardMgr.Start(fwdEntries); err != nil { + log.Errorf("failed to start DNS forward: %v", err) e.dnsForwardMgr = nil + return + } + + log.Infof("started domain router service with %d entries", len(fwdEntries)) +} + +func (e *Engine) stopDNSForwarder() { + if e.dnsForwardMgr == nil { + return + } + + if err := e.dnsForwardMgr.Stop(context.Background()); err != nil { + log.Errorf("failed to stop DNS forward: %v", err) + } + + e.unregisterDNSServices() + e.dnsForwardMgr = nil +} + +func (e *Engine) registerDNSServices() { + if netstackNet := e.wgInterface.GetNet(); netstackNet != nil { + if registrar, ok := e.firewall.(interface { + RegisterNetstackService(protocol nftypes.Protocol, port uint16) + }); ok { + registrar.RegisterNetstackService(nftypes.UDP, nbdns.ForwarderServerPort) + registrar.RegisterNetstackService(nftypes.TCP, nbdns.ForwarderServerPort) + log.Debugf("registered DNS forwarder service with netstack for UDP/TCP:%d", nbdns.ForwarderServerPort) + } + } +} + +func (e *Engine) unregisterDNSServices() { + if netstackNet := e.wgInterface.GetNet(); netstackNet != nil { + if registrar, ok := e.firewall.(interface { + UnregisterNetstackService(protocol nftypes.Protocol, port uint16) + }); ok { + registrar.UnregisterNetstackService(nftypes.UDP, nbdns.ForwarderServerPort) + registrar.UnregisterNetstackService(nftypes.TCP, nbdns.ForwarderServerPort) + log.Debugf("unregistered DNS forwarder service with netstack for UDP/TCP:%d", nbdns.ForwarderServerPort) + } } }