Skip to content

Commit cc36477

Browse files
committed
Adds callback mode that is direct to vault
Signed-off-by: Dave Dykstra <[email protected]>
1 parent b8833ce commit cc36477

File tree

7 files changed

+327
-85
lines changed

7 files changed

+327
-85
lines changed

backend.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,7 @@ func backend() *jwtAuthBackend {
6363
"login",
6464
"oidc/auth_url",
6565
"oidc/callback",
66+
"oidc/poll",
6667

6768
// Uncomment to mount simple UI handler for local development
6869
// "ui",

cli.go

Lines changed: 108 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ import (
88
"fmt"
99
"net"
1010
"net/http"
11+
"net/url"
1112
"os"
1213
"os/signal"
1314
"path"
@@ -27,9 +28,11 @@ const (
2728
defaultPort = "8250"
2829
defaultCallbackHost = "localhost"
2930
defaultCallbackMethod = "http"
31+
defaultCallbackMode = "client"
3032

3133
FieldCallbackHost = "callbackhost"
3234
FieldCallbackMethod = "callbackmethod"
35+
FieldCallbackMode = "callbackmode"
3336
FieldListenAddress = "listenaddress"
3437
FieldPort = "port"
3538
FieldCallbackPort = "callbackport"
@@ -69,19 +72,42 @@ func (h *CLIHandler) Auth(c *api.Client, m map[string]string) (*api.Secret, erro
6972
port = defaultPort
7073
}
7174

75+
var vaultURL *url.URL
76+
callbackMode, ok := m[FieldCallbackMode]
77+
if !ok {
78+
callbackMode = defaultCallbackMode
79+
} else if callbackMode == "direct" {
80+
vaultAddr := os.Getenv("VAULT_ADDR")
81+
if vaultAddr != "" {
82+
vaultURL, _ = url.Parse(vaultAddr)
83+
}
84+
}
85+
7286
callbackHost, ok := m[FieldCallbackHost]
7387
if !ok {
74-
callbackHost = defaultCallbackHost
88+
if vaultURL != nil {
89+
callbackHost = vaultURL.Hostname()
90+
} else {
91+
callbackHost = defaultCallbackHost
92+
}
7593
}
7694

7795
callbackMethod, ok := m[FieldCallbackMethod]
7896
if !ok {
79-
callbackMethod = defaultCallbackMethod
97+
if vaultURL != nil {
98+
callbackMethod = vaultURL.Scheme
99+
} else {
100+
callbackMethod = defaultCallbackMethod
101+
}
80102
}
81103

82104
callbackPort, ok := m[FieldCallbackPort]
83105
if !ok {
84-
callbackPort = port
106+
if vaultURL != nil {
107+
callbackPort = vaultURL.Port() + "/v1/auth/" + mount
108+
} else {
109+
callbackPort = port
110+
}
85111
}
86112

87113
parseBool := func(f string, d bool) (bool, error) {
@@ -115,20 +141,49 @@ func (h *CLIHandler) Auth(c *api.Client, m map[string]string) (*api.Secret, erro
115141

116142
role := m["role"]
117143

118-
authURL, clientNonce, err := fetchAuthURL(c, role, mount, callbackPort, callbackMethod, callbackHost)
144+
authURL, clientNonce, secret, err := fetchAuthURL(c, role, mount, callbackPort, callbackMethod, callbackHost)
119145
if err != nil {
120146
return nil, err
121147
}
122148

123-
// Set up callback handler
124149
doneCh := make(chan loginResp)
125-
http.HandleFunc("/oidc/callback", callbackHandler(c, mount, clientNonce, doneCh))
126150

127-
listener, err := net.Listen("tcp", listenAddress+":"+port)
128-
if err != nil {
129-
return nil, err
151+
var pollInterval string
152+
var interval int
153+
var state string
154+
var listener net.Listener
155+
156+
if secret != nil {
157+
pollInterval, _ = secret.Data["poll_interval"].(string)
158+
state, _ = secret.Data["state"].(string)
159+
}
160+
if callbackMode == "direct" {
161+
if state == "" {
162+
return nil, errors.New("no state returned in direct callback mode")
163+
}
164+
if pollInterval == "" {
165+
return nil, errors.New("no poll_interval returned in direct callback mode")
166+
}
167+
interval, err = strconv.Atoi(pollInterval)
168+
if err != nil {
169+
return nil, errors.New("cannot convert poll_interval " + pollInterval + " to integer")
170+
}
171+
} else {
172+
if state != "" {
173+
return nil, errors.New("state returned in client callback mode, try direct")
174+
}
175+
if pollInterval != "" {
176+
return nil, errors.New("poll_interval returned in client callback mode")
177+
}
178+
// Set up callback handler
179+
http.HandleFunc("/oidc/callback", callbackHandler(c, mount, clientNonce, doneCh))
180+
181+
listener, err = net.Listen("tcp", listenAddress+":"+port)
182+
if err != nil {
183+
return nil, err
184+
}
185+
defer listener.Close()
130186
}
131-
defer listener.Close()
132187

133188
// Open the default browser to the callback URL.
134189
if !skipBrowserLaunch {
@@ -144,6 +199,26 @@ func (h *CLIHandler) Auth(c *api.Client, m map[string]string) (*api.Secret, erro
144199
}
145200
fmt.Fprintf(os.Stderr, "Waiting for OIDC authentication to complete...\n")
146201

202+
if callbackMode == "direct" {
203+
data := map[string]interface{}{
204+
"state": state,
205+
"client_nonce": clientNonce,
206+
}
207+
pollUrl := fmt.Sprintf("auth/%s/oidc/poll", mount)
208+
for {
209+
time.Sleep(time.Duration(interval) * time.Second)
210+
211+
secret, err := c.Logical().Write(pollUrl, data)
212+
if err == nil {
213+
return secret, nil
214+
}
215+
if !strings.HasSuffix(err.Error(), "authorization_pending") {
216+
return nil, err
217+
}
218+
// authorization is pending, try again
219+
}
220+
}
221+
147222
// Start local server
148223
go func() {
149224
err := http.Serve(listener, nil)
@@ -210,12 +285,12 @@ func callbackHandler(c *api.Client, mount string, clientNonce string, doneCh cha
210285
}
211286
}
212287

213-
func fetchAuthURL(c *api.Client, role, mount, callbackPort string, callbackMethod string, callbackHost string) (string, string, error) {
288+
func fetchAuthURL(c *api.Client, role, mount, callbackPort string, callbackMethod string, callbackHost string) (string, string, *api.Secret, error) {
214289
var authURL string
215290

216291
clientNonce, err := base62.Random(20)
217292
if err != nil {
218-
return "", "", err
293+
return "", "", nil, err
219294
}
220295

221296
redirectURI := fmt.Sprintf("%s://%s:%s/oidc/callback", callbackMethod, callbackHost, callbackPort)
@@ -227,18 +302,18 @@ func fetchAuthURL(c *api.Client, role, mount, callbackPort string, callbackMetho
227302

228303
secret, err := c.Logical().Write(fmt.Sprintf("auth/%s/oidc/auth_url", mount), data)
229304
if err != nil {
230-
return "", "", err
305+
return "", "", nil, err
231306
}
232307

233308
if secret != nil {
234309
authURL = secret.Data["auth_url"].(string)
235310
}
236311

237312
if authURL == "" {
238-
return "", "", fmt.Errorf("Unable to authorize role %q with redirect_uri %q. Check Vault logs for more information.", role, redirectURI)
313+
return "", "", nil, fmt.Errorf("Unable to authorize role %q with redirect_uri %q. Check Vault logs for more information.", role, redirectURI)
239314
}
240315

241-
return authURL, clientNonce, nil
316+
return authURL, clientNonce, secret, nil
242317
}
243318

244319
// parseError converts error from the API into summary and detailed portions.
@@ -292,35 +367,46 @@ Usage: vault login -method=oidc [CONFIG K=V...]
292367
293368
https://accounts.google.com/o/oauth2/v2/...
294369
295-
The default browser will be opened for the user to complete the login. Alternatively,
296-
the user may visit the provided URL directly.
370+
The default browser will be opened for the user to complete the login.
371+
Alternatively, the user may visit the provided URL directly.
297372
298373
Configuration:
299374
300375
role=<string>
301376
Vault role of type "OIDC" to use for authentication.
302377
303378
%s=<string>
304-
Optional address to bind the OIDC callback listener to (default: localhost).
379+
Mode of callback: "direct" for direct connection to Vault or "client"
380+
for connection to command line client (default: client).
381+
382+
%s=<string>
383+
Optional address to bind the OIDC callback listener to in client callback
384+
mode (default: localhost).
305385
306386
%s=<string>
307-
Optional localhost port to use for OIDC callback (default: 8250).
387+
Optional localhost port to use for OIDC callback in client callback mode
388+
(default: 8250).
308389
309390
%s=<string>
310-
Optional method to to use in OIDC redirect_uri (default: http).
391+
Optional method to use in OIDC redirect_uri (default: the method from
392+
$VAULT_ADDR in direct callback mode, else http)
311393
312394
%s=<string>
313-
Optional callback host address to use in OIDC redirect_uri (default: localhost).
395+
Optional callback host address to use in OIDC redirect_uri (default:
396+
the host from $VAULT_ADDR in direct callback mode, else localhost).
314397
315398
%s=<string>
316-
Optional port to to use in OIDC redirect_uri (default: the value set for port).
399+
Optional port to use in OIDC redirect_uri (default: the value set for
400+
port in client callback mode, else the port from $VAULT_ADDR with an
401+
added /v1/auth/<path> where <path> is from the login -path option).
317402
318403
%s=<bool>
319404
Toggle the automatic launching of the default browser to the login URL. (default: false).
320405
321406
%s=<bool>
322407
Abort on any error. (default: false).
323408
`,
409+
FieldCallbackMode,
324410
FieldListenAddress, FieldPort, FieldCallbackMethod,
325411
FieldCallbackHost, FieldCallbackPort, FieldSkipBrowser,
326412
FieldAbortOnError,
File renamed without changes.

0 commit comments

Comments
 (0)