Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 3 additions & 4 deletions client/grpc/dialer.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,17 +29,16 @@ func Backoff(ctx context.Context) backoff.BackOff {
// The component parameter specifies the WebSocket proxy component path (e.g., "/management", "/signal").
func CreateConnection(ctx context.Context, addr string, tlsEnabled bool, component string) (*grpc.ClientConn, error) {
transportOption := grpc.WithTransportCredentials(insecure.NewCredentials())
if tlsEnabled {
// for js, the outer websocket layer takes care of tls
if tlsEnabled && runtime.GOOS != "js" {
certPool, err := x509.SystemCertPool()
if err != nil || certPool == nil {
log.Debugf("System cert pool not available; falling back to embedded cert, error: %v", err)
certPool = embeddedroots.Get()
}

transportOption = grpc.WithTransportCredentials(credentials.NewTLS(&tls.Config{
// for js, outer websocket layer takes care of tls verification via WithCustomDialer
InsecureSkipVerify: runtime.GOOS == "js",
RootCAs: certPool,
RootCAs: certPool,
}))
}

Expand Down
3 changes: 1 addition & 2 deletions management/internals/server/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ import (
"fmt"
"net"
"net/http"
"net/netip"
"strings"
"sync"
"time"
Expand Down Expand Up @@ -252,7 +251,7 @@ func updateMgmtConfig(ctx context.Context, path string, config *nbconfig.Config)
}

func (s *BaseServer) handlerFunc(gRPCHandler *grpc.Server, httpHandler http.Handler, meter metric.Meter) http.Handler {
wsProxy := wsproxyserver.New(netip.AddrPortFrom(netip.AddrFrom4([4]byte{127, 0, 0, 1}), ManagementLegacyPort), wsproxyserver.WithOTelMeter(meter))
wsProxy := wsproxyserver.New(gRPCHandler, wsproxyserver.WithOTelMeter(meter))

return http.HandlerFunc(func(writer http.ResponseWriter, request *http.Request) {
switch {
Expand Down
9 changes: 5 additions & 4 deletions signal/cmd/run.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@ import (
"net/http"
// nolint:gosec
_ "net/http/pprof"
"net/netip"
"time"

"go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc"
Expand Down Expand Up @@ -63,10 +62,10 @@ var (
Use: "run",
Short: "start NetBird Signal Server daemon",
SilenceUsage: true,
PreRun: func(cmd *cobra.Command, args []string) {
PreRunE: func(cmd *cobra.Command, args []string) error {
err := util.InitLog(logLevel, logFile)
if err != nil {
log.Fatalf("failed initializing log %v", err)
return fmt.Errorf("failed initializing log: %w", err)
}

flag.Parse()
Expand All @@ -87,6 +86,8 @@ var (
signalPort = 80
}
}

return nil
},
RunE: func(cmd *cobra.Command, args []string) error {
flag.Parse()
Expand Down Expand Up @@ -254,7 +255,7 @@ func startServerWithCertManager(certManager *autocert.Manager, grpcRootHandler h
}

func grpcHandlerFunc(grpcServer *grpc.Server, meter metric.Meter) http.Handler {
wsProxy := wsproxyserver.New(netip.AddrPortFrom(netip.AddrFrom4([4]byte{127, 0, 0, 1}), legacyGRPCPort), wsproxyserver.WithOTelMeter(meter))
wsProxy := wsproxyserver.New(grpcServer, wsproxyserver.WithOTelMeter(meter))

return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
switch {
Expand Down
140 changes: 53 additions & 87 deletions util/wsproxy/server/proxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,42 +2,41 @@ package server

import (
"context"
"errors"
"io"
"net"
"net/http"
"net/netip"
"sync"
"time"

"github.com/coder/websocket"
log "github.com/sirupsen/logrus"
"golang.org/x/net/http2"

"github.com/netbirdio/netbird/util/wsproxy"
)

const (
dialTimeout = 10 * time.Second
bufferSize = 32 * 1024
bufferSize = 32 * 1024
ioTimeout = 5 * time.Second
)

// Config contains the configuration for the WebSocket proxy.
type Config struct {
LocalGRPCAddr netip.AddrPort
Handler http.Handler
Path string
MetricsRecorder MetricsRecorder
}

// Proxy handles WebSocket to TCP proxying for gRPC connections.
// Proxy handles WebSocket to gRPC handler proxying.
type Proxy struct {
config Config
metrics MetricsRecorder
}

// New creates a new WebSocket proxy instance with optional configuration
func New(localGRPCAddr netip.AddrPort, opts ...Option) *Proxy {
func New(handler http.Handler, opts ...Option) *Proxy {
config := Config{
LocalGRPCAddr: localGRPCAddr,
Handler: handler,
Path: wsproxy.ProxyPath,
MetricsRecorder: NoOpMetricsRecorder{}, // Default to no-op
}
Expand All @@ -63,7 +62,7 @@ func (p *Proxy) handleWebSocket(w http.ResponseWriter, r *http.Request) {
p.metrics.RecordConnection(ctx)
defer p.metrics.RecordDisconnection(ctx)

log.Debugf("WebSocket proxy handling connection from %s, forwarding to %s", r.RemoteAddr, p.config.LocalGRPCAddr)
log.Debugf("WebSocket proxy handling connection from %s, forwarding to internal gRPC handler", r.RemoteAddr)
acceptOptions := &websocket.AcceptOptions{
OriginPatterns: []string{"*"},
}
Expand All @@ -75,71 +74,41 @@ func (p *Proxy) handleWebSocket(w http.ResponseWriter, r *http.Request) {
return
}
defer func() {
if err := wsConn.Close(websocket.StatusNormalClosure, ""); err != nil {
log.Debugf("Failed to close WebSocket: %v", err)
}
_ = wsConn.Close(websocket.StatusNormalClosure, "")
}()

log.Debugf("WebSocket proxy attempting to connect to local gRPC at %s", p.config.LocalGRPCAddr)
tcpConn, err := net.DialTimeout("tcp", p.config.LocalGRPCAddr.String(), dialTimeout)
if err != nil {
p.metrics.RecordError(ctx, "tcp_dial_failed")
log.Warnf("Failed to connect to local gRPC server at %s: %v", p.config.LocalGRPCAddr, err)
if err := wsConn.Close(websocket.StatusInternalError, "Backend unavailable"); err != nil {
log.Debugf("Failed to close WebSocket after connection failure: %v", err)
}
return
}
clientConn, serverConn := net.Pipe()
defer func() {
if err := tcpConn.Close(); err != nil {
log.Debugf("Failed to close TCP connection: %v", err)
}
_ = clientConn.Close()
_ = serverConn.Close()
}()

log.Debugf("WebSocket proxy established: client %s -> local gRPC %s", r.RemoteAddr, p.config.LocalGRPCAddr)
log.Debugf("WebSocket proxy established: %s -> gRPC handler", r.RemoteAddr)

p.proxyData(ctx, wsConn, tcpConn)
go func() {
(&http2.Server{}).ServeConn(serverConn, &http2.ServeConnOpts{
Context: ctx,
Handler: p.config.Handler,
})
}()

p.proxyData(ctx, wsConn, clientConn, r.RemoteAddr)
}

func (p *Proxy) proxyData(ctx context.Context, wsConn *websocket.Conn, tcpConn net.Conn) {
func (p *Proxy) proxyData(ctx context.Context, wsConn *websocket.Conn, pipeConn net.Conn, clientAddr string) {
proxyCtx, cancel := context.WithCancel(ctx)
defer cancel()

var wg sync.WaitGroup
wg.Add(2)

go p.wsToTCP(proxyCtx, cancel, &wg, wsConn, tcpConn)
go p.tcpToWS(proxyCtx, cancel, &wg, wsConn, tcpConn)

done := make(chan struct{})
go func() {
wg.Wait()
close(done)
}()

select {
case <-done:
log.Tracef("Proxy data transfer completed, both goroutines terminated")
case <-proxyCtx.Done():
log.Tracef("Proxy data transfer cancelled, forcing connection closure")
go p.wsToPipe(proxyCtx, cancel, &wg, wsConn, pipeConn, clientAddr)
go p.pipeToWS(proxyCtx, cancel, &wg, wsConn, pipeConn, clientAddr)

if err := wsConn.Close(websocket.StatusGoingAway, "proxy cancelled"); err != nil {
log.Tracef("Error closing WebSocket during cancellation: %v", err)
}
if err := tcpConn.Close(); err != nil {
log.Tracef("Error closing TCP connection during cancellation: %v", err)
}

select {
case <-done:
log.Tracef("Goroutines terminated after forced connection closure")
case <-time.After(2 * time.Second):
log.Tracef("Goroutines did not terminate within timeout after connection closure")
}
}
wg.Wait()
}

func (p *Proxy) wsToTCP(ctx context.Context, cancel context.CancelFunc, wg *sync.WaitGroup, wsConn *websocket.Conn, tcpConn net.Conn) {
func (p *Proxy) wsToPipe(ctx context.Context, cancel context.CancelFunc, wg *sync.WaitGroup, wsConn *websocket.Conn, pipeConn net.Conn, clientAddr string) {
defer wg.Done()
defer cancel()

Expand All @@ -148,80 +117,77 @@ func (p *Proxy) wsToTCP(ctx context.Context, cancel context.CancelFunc, wg *sync
if err != nil {
switch {
case ctx.Err() != nil:
log.Debugf("wsToTCP goroutine terminating due to context cancellation")
case websocket.CloseStatus(err) == websocket.StatusNormalClosure:
log.Debugf("WebSocket closed normally")
log.Debugf("WebSocket from %s terminating due to context cancellation", clientAddr)
case websocket.CloseStatus(err) != -1:
log.Debugf("WebSocket from %s disconnected", clientAddr)
default:
p.metrics.RecordError(ctx, "websocket_read_error")
log.Errorf("WebSocket read error: %v", err)
log.Debugf("WebSocket read error from %s: %v", clientAddr, err)
}
return
}

if msgType != websocket.MessageBinary {
log.Warnf("Unexpected WebSocket message type: %v", msgType)
log.Warnf("Unexpected WebSocket message type from %s: %v", clientAddr, msgType)
continue
}

if ctx.Err() != nil {
log.Tracef("wsToTCP goroutine terminating due to context cancellation before TCP write")
log.Tracef("wsToPipe goroutine terminating due to context cancellation before pipe write")
return
}

if err := tcpConn.SetWriteDeadline(time.Now().Add(5 * time.Second)); err != nil {
log.Debugf("Failed to set TCP write deadline: %v", err)
if err := pipeConn.SetWriteDeadline(time.Now().Add(ioTimeout)); err != nil {
log.Debugf("Failed to set pipe write deadline: %v", err)
}

n, err := tcpConn.Write(data)
n, err := pipeConn.Write(data)
if err != nil {
p.metrics.RecordError(ctx, "tcp_write_error")
log.Errorf("TCP write error: %v", err)
p.metrics.RecordError(ctx, "pipe_write_error")
log.Warnf("Pipe write error for %s: %v", clientAddr, err)
return
}

p.metrics.RecordBytesTransferred(ctx, "ws_to_tcp", int64(n))
p.metrics.RecordBytesTransferred(ctx, "ws_to_grpc", int64(n))
}
}

func (p *Proxy) tcpToWS(ctx context.Context, cancel context.CancelFunc, wg *sync.WaitGroup, wsConn *websocket.Conn, tcpConn net.Conn) {
func (p *Proxy) pipeToWS(ctx context.Context, cancel context.CancelFunc, wg *sync.WaitGroup, wsConn *websocket.Conn, pipeConn net.Conn, clientAddr string) {
defer wg.Done()
defer cancel()

buf := make([]byte, bufferSize)
for {
if err := tcpConn.SetReadDeadline(time.Now().Add(5 * time.Second)); err != nil {
log.Debugf("Failed to set TCP read deadline: %v", err)
if err := pipeConn.SetReadDeadline(time.Now().Add(ioTimeout)); err != nil {
log.Debugf("Failed to set pipe read deadline: %v", err)
}
n, err := tcpConn.Read(buf)

n, err := pipeConn.Read(buf)
if err != nil {
if ctx.Err() != nil {
log.Tracef("tcpToWS goroutine terminating due to context cancellation")
log.Tracef("pipeToWS goroutine terminating due to context cancellation")
return
}

var netErr net.Error
if errors.As(err, &netErr) && netErr.Timeout() {
continue
}

if err != io.EOF {
log.Errorf("TCP read error: %v", err)
log.Debugf("Pipe read error for %s: %v", clientAddr, err)
}
return
}

if ctx.Err() != nil {
log.Tracef("tcpToWS goroutine terminating due to context cancellation before WebSocket write")
log.Tracef("pipeToWS goroutine terminating due to context cancellation before WebSocket write")
return
}

if err := wsConn.Write(ctx, websocket.MessageBinary, buf[:n]); err != nil {
p.metrics.RecordError(ctx, "websocket_write_error")
log.Errorf("WebSocket write error: %v", err)
return
}
if n > 0 {
if err := wsConn.Write(ctx, websocket.MessageBinary, buf[:n]); err != nil {
p.metrics.RecordError(ctx, "websocket_write_error")
log.Warnf("WebSocket write error for %s: %v", clientAddr, err)
return
}

p.metrics.RecordBytesTransferred(ctx, "tcp_to_ws", int64(n))
p.metrics.RecordBytesTransferred(ctx, "grpc_to_ws", int64(n))
}
}
}
Loading