Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 37 additions & 16 deletions libvirt/uri/ssh.go
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ func (u *ConnectionURI) parseAuthMethods(target string, sshcfg *ssh_config.Confi
// construct the whole ssh connection, which can consist of multiple hops if using proxy jumps,
// the ssh configuration file is loaded once and passed along to each host connection.
func (u *ConnectionURI) dialSSH() (net.Conn, error) {
var sshcfg* ssh_config.Config = nil
var sshcfg *ssh_config.Config = nil

sshConfigFile, err := os.Open(os.ExpandEnv(defaultSSHConfigFile))
if err != nil {
Expand All @@ -132,7 +132,7 @@ func (u *ConnectionURI) dialSSH() (net.Conn, error) {
}

// configuration loaded, build tunnel
sshClient, err := u.dialHost(u.Host, sshcfg, 0)
sshClient, err := u.dialHost(parsedTarget{hostName: u.Host}, sshcfg, 0)
if err != nil {
return nil, err
}
Expand All @@ -152,7 +152,12 @@ func (u *ConnectionURI) dialSSH() (net.Conn, error) {
return c, nil
}

func (u *ConnectionURI) dialHost(target string, sshcfg *ssh_config.Config, depth int) (*ssh.Client, error) {
type parsedTarget struct {
hostName string
user string
}

func (u *ConnectionURI) dialHost(target parsedTarget, sshcfg *ssh_config.Config, depth int) (*ssh.Client, error) {

if depth > maxHostHops {
return nil, fmt.Errorf("[ERROR] dialHost failed: max tunnel depth of 10 reached")
Expand All @@ -169,9 +174,9 @@ func (u *ConnectionURI) dialHost(target string, sshcfg *ssh_config.Config, depth
log.Printf("[DEBUG] ssh Port is overridden to: '%s'", port)
}

hostName := target
hostName := target.hostName
if sshcfg != nil {
host, err := sshcfg.Get(target, "HostName")
host, err := sshcfg.Get(target.hostName, "HostName")
if err == nil && host != "" {
hostName = host
log.Printf("[DEBUG] HostName is overridden to: '%s'", hostName)
Expand All @@ -188,7 +193,7 @@ func (u *ConnectionURI) dialHost(target string, sshcfg *ssh_config.Config, depth
skipVerify = true
} else {
if sshcfg != nil {
strictCheck, err := sshcfg.Get(target, "StrictHostKeyChecking")
strictCheck, err := sshcfg.Get(target.hostName, "StrictHostKeyChecking")
if err != nil && strictCheck == "yes" {
skipVerify = false
}
Expand All @@ -199,7 +204,7 @@ func (u *ConnectionURI) dialHost(target string, sshcfg *ssh_config.Config, depth
knownHostsPath = defaultSSHKnownHostsPath

if sshcfg != nil {
knownHosts, err := sshcfg.Get(target, "UserKnownHostsFile")
knownHosts, err := sshcfg.Get(target.hostName, "UserKnownHostsFile")
if err == nil && knownHosts != "" {
knownHostsPath = knownHosts
}
Expand Down Expand Up @@ -236,7 +241,7 @@ func (u *ConnectionURI) dialHost(target string, sshcfg *ssh_config.Config, depth
}

if sshcfg != nil {
keyAlgs, err := sshcfg.Get(target, "HostKeyAlgorithms")
keyAlgs, err := sshcfg.Get(target.hostName, "HostKeyAlgorithms")
if err == nil && keyAlgs != "" {
log.Printf("[DEBUG] HostKeyAlgorithms is overridden to '%s'", keyAlgs)
hostKeyAlgorithms = strings.Split(keyAlgs, ",")
Expand All @@ -251,22 +256,25 @@ func (u *ConnectionURI) dialHost(target string, sshcfg *ssh_config.Config, depth
HostKeyAlgorithms: hostKeyAlgorithms,
Timeout: dialTimeout,
}
if target.user != "" {
cfg.User = target.user
}
var bastion *ssh.Client = nil
var bastion_proxy string = ""

if sshcfg != nil {
command, err := sshcfg.Get(target, "ProxyCommand")
command, err := sshcfg.Get(target.hostName, "ProxyCommand")
if err == nil && command != "" {
log.Printf("[WARNING] unsupported ssh ProxyCommand '%v' - ignoring", command)
}
}

if sshcfg != nil {
proxy, err := sshcfg.Get(target, "ProxyJump")
if err == nil && proxy != "" {
proxy, err := sshcfg.Get(target.hostName, "ProxyJump")
if err == nil && (proxy != "" && proxy != "none") {
log.Printf("[DEBUG] found ProxyJump '%v'", proxy)
// this is a proxy jump: we recurse into that proxy
bastion, err = u.dialHost(proxy, sshcfg, depth+1)
bastion, err = u.dialHost(proxyJumpStringToParsedTarget(proxy), sshcfg, depth+1)
bastion_proxy = proxy
if err != nil {
return nil, fmt.Errorf("failed to connect to bastion host '%v': %w", proxy, err)
Expand All @@ -276,15 +284,14 @@ func (u *ConnectionURI) dialHost(target string, sshcfg *ssh_config.Config, depth

// cfg.User value defaults to u.User.Username()
if sshcfg != nil {
sshu, err := sshcfg.Get(target, "User")
sshu, err := sshcfg.Get(target.hostName, "User")
if err != nil {
log.Printf("[DEBUG] ssh user for target '%v' is overridden to '%v'", target, sshu)
cfg.User = sshu
}
}


cfg.Auth = u.parseAuthMethods(target, sshcfg)
cfg.Auth = u.parseAuthMethods(target.hostName, sshcfg)
if len(cfg.Auth) < 1 {
return nil, fmt.Errorf("could not configure SSH authentication methods")
}
Expand All @@ -298,7 +305,7 @@ func (u *ConnectionURI) dialHost(target string, sshcfg *ssh_config.Config, depth
return nil, fmt.Errorf("failed to connect to remote host '%v': %w", target, err)
}

ncc, chans, reqs, err := ssh.NewClientConn(conn, target, &cfg)
ncc, chans, reqs, err := ssh.NewClientConn(conn, target.hostName, &cfg)
if err != nil {
return nil, fmt.Errorf("failed to connect to remote host '%v': %w", target, err)
}
Expand All @@ -317,3 +324,17 @@ func (u *ConnectionURI) dialHost(target string, sshcfg *ssh_config.Config, depth
return conn, nil
}
}

func proxyJumpStringToParsedTarget(s string) parsedTarget {
atIdx := strings.Index(s, "@")
if atIdx < 0 {
return parsedTarget{
hostName: s,
}
}

return parsedTarget{
hostName: s[atIdx+1:],
user: s[:atIdx],
}
}
29 changes: 29 additions & 0 deletions libvirt/uri/ssh_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
package uri

import (
"testing"

"github.com/stretchr/testify/assert"
)

func TestProxyJumpStringToParsedTarget(t *testing.T) {
in := []string{
"host.enterprise.com",
"[email protected]",
}
expectedOut := []parsedTarget{
{
hostName: "host.enterprise.com",
},
{
hostName: "host.enterprise.com",
user: "user",
},
}

out := []parsedTarget{}
for _, proxyJumpStr := range in {
out = append(out, proxyJumpStringToParsedTarget(proxyJumpStr))
}
assert.Equal(t, expectedOut, out)
}