Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion client/android/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,9 @@ import (
"github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/internal/profilemanager"
"github.com/netbirdio/netbird/client/internal/stdnet"
"github.com/netbirdio/netbird/client/net"
"github.com/netbirdio/netbird/client/system"
"github.com/netbirdio/netbird/formatter"
"github.com/netbirdio/netbird/client/net"
)

// ConnectionListener export internal Listener for mobile
Expand Down
166 changes: 125 additions & 41 deletions client/cmd/ssh.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,17 +5,21 @@ import (
"errors"
"flag"
"fmt"
"net"
"os"
"os/signal"
"os/user"
"slices"
"strconv"
"strings"
"syscall"

"github.com/spf13/cobra"

"github.com/netbirdio/netbird/client/internal"
sshclient "github.com/netbirdio/netbird/client/ssh/client"
"github.com/netbirdio/netbird/client/ssh/detection"
sshproxy "github.com/netbirdio/netbird/client/ssh/proxy"
sshserver "github.com/netbirdio/netbird/client/ssh/server"
"github.com/netbirdio/netbird/util"
)
Expand All @@ -29,6 +33,7 @@ const (
enableSSHSFTPFlag = "enable-ssh-sftp"
enableSSHLocalPortForwardFlag = "enable-ssh-local-port-forwarding"
enableSSHRemotePortForwardFlag = "enable-ssh-remote-port-forwarding"
disableSSHAuthFlag = "disable-ssh-auth"
)

var (
Expand All @@ -41,6 +46,7 @@ var (
strictHostKeyChecking bool
knownHostsFile string
identityFile string
skipCachedToken bool
)

var (
Expand All @@ -49,6 +55,7 @@ var (
enableSSHSFTP bool
enableSSHLocalPortForward bool
enableSSHRemotePortForward bool
disableSSHAuth bool
)

func init() {
Expand All @@ -57,18 +64,22 @@ func init() {
upCmd.PersistentFlags().BoolVar(&enableSSHSFTP, enableSSHSFTPFlag, false, "Enable SFTP subsystem for SSH server")
upCmd.PersistentFlags().BoolVar(&enableSSHLocalPortForward, enableSSHLocalPortForwardFlag, false, "Enable local port forwarding for SSH server")
upCmd.PersistentFlags().BoolVar(&enableSSHRemotePortForward, enableSSHRemotePortForwardFlag, false, "Enable remote port forwarding for SSH server")
upCmd.PersistentFlags().BoolVar(&disableSSHAuth, disableSSHAuthFlag, false, "Disable SSH authentication")

sshCmd.PersistentFlags().IntVarP(&port, "port", "p", sshserver.DefaultSSHPort, "Remote SSH port")
sshCmd.PersistentFlags().StringVarP(&username, "user", "u", "", sshUsernameDesc)
sshCmd.PersistentFlags().StringVar(&username, "login", "", sshUsernameDesc+" (alias for --user)")
sshCmd.PersistentFlags().BoolVar(&strictHostKeyChecking, "strict-host-key-checking", true, "Enable strict host key checking (default: true)")
sshCmd.PersistentFlags().StringVarP(&knownHostsFile, "known-hosts", "o", "", "Path to known_hosts file (default: ~/.ssh/known_hosts)")
sshCmd.PersistentFlags().StringVarP(&identityFile, "identity", "i", "", "Path to SSH private key file")
sshCmd.PersistentFlags().BoolVar(&skipCachedToken, "no-cache", false, "Skip cached JWT token and force fresh authentication")

sshCmd.PersistentFlags().StringArrayP("L", "L", []string{}, "Local port forwarding [bind_address:]port:host:hostport")
sshCmd.PersistentFlags().StringArrayP("R", "R", []string{}, "Remote port forwarding [bind_address:]port:host:hostport")

sshCmd.AddCommand(sshSftpCmd)
sshCmd.AddCommand(sshProxyCmd)
sshCmd.AddCommand(sshDetectCmd)
}

var sshCmd = &cobra.Command{
Expand Down Expand Up @@ -335,31 +346,51 @@ func parseSpacedFormat(arg string, args []string, currentIndex int, flagHandlers
}

// createSSHFlagSet creates and configures the flag set for SSH command parsing
func createSSHFlagSet() (*flag.FlagSet, *int, *string, *string, *bool, *string, *string, *string, *string) {
// sshFlags contains all SSH-related flags and parameters
type sshFlags struct {
Port int
Username string
Login string
StrictHostKeyChecking bool
KnownHostsFile string
IdentityFile string
SkipCachedToken bool
ConfigPath string
LogLevel string
LocalForwards []string
RemoteForwards []string
Host string
Command string
}

func createSSHFlagSet() (*flag.FlagSet, *sshFlags) {
defaultConfigPath := getEnvOrDefault("CONFIG", configPath)
defaultLogLevel := getEnvOrDefault("LOG_LEVEL", logLevel)

fs := flag.NewFlagSet("ssh-flags", flag.ContinueOnError)
fs.SetOutput(nil)

portFlag := fs.Int("p", sshserver.DefaultSSHPort, "SSH port")
flags := &sshFlags{}

fs.IntVar(&flags.Port, "p", sshserver.DefaultSSHPort, "SSH port")
fs.Int("port", sshserver.DefaultSSHPort, "SSH port")
userFlag := fs.String("u", "", sshUsernameDesc)
fs.StringVar(&flags.Username, "u", "", sshUsernameDesc)
fs.String("user", "", sshUsernameDesc)
loginFlag := fs.String("login", "", sshUsernameDesc+" (alias for --user)")
fs.StringVar(&flags.Login, "login", "", sshUsernameDesc+" (alias for --user)")

strictHostKeyCheckingFlag := fs.Bool("strict-host-key-checking", true, "Enable strict host key checking")
knownHostsFlag := fs.String("o", "", "Path to known_hosts file")
fs.BoolVar(&flags.StrictHostKeyChecking, "strict-host-key-checking", true, "Enable strict host key checking")
fs.StringVar(&flags.KnownHostsFile, "o", "", "Path to known_hosts file")
fs.String("known-hosts", "", "Path to known_hosts file")
identityFlag := fs.String("i", "", "Path to SSH private key file")
fs.StringVar(&flags.IdentityFile, "i", "", "Path to SSH private key file")
fs.String("identity", "", "Path to SSH private key file")
fs.BoolVar(&flags.SkipCachedToken, "no-cache", false, "Skip cached JWT token and force fresh authentication")

configFlag := fs.String("c", defaultConfigPath, "Netbird config file location")
fs.StringVar(&flags.ConfigPath, "c", defaultConfigPath, "Netbird config file location")
fs.String("config", defaultConfigPath, "Netbird config file location")
logLevelFlag := fs.String("l", defaultLogLevel, "sets Netbird log level")
fs.StringVar(&flags.LogLevel, "l", defaultLogLevel, "sets Netbird log level")
fs.String("log-level", defaultLogLevel, "sets Netbird log level")

return fs, portFlag, userFlag, loginFlag, strictHostKeyCheckingFlag, knownHostsFlag, identityFlag, configFlag, logLevelFlag
return fs, flags
}

func validateSSHArgsWithoutFlagParsing(_ *cobra.Command, args []string) error {
Expand All @@ -375,7 +406,7 @@ func validateSSHArgsWithoutFlagParsing(_ *cobra.Command, args []string) error {

filteredArgs, localForwardFlags, remoteForwardFlags := parseCustomSSHFlags(args)

fs, portFlag, userFlag, loginFlag, strictHostKeyCheckingFlag, knownHostsFlag, identityFlag, configFlag, logLevelFlag := createSSHFlagSet()
fs, flags := createSSHFlagSet()

if err := fs.Parse(filteredArgs); err != nil {
return parseHostnameAndCommand(filteredArgs)
Expand All @@ -386,22 +417,23 @@ func validateSSHArgsWithoutFlagParsing(_ *cobra.Command, args []string) error {
return errors.New(hostArgumentRequired)
}

port = *portFlag
if *userFlag != "" {
username = *userFlag
} else if *loginFlag != "" {
username = *loginFlag
port = flags.Port
if flags.Username != "" {
username = flags.Username
} else if flags.Login != "" {
username = flags.Login
}

strictHostKeyChecking = *strictHostKeyCheckingFlag
knownHostsFile = *knownHostsFlag
identityFile = *identityFlag
strictHostKeyChecking = flags.StrictHostKeyChecking
knownHostsFile = flags.KnownHostsFile
identityFile = flags.IdentityFile
skipCachedToken = flags.SkipCachedToken

if *configFlag != getEnvOrDefault("CONFIG", configPath) {
configPath = *configFlag
if flags.ConfigPath != getEnvOrDefault("CONFIG", configPath) {
configPath = flags.ConfigPath
}
if *logLevelFlag != getEnvOrDefault("LOG_LEVEL", logLevel) {
logLevel = *logLevelFlag
if flags.LogLevel != getEnvOrDefault("LOG_LEVEL", logLevel) {
logLevel = flags.LogLevel
}

localForwards = localForwardFlags
Expand Down Expand Up @@ -449,30 +481,20 @@ func parseHostnameAndCommand(args []string) error {

func runSSH(ctx context.Context, addr string, cmd *cobra.Command) error {
target := fmt.Sprintf("%s:%d", addr, port)

var c *sshclient.Client
var err error

if strictHostKeyChecking {
c, err = sshclient.DialWithOptions(ctx, target, username, sshclient.DialOptions{
KnownHostsFile: knownHostsFile,
IdentityFile: identityFile,
DaemonAddr: daemonAddr,
})
} else {
c, err = sshclient.DialInsecure(ctx, target, username)
}
c, err := sshclient.Dial(ctx, target, username, sshclient.DialOptions{
KnownHostsFile: knownHostsFile,
IdentityFile: identityFile,
DaemonAddr: daemonAddr,
SkipCachedToken: skipCachedToken,
InsecureSkipVerify: !strictHostKeyChecking,
})

if err != nil {
cmd.Printf("Failed to connect to %s@%s\n", username, target)
cmd.Printf("\nTroubleshooting steps:\n")
cmd.Printf(" 1. Check peer connectivity: netbird status\n")
cmd.Printf(" 1. Check peer connectivity: netbird status -d\n")
cmd.Printf(" 2. Verify SSH server is enabled on the peer\n")
cmd.Printf(" 3. Ensure correct hostname/IP is used\n")
if strictHostKeyChecking {
cmd.Printf(" 4. Try --strict-host-key-checking=false to bypass host key verification\n")
}
cmd.Printf("\n")
return fmt.Errorf("dial %s: %w", target, err)
}

Expand Down Expand Up @@ -665,3 +687,65 @@ func normalizeLocalHost(host string) string {
}
return host
}

var sshProxyCmd = &cobra.Command{
Use: "proxy <host> <port>",
Short: "Internal SSH proxy for native SSH client integration",
Long: "Internal command used by SSH ProxyCommand to handle JWT authentication",
Hidden: true,
Args: cobra.ExactArgs(2),
RunE: sshProxyFn,
}

func sshProxyFn(cmd *cobra.Command, args []string) error {
host := args[0]
portStr := args[1]

port, err := strconv.Atoi(portStr)
if err != nil {
return fmt.Errorf("invalid port: %s", portStr)
}

proxy, err := sshproxy.New(daemonAddr, host, port, cmd.ErrOrStderr())
if err != nil {
return fmt.Errorf("create SSH proxy: %w", err)
}

if err := proxy.Connect(cmd.Context()); err != nil {
return fmt.Errorf("SSH proxy: %w", err)
}

return nil
}

var sshDetectCmd = &cobra.Command{
Use: "detect <host> <port>",
Short: "Detect if a host is running NetBird SSH",
Long: "Internal command used by SSH Match exec to detect NetBird SSH servers. Exit codes: 0=JWT, 1=no-JWT, 2=regular SSH",
Hidden: true,
Args: cobra.ExactArgs(2),
RunE: sshDetectFn,
}

func sshDetectFn(cmd *cobra.Command, args []string) error {
if err := util.InitLog(logLevel, "console"); err != nil {
os.Exit(detection.ServerTypeRegular.ExitCode())
}

host := args[0]
portStr := args[1]

port, err := strconv.Atoi(portStr)
if err != nil {
os.Exit(detection.ServerTypeRegular.ExitCode())
}

dialer := &net.Dialer{Timeout: detection.Timeout}
serverType, err := detection.DetectSSHServerType(cmd.Context(), dialer, host, port)
if err != nil {
os.Exit(detection.ServerTypeRegular.ExitCode())
}

os.Exit(serverType.ExitCode())
return nil
}
11 changes: 11 additions & 0 deletions client/cmd/up.go
Original file line number Diff line number Diff line change
Expand Up @@ -360,6 +360,9 @@ func setupSetConfigReq(customDNSAddressConverted []byte, cmd *cobra.Command, pro
if cmd.Flag(enableSSHRemotePortForwardFlag).Changed {
req.EnableSSHRemotePortForward = &enableSSHRemotePortForward
}
if cmd.Flag(disableSSHAuthFlag).Changed {
req.DisableSSHAuth = &disableSSHAuth
}
if cmd.Flag(interfaceNameFlag).Changed {
if err := parseInterfaceName(interfaceName); err != nil {
log.Errorf("parse interface name: %v", err)
Expand Down Expand Up @@ -460,6 +463,10 @@ func setupConfig(customDNSAddressConverted []byte, cmd *cobra.Command, configFil
ic.EnableSSHRemotePortForwarding = &enableSSHRemotePortForward
}

if cmd.Flag(disableSSHAuthFlag).Changed {
ic.DisableSSHAuth = &disableSSHAuth
}

if cmd.Flag(interfaceNameFlag).Changed {
if err := parseInterfaceName(interfaceName); err != nil {
return nil, err
Expand Down Expand Up @@ -576,6 +583,10 @@ func setupLoginRequest(providedSetupKey string, customDNSAddressConverted []byte
loginRequest.EnableSSHRemotePortForwarding = &enableSSHRemotePortForward
}

if cmd.Flag(disableSSHAuthFlag).Changed {
loginRequest.DisableSSHAuth = &disableSSHAuth
}

if cmd.Flag(disableAutoConnectFlag).Changed {
loginRequest.DisableAutoConnect = &autoConnectDisabled
}
Expand Down
Loading
Loading