From 9c29af19d9cfb5ca403cb3634733eeba82ff3ada Mon Sep 17 00:00:00 2001 From: Dave Dykstra <2129743+DrDaveD@users.noreply.github.com> Date: Fri, 7 Jun 2024 15:25:16 -0500 Subject: [PATCH 1/2] Adds callback mode that is direct to vault Signed-off-by: Dave Dykstra <2129743+DrDaveD@users.noreply.github.com> --- backend.go | 1 + cli.go | 130 +++++++++++++++---- cli_responses.go => html_responses.go | 0 path_oidc.go | 179 +++++++++++++++++++------- path_oidc_test.go | 68 +++++++--- path_role.go | 31 +++++ path_role_test.go | 3 + 7 files changed, 327 insertions(+), 85 deletions(-) rename cli_responses.go => html_responses.go (100%) diff --git a/backend.go b/backend.go index 85041de0..dc09966a 100644 --- a/backend.go +++ b/backend.go @@ -63,6 +63,7 @@ func backend() *jwtAuthBackend { "login", "oidc/auth_url", "oidc/callback", + "oidc/poll", // Uncomment to mount simple UI handler for local development // "ui", diff --git a/cli.go b/cli.go index 9c61f868..d84a6c0d 100644 --- a/cli.go +++ b/cli.go @@ -8,6 +8,7 @@ import ( "fmt" "net" "net/http" + "net/url" "os" "os/signal" "path" @@ -27,9 +28,11 @@ const ( defaultPort = "8250" defaultCallbackHost = "localhost" defaultCallbackMethod = "http" + defaultCallbackMode = "client" FieldCallbackHost = "callbackhost" FieldCallbackMethod = "callbackmethod" + FieldCallbackMode = "callbackmode" FieldListenAddress = "listenaddress" FieldPort = "port" FieldCallbackPort = "callbackport" @@ -69,19 +72,42 @@ func (h *CLIHandler) Auth(c *api.Client, m map[string]string) (*api.Secret, erro port = defaultPort } + var vaultURL *url.URL + callbackMode, ok := m[FieldCallbackMode] + if !ok { + callbackMode = defaultCallbackMode + } else if callbackMode == "direct" { + vaultAddr := os.Getenv("VAULT_ADDR") + if vaultAddr != "" { + vaultURL, _ = url.Parse(vaultAddr) + } + } + callbackHost, ok := m[FieldCallbackHost] if !ok { - callbackHost = defaultCallbackHost + if vaultURL != nil { + callbackHost = vaultURL.Hostname() + } else { + callbackHost = defaultCallbackHost + } } callbackMethod, ok := m[FieldCallbackMethod] if !ok { - callbackMethod = defaultCallbackMethod + if vaultURL != nil { + callbackMethod = vaultURL.Scheme + } else { + callbackMethod = defaultCallbackMethod + } } callbackPort, ok := m[FieldCallbackPort] if !ok { - callbackPort = port + if vaultURL != nil { + callbackPort = vaultURL.Port() + "/v1/auth/" + mount + } else { + callbackPort = port + } } parseBool := func(f string, d bool) (bool, error) { @@ -115,20 +141,49 @@ func (h *CLIHandler) Auth(c *api.Client, m map[string]string) (*api.Secret, erro role := m["role"] - authURL, clientNonce, err := fetchAuthURL(c, role, mount, callbackPort, callbackMethod, callbackHost) + authURL, clientNonce, secret, err := fetchAuthURL(c, role, mount, callbackPort, callbackMethod, callbackHost) if err != nil { return nil, err } - // Set up callback handler doneCh := make(chan loginResp) - http.HandleFunc("/oidc/callback", callbackHandler(c, mount, clientNonce, doneCh)) - listener, err := net.Listen("tcp", listenAddress+":"+port) - if err != nil { - return nil, err + var pollInterval string + var interval int + var state string + var listener net.Listener + + if secret != nil { + pollInterval, _ = secret.Data["poll_interval"].(string) + state, _ = secret.Data["state"].(string) + } + if callbackMode == "direct" { + if state == "" { + return nil, errors.New("no state returned in direct callback mode") + } + if pollInterval == "" { + return nil, errors.New("no poll_interval returned in direct callback mode") + } + interval, err = strconv.Atoi(pollInterval) + if err != nil { + return nil, errors.New("cannot convert poll_interval " + pollInterval + " to integer") + } + } else { + if state != "" { + return nil, errors.New("state returned in client callback mode, try direct") + } + if pollInterval != "" { + return nil, errors.New("poll_interval returned in client callback mode") + } + // Set up callback handler + http.HandleFunc("/oidc/callback", callbackHandler(c, mount, clientNonce, doneCh)) + + listener, err = net.Listen("tcp", listenAddress+":"+port) + if err != nil { + return nil, err + } + defer listener.Close() } - defer listener.Close() // Open the default browser to the callback URL. if !skipBrowserLaunch { @@ -144,6 +199,26 @@ func (h *CLIHandler) Auth(c *api.Client, m map[string]string) (*api.Secret, erro } fmt.Fprintf(os.Stderr, "Waiting for OIDC authentication to complete...\n") + if callbackMode == "direct" { + data := map[string]interface{}{ + "state": state, + "client_nonce": clientNonce, + } + pollUrl := fmt.Sprintf("auth/%s/oidc/poll", mount) + for { + time.Sleep(time.Duration(interval) * time.Second) + + secret, err := c.Logical().Write(pollUrl, data) + if err == nil { + return secret, nil + } + if !strings.HasSuffix(err.Error(), "authorization_pending") { + return nil, err + } + // authorization is pending, try again + } + } + // Start local server go func() { err := http.Serve(listener, nil) @@ -210,12 +285,12 @@ func callbackHandler(c *api.Client, mount string, clientNonce string, doneCh cha } } -func fetchAuthURL(c *api.Client, role, mount, callbackPort string, callbackMethod string, callbackHost string) (string, string, error) { +func fetchAuthURL(c *api.Client, role, mount, callbackPort string, callbackMethod string, callbackHost string) (string, string, *api.Secret, error) { var authURL string clientNonce, err := base62.Random(20) if err != nil { - return "", "", err + return "", "", nil, err } redirectURI := fmt.Sprintf("%s://%s:%s/oidc/callback", callbackMethod, callbackHost, callbackPort) @@ -227,7 +302,7 @@ func fetchAuthURL(c *api.Client, role, mount, callbackPort string, callbackMetho secret, err := c.Logical().Write(fmt.Sprintf("auth/%s/oidc/auth_url", mount), data) if err != nil { - return "", "", err + return "", "", nil, err } if secret != nil { @@ -235,10 +310,10 @@ func fetchAuthURL(c *api.Client, role, mount, callbackPort string, callbackMetho } if authURL == "" { - return "", "", fmt.Errorf("Unable to authorize role %q with redirect_uri %q. Check Vault logs for more information.", role, redirectURI) + return "", "", nil, fmt.Errorf("Unable to authorize role %q with redirect_uri %q. Check Vault logs for more information.", role, redirectURI) } - return authURL, clientNonce, nil + return authURL, clientNonce, secret, nil } // parseError converts error from the API into summary and detailed portions. @@ -292,8 +367,8 @@ Usage: vault login -method=oidc [CONFIG K=V...] https://accounts.google.com/o/oauth2/v2/... - The default browser will be opened for the user to complete the login. Alternatively, - the user may visit the provided URL directly. + The default browser will be opened for the user to complete the login. + Alternatively, the user may visit the provided URL directly. Configuration: @@ -301,19 +376,29 @@ Configuration: Vault role of type "OIDC" to use for authentication. %s= - Optional address to bind the OIDC callback listener to (default: localhost). + Mode of callback: "direct" for direct connection to Vault or "client" + for connection to command line client (default: client). + + %s= + Optional address to bind the OIDC callback listener to in client callback + mode (default: localhost). %s= - Optional localhost port to use for OIDC callback (default: 8250). + Optional localhost port to use for OIDC callback in client callback mode + (default: 8250). %s= - Optional method to to use in OIDC redirect_uri (default: http). + Optional method to use in OIDC redirect_uri (default: the method from + $VAULT_ADDR in direct callback mode, else http) %s= - Optional callback host address to use in OIDC redirect_uri (default: localhost). + Optional callback host address to use in OIDC redirect_uri (default: + the host from $VAULT_ADDR in direct callback mode, else localhost). %s= - Optional port to to use in OIDC redirect_uri (default: the value set for port). + Optional port to use in OIDC redirect_uri (default: the value set for + port in client callback mode, else the port from $VAULT_ADDR with an + added /v1/auth/ where is from the login -path option). %s= Toggle the automatic launching of the default browser to the login URL. (default: false). @@ -321,6 +406,7 @@ Configuration: %s= Abort on any error. (default: false). `, + FieldCallbackMode, FieldListenAddress, FieldPort, FieldCallbackMethod, FieldCallbackHost, FieldCallbackPort, FieldSkipBrowser, FieldAbortOnError, diff --git a/cli_responses.go b/html_responses.go similarity index 100% rename from cli_responses.go rename to html_responses.go diff --git a/path_oidc.go b/path_oidc.go index a34a7fa9..6f7304d9 100644 --- a/path_oidc.go +++ b/path_oidc.go @@ -6,7 +6,6 @@ package jwtauth import ( "context" "encoding/json" - "errors" "fmt" "net" "net/http" @@ -52,6 +51,9 @@ type oidcRequest struct { // clientNonce is used between Vault and the client/application (e.g. CLI) making the request, // and is unrelated to the OIDC nonce above. It is optional. clientNonce string + + // this is for storing the response in direct callback mode + auth *logical.Auth } func pathOIDC(b *jwtAuthBackend) []*framework.Path { @@ -82,6 +84,9 @@ func pathOIDC(b *jwtAuthBackend) []*framework.Path { Type: framework.TypeString, Query: true, }, + "error_description": { + Type: framework.TypeString, + }, }, Operations: map[logical.Operation]framework.OperationHandler{ @@ -105,6 +110,26 @@ func pathOIDC(b *jwtAuthBackend) []*framework.Path { }, }, }, + { + Pattern: `oidc/poll`, + Fields: map[string]*framework.FieldSchema{ + "state": { + Type: framework.TypeString, + }, + "client_nonce": { + Type: framework.TypeString, + }, + }, + Operations: map[logical.Operation]framework.OperationHandler{ + logical.UpdateOperation: &framework.PathOperation{ + Callback: b.pathPoll, + Summary: "Poll endpoint to complete an OIDC login.", + + // state is cached so don't process OIDC logins on perf standbys + ForwardPerformanceStandby: true, + }, + }, + }, { Pattern: `oidc/auth_url`, @@ -125,7 +150,7 @@ func pathOIDC(b *jwtAuthBackend) []*framework.Path { }, "client_nonce": { Type: framework.TypeString, - Description: "Optional client-provided nonce that must match during callback, if present.", + Description: "Client-provided nonce that must match during callback, if present. Required only in direct callback mode.", }, }, @@ -167,11 +192,14 @@ func (b *jwtAuthBackend) pathCallbackPost(ctx context.Context, req *logical.Requ } // Store the provided code and/or token into its OIDC request, which must already exist. - oidcReq, err := b.amendOIDCRequest(stateID, code, idToken) - if err != nil { + oidcReq := b.getOIDCRequest(stateID) + if oidcReq == nil { resp.Data[logical.HTTPRawBody] = []byte(errorHTML(errLoginFailed, "Expired or missing OAuth state.")) resp.Data[logical.HTTPStatusCode] = http.StatusBadRequest } else { + oidcReq.code = code + oidcReq.idToken = idToken + b.setOIDCRequest(stateID, oidcReq) mount := parseMount(oidcReq.RedirectURL()) if mount == "" { resp.Data[logical.HTTPRawBody] = []byte(errorHTML(errLoginFailed, "Invalid redirect path.")) @@ -184,6 +212,19 @@ func (b *jwtAuthBackend) pathCallbackPost(ctx context.Context, req *logical.Requ return resp, nil } +func loginFailedResponse(useHttp bool, msg string) *logical.Response { + if !useHttp { + return logical.ErrorResponse(errLoginFailed + " " + msg) + } + return &logical.Response{ + Data: map[string]interface{}{ + logical.HTTPContentType: "text/html", + logical.HTTPStatusCode: http.StatusBadRequest, + logical.HTTPRawBody: []byte(errorHTML(errLoginFailed, msg)), + }, + } +} + func (b *jwtAuthBackend) pathCallback(ctx context.Context, req *logical.Request, d *framework.FieldData) (*logical.Response, error) { config, err := b.config(ctx, req.Storage) if err != nil { @@ -195,28 +236,45 @@ func (b *jwtAuthBackend) pathCallback(ctx context.Context, req *logical.Request, stateID := d.Get("state").(string) - oidcReq := b.verifyOIDCRequest(stateID) - if oidcReq == nil { + oidcReq := b.getOIDCRequest(stateID) + if oidcReq == nil || oidcReq.auth != nil { return logical.ErrorResponse(errLoginFailed + " Expired or missing OAuth state."), nil } - clientNonce := d.Get("client_nonce").(string) - - // If a client_nonce was provided at the start of the auth process as part of the auth_url - // request, require that it is present and matching during the callback phase. - if oidcReq.clientNonce != "" && clientNonce != oidcReq.clientNonce { - return logical.ErrorResponse("invalid client_nonce"), nil - } - roleName := oidcReq.rolename role, err := b.role(ctx, req.Storage, roleName) if err != nil { + b.deleteOIDCRequest(stateID) return nil, err } if role == nil { + b.deleteOIDCRequest(stateID) return logical.ErrorResponse(errLoginFailed + " Role could not be found"), nil } + useHttp := false + if role.CallbackMode == callbackModeDirect { + useHttp = true + } + if !useHttp { + // state is only accessed once when not using direct callback + b.deleteOIDCRequest(stateID) + } + + errorDescription := d.Get("error_description").(string) + if errorDescription != "" { + return loginFailedResponse(useHttp, errorDescription), nil + } + + clientNonce := d.Get("client_nonce").(string) + + // If a client_nonce was provided at the start of the auth process as part of the auth_url + // request, require that it is present and matching during the callback phase + // unless using the direct callback mode (when we instead check in poll). + if oidcReq.clientNonce != "" && clientNonce != oidcReq.clientNonce && !useHttp { + return logical.ErrorResponse("invalid client_nonce"), nil + } + if len(role.TokenBoundCIDRs) > 0 { if req.Connection == nil { b.Logger().Warn("token bound CIDRs found but no connection information available for validation") @@ -242,7 +300,7 @@ func (b *jwtAuthBackend) pathCallback(ctx context.Context, req *logical.Request, if code == "" { if oidcReq.idToken == "" { - return logical.ErrorResponse(errLoginFailed + " No code or id_token received."), nil + return loginFailedResponse(useHttp, "No code or id_token received."), nil } // Verify the ID token received from the authentication response. @@ -255,7 +313,7 @@ func (b *jwtAuthBackend) pathCallback(ctx context.Context, req *logical.Request, // ID token verification takes place in provider.Exchange. token, err = provider.Exchange(ctx, oidcReq, stateID, code) if err != nil { - return logical.ErrorResponse(errLoginFailed+" Error exchanging oidc code: %q.", err.Error()), nil + return loginFailedResponse(useHttp, fmt.Sprintf("Error exchanging oidc code: %q.", err.Error())), nil } rawToken = token.IDToken() @@ -287,7 +345,7 @@ func (b *jwtAuthBackend) pathCallback(ctx context.Context, req *logical.Request, } if role.BoundSubject != "" && role.BoundSubject != subject { - return nil, errors.New("sub claim does not match bound subject") + return loginFailedResponse(useHttp, "sub claim does not match bound subject"), nil } // Set the token source for the access token if it's available. It will only @@ -321,11 +379,11 @@ func (b *jwtAuthBackend) pathCallback(ctx context.Context, req *logical.Request, alias, groupAliases, err := b.createIdentity(ctx, allClaims, roleName, role, tokenSource) if err != nil { - return logical.ErrorResponse(err.Error()), nil + return loginFailedResponse(useHttp, err.Error()), nil } if err := validateBoundClaims(b.Logger(), role.BoundClaimsType, role.BoundClaims, allClaims); err != nil { - return logical.ErrorResponse("error validating claims: %s", err.Error()), nil + return loginFailedResponse(useHttp, fmt.Sprintf("error validating claims: %s", err.Error())), nil } tokenMetadata := make(map[string]string) @@ -354,13 +412,49 @@ func (b *jwtAuthBackend) pathCallback(ctx context.Context, req *logical.Request, role.PopulateTokenAuth(auth) - resp := &logical.Response{ - Auth: auth, + resp := &logical.Response{} + if useHttp { + oidcReq.auth = auth + b.setOIDCRequest(stateID, oidcReq) + resp.Data = map[string]interface{}{ + logical.HTTPContentType: "text/html", + logical.HTTPStatusCode: http.StatusOK, + logical.HTTPRawBody: []byte(successHTML), + } + } else { + resp.Auth = auth } return resp, nil } +func (b *jwtAuthBackend) pathPoll(ctx context.Context, req *logical.Request, d *framework.FieldData) (*logical.Response, error) { + stateID := d.Get("state").(string) + + oidcReq := b.getOIDCRequest(stateID) + if oidcReq == nil { + return logical.ErrorResponse(errLoginFailed + " Expired or missing OAuth state."), nil + } + + clientNonce := d.Get("client_nonce").(string) + + if oidcReq.clientNonce != "" && clientNonce != oidcReq.clientNonce { + b.deleteOIDCRequest(stateID) + return logical.ErrorResponse("invalid client_nonce"), nil + } + + if oidcReq.auth == nil { + // Return the same response as oauth 2.0 device flow in RFC8628 + return logical.ErrorResponse("authorization_pending"), nil + } + + b.deleteOIDCRequest(stateID) + resp := &logical.Response{ + Auth: oidcReq.auth, + } + return resp, nil +} + // authURL returns a URL used for redirection to receive an authorization code. // This path requires a role name, or that a default_role has been configured. // Because this endpoint is unauthenticated, the response to invalid or non-OIDC @@ -400,8 +494,6 @@ func (b *jwtAuthBackend) authURL(ctx context.Context, req *logical.Request, d *f return logical.ErrorResponse("missing redirect_uri"), nil } - clientNonce := d.Get("client_nonce").(string) - role, err := b.role(ctx, req.Storage, roleName) if err != nil { return nil, err @@ -410,9 +502,14 @@ func (b *jwtAuthBackend) authURL(ctx context.Context, req *logical.Request, d *f return logical.ErrorResponse("role %q could not be found", roleName), nil } - // If namespace will be passed around in state, and it has been provided as + clientNonce := d.Get("client_nonce").(string) + if clientNonce == "" && role.CallbackMode == callbackModeDirect { + return logical.ErrorResponse("missing client_nonce"), nil + } + + // If namespace will be passed around in oidcReq, and it has been provided as // a redirectURI query parameter, remove it from redirectURI, and append it - // to the state (later in this function) + // to the oidcReq (later in this function) namespace := "" if config.NamespaceInState { inputURI, err := url.Parse(redirectURI) @@ -460,13 +557,17 @@ func (b *jwtAuthBackend) authURL(ctx context.Context, req *logical.Request, d *f return resp, nil } - // embed namespace in state in the auth_url + // embed namespace in oidcReq in the auth_url if config.NamespaceInState && len(namespace) > 0 { stateWithNamespace := fmt.Sprintf("%s,ns=%s", oidcReq.State(), namespace) urlStr = strings.Replace(urlStr, oidcReq.State(), url.QueryEscape(stateWithNamespace), 1) } resp.Data["auth_url"] = urlStr + if role.CallbackMode == callbackModeDirect { + resp.Data["state"] = oidcReq.State() + resp.Data["poll_interval"] = "5" + } return resp, nil } @@ -509,32 +610,14 @@ func (b *jwtAuthBackend) createOIDCRequest(config *jwtConfig, role *jwtRole, rol return oidcReq, nil } -func (b *jwtAuthBackend) amendOIDCRequest(stateID, code, idToken string) (*oidcRequest, error) { - requestRaw, ok := b.oidcRequests.Get(stateID) - if !ok { - return nil, errors.New("OIDC state not found") - } - - oidcReq := requestRaw.(*oidcRequest) - oidcReq.code = code - oidcReq.idToken = idToken - +func (b *jwtAuthBackend) setOIDCRequest(stateID string, oidcReq *oidcRequest) { b.oidcRequests.SetDefault(stateID, oidcReq) - - return oidcReq, nil } -// verifyOIDCRequest tests whether the provided state ID is valid and returns the -// associated oidcRequest if so. A nil oidcRequest is returned if the ID is not found -// or expired. The oidcRequest should only ever be retrieved once and is deleted as -// part of this request. -func (b *jwtAuthBackend) verifyOIDCRequest(stateID string) *oidcRequest { - defer b.oidcRequests.Delete(stateID) - +func (b *jwtAuthBackend) getOIDCRequest(stateID string) *oidcRequest { if requestRaw, ok := b.oidcRequests.Get(stateID); ok { return requestRaw.(*oidcRequest) } - return nil } @@ -549,6 +632,10 @@ func isLocalAddr(hostname string) bool { return hostname == "localhost" } +func (b *jwtAuthBackend) deleteOIDCRequest(stateID string) { + b.oidcRequests.Delete(stateID) +} + // validRedirect checks whether uri is in allowed using special handling for loopback uris. // Ref: https://tools.ietf.org/html/rfc8252#section-7.3 func validRedirect(uri string, allowed []string) bool { diff --git a/path_oidc_test.go b/path_oidc_test.go index 776a4b5e..c78c7f49 100644 --- a/path_oidc_test.go +++ b/path_oidc_test.go @@ -527,7 +527,7 @@ func TestOIDC_AuthURL_max_age(t *testing.T) { // pointer syntax for the user_claim of roles. For claims used // in assertions, see the sampleClaims function. func TestOIDC_UserClaim_JSON_Pointer(t *testing.T) { - b, storage, s := getBackendAndServer(t, false) + b, storage, s := getBackendAndServer(t, false, "") defer s.server.Close() type args struct { @@ -773,14 +773,27 @@ func TestOIDC_ResponseTypeIDToken(t *testing.T) { func TestOIDC_Callback(t *testing.T) { t.Run("successful login", func(t *testing.T) { // run test with and without bound_cidrs configured - for _, useBoundCIDRs := range []bool{false, true} { - b, storage, s := getBackendAndServer(t, useBoundCIDRs) + // and with and without direct callback mode + for i := 1; i <= 3; i++ { + var useBoundCIDRs bool + var callbackMode string + + if i == 2 { + useBoundCIDRs = true + } else if i == 3 { + callbackMode = "direct" + } + + b, storage, s := getBackendAndServer(t, useBoundCIDRs, callbackMode) defer s.server.Close() + clientNonce := "456" + // get auth_url data := map[string]interface{}{ "role": "test", "redirect_uri": "https://example.com", + "client_nonce": clientNonce, } req := &logical.Request{ Operation: logical.UpdateOperation, @@ -815,8 +828,9 @@ func TestOIDC_Callback(t *testing.T) { Path: "oidc/callback", Storage: storage, Data: map[string]interface{}{ - "state": state, - "code": "abc", + "state": state, + "code": "abc", + "client_nonce": clientNonce, }, Connection: &logical.Connection{ RemoteAddr: "127.0.0.42", @@ -828,6 +842,22 @@ func TestOIDC_Callback(t *testing.T) { t.Fatal(err) } + if callbackMode == "direct" { + req = &logical.Request{ + Operation: logical.UpdateOperation, + Path: "oidc/poll", + Storage: storage, + Data: map[string]interface{}{ + "state": state, + "client_nonce": clientNonce, + }, + } + resp, err = b.HandleRequest(context.Background(), req) + if err != nil { + t.Fatal(err) + } + } + expected := &logical.Auth{ LeaseOptions: logical.LeaseOptions{ Renewable: true, @@ -874,7 +904,7 @@ func TestOIDC_Callback(t *testing.T) { }) t.Run("failed login - bad nonce", func(t *testing.T) { - b, storage, s := getBackendAndServer(t, false) + b, storage, s := getBackendAndServer(t, false, "") defer s.server.Close() // get auth_url @@ -928,7 +958,7 @@ func TestOIDC_Callback(t *testing.T) { }) t.Run("failed login - bound claim mismatch", func(t *testing.T) { - b, storage, s := getBackendAndServer(t, false) + b, storage, s := getBackendAndServer(t, false, "") defer s.server.Close() // get auth_url @@ -984,7 +1014,7 @@ func TestOIDC_Callback(t *testing.T) { }) t.Run("missing state", func(t *testing.T) { - b, storage, s := getBackendAndServer(t, false) + b, storage, s := getBackendAndServer(t, false, "") defer s.server.Close() req := &logical.Request{ @@ -1003,7 +1033,7 @@ func TestOIDC_Callback(t *testing.T) { }) t.Run("unknown state", func(t *testing.T) { - b, storage, s := getBackendAndServer(t, false) + b, storage, s := getBackendAndServer(t, false, "") defer s.server.Close() req := &logical.Request{ @@ -1025,7 +1055,7 @@ func TestOIDC_Callback(t *testing.T) { }) t.Run("valid state, missing code", func(t *testing.T) { - b, storage, s := getBackendAndServer(t, false) + b, storage, s := getBackendAndServer(t, false, "") defer s.server.Close() // get auth_url @@ -1067,7 +1097,7 @@ func TestOIDC_Callback(t *testing.T) { }) t.Run("failed code exchange", func(t *testing.T) { - b, storage, s := getBackendAndServer(t, false) + b, storage, s := getBackendAndServer(t, false, "") defer s.server.Close() // get auth_url @@ -1117,7 +1147,7 @@ func TestOIDC_Callback(t *testing.T) { }) t.Run("failed code exchange (PKCE)", func(t *testing.T) { - b, storage, s := getBackendAndServer(t, false) + b, storage, s := getBackendAndServer(t, false, "") defer s.server.Close() // get auth_url @@ -1169,7 +1199,7 @@ func TestOIDC_Callback(t *testing.T) { }) t.Run("no response from provider", func(t *testing.T) { - b, storage, s := getBackendAndServer(t, false) + b, storage, s := getBackendAndServer(t, false, "") // get auth_url data := map[string]interface{}{ @@ -1215,7 +1245,7 @@ func TestOIDC_Callback(t *testing.T) { }) t.Run("test bad address", func(t *testing.T) { - b, storage, s := getBackendAndServer(t, true) + b, storage, s := getBackendAndServer(t, true, "") defer s.server.Close() s.code = "abc" @@ -1260,7 +1290,7 @@ func TestOIDC_Callback(t *testing.T) { }) t.Run("test invalid client_id", func(t *testing.T) { - b, storage, s := getBackendAndServer(t, false) + b, storage, s := getBackendAndServer(t, false, "") defer s.server.Close() s.code = "abc" @@ -1316,7 +1346,7 @@ func TestOIDC_Callback(t *testing.T) { }) t.Run("client_nonce", func(t *testing.T) { - b, storage, s := getBackendAndServer(t, false) + b, storage, s := getBackendAndServer(t, false, "") defer s.server.Close() // General behavior is that if a client_nonce is provided during the authURL phase @@ -1587,7 +1617,7 @@ func TestOIDC_ValidRedirect(t *testing.T) { } } -func getBackendAndServer(t *testing.T, boundCIDRs bool) (logical.Backend, logical.Storage, *oidcProvider) { +func getBackendAndServer(t *testing.T, boundCIDRs bool, callbackMode string) (logical.Backend, logical.Storage, *oidcProvider) { b, storage := getBackend(t) s := newOIDCProvider(t) s.clientID = "abc" @@ -1646,6 +1676,10 @@ func getBackendAndServer(t *testing.T, boundCIDRs bool) (logical.Backend, logica data["bound_cidrs"] = "127.0.0.42" } + if callbackMode != "" { + data["callback_mode"] = callbackMode + } + req = &logical.Request{ Operation: logical.CreateOperation, Path: "role/test", diff --git a/path_role.go b/path_role.go index 188f74d6..6e1b33e6 100644 --- a/path_role.go +++ b/path_role.go @@ -24,6 +24,8 @@ const ( claimDefaultLeeway = 150 boundClaimsTypeString = "string" boundClaimsTypeGlob = "glob" + callbackModeDirect = "direct" + callbackModeClient = "client" ) func pathRoleList(b *jwtAuthBackend) *framework.Path { @@ -154,6 +156,11 @@ for referencing claims.`, Type: framework.TypeCommaStringSlice, Description: `Comma-separated list of allowed values for redirect_uri`, }, + "callback_mode": { + Type: framework.TypeString, + Description: `OIDC callback mode from Authorization Server: allowed values are 'direct' to Vault or 'client', default 'client'`, + Default: callbackModeClient, + }, "verbose_oidc_logging": { Type: framework.TypeBool, Description: `Log received OIDC tokens and claims when debug-level logging is active. @@ -222,6 +229,7 @@ type jwtRole struct { GroupsClaim string `json:"groups_claim"` OIDCScopes []string `json:"oidc_scopes"` AllowedRedirectURIs []string `json:"allowed_redirect_uris"` + CallbackMode string `json:"callback_mode"` VerboseOIDCLogging bool `json:"verbose_oidc_logging"` MaxAge time.Duration `json:"max_age"` UserClaimJSONPointer bool `json:"user_claim_json_pointer"` @@ -330,6 +338,7 @@ func (b *jwtAuthBackend) pathRoleRead(ctx context.Context, req *logical.Request, "user_claim_json_pointer": role.UserClaimJSONPointer, "groups_claim": role.GroupsClaim, "allowed_redirect_uris": role.AllowedRedirectURIs, + "callback_mode": role.CallbackMode, "oidc_scopes": role.OIDCScopes, "verbose_oidc_logging": role.VerboseOIDCLogging, "max_age": int64(role.MaxAge.Seconds()), @@ -356,6 +365,20 @@ func (b *jwtAuthBackend) pathRoleRead(ctx context.Context, req *logical.Request, d["num_uses"] = role.NumUses } + if role.CallbackMode == "" { + // Must have been after an upgrade. Store the default value. + role.CallbackMode = "client" + d["callback_mode"] = role.CallbackMode + + entry, err := logical.StorageEntryJSON(rolePrefix+roleName, role) + if err != nil { + return nil, err + } + if err = req.Storage.Put(ctx, entry); err != nil { + return nil, err + } + } + return &logical.Response{ Data: d, }, nil @@ -541,6 +564,14 @@ func (b *jwtAuthBackend) pathRoleCreateUpdate(ctx context.Context, req *logical. role.AllowedRedirectURIs = allowedRedirectURIs.([]string) } + callbackMode := data.Get("callback_mode").(string) + switch callbackMode { + case callbackModeDirect, callbackModeClient: + role.CallbackMode = callbackMode + default: + return logical.ErrorResponse("invalid 'callback_mode': %s", callbackMode), nil + } + if role.RoleType == "oidc" && len(role.AllowedRedirectURIs) == 0 { return logical.ErrorResponse( "'allowed_redirect_uris' must be set if 'role_type' is 'oidc' or unspecified."), nil diff --git a/path_role_test.go b/path_role_test.go index 5628433f..96d56f60 100644 --- a/path_role_test.go +++ b/path_role_test.go @@ -91,6 +91,7 @@ func TestPath_Create(t *testing.T) { NumUses: 12, BoundCIDRs: []*sockaddr.SockAddrMarshaler{{SockAddr: expectedSockAddr}}, AllowedRedirectURIs: []string(nil), + CallbackMode: "client", MaxAge: 60 * time.Second, } @@ -564,6 +565,7 @@ func TestPath_OIDCCreate(t *testing.T) { "bar": "baz", }, AllowedRedirectURIs: []string{"https://example.com", "http://localhost:8250"}, + CallbackMode: "client", ClaimMappings: map[string]string{ "foo": "a", "bar": "b", @@ -770,6 +772,7 @@ func TestPath_Read(t *testing.T) { "bound_subject": "testsub", "bound_audiences": []string{"vault"}, "allowed_redirect_uris": []string{"http://127.0.0.1"}, + "callback_mode": "client", "oidc_scopes": []string{"email", "profile"}, "user_claim": "user", "user_claim_json_pointer": false, From 1db39e62c65296eeffb1eeb35f9dcd12613f27b2 Mon Sep 17 00:00:00 2001 From: Dave Dykstra <2129743+DrDaveD@users.noreply.github.com> Date: Tue, 13 Aug 2024 15:16:08 -0500 Subject: [PATCH 2/2] Add device flow Signed-off-by: Dave Dykstra <2129743+DrDaveD@users.noreply.github.com> --- cli.go | 23 +++-- path_config.go | 91 +++++++++++++++++ path_oidc.go | 252 ++++++++++++++++++++++++++++++++++++++++------ path_oidc_test.go | 102 ++++++++++++------- path_role.go | 19 +++- 5 files changed, 409 insertions(+), 78 deletions(-) diff --git a/cli.go b/cli.go index d84a6c0d..15bc04d0 100644 --- a/cli.go +++ b/cli.go @@ -151,18 +151,20 @@ func (h *CLIHandler) Auth(c *api.Client, m map[string]string) (*api.Secret, erro var pollInterval string var interval int var state string + var userCode string var listener net.Listener if secret != nil { pollInterval, _ = secret.Data["poll_interval"].(string) state, _ = secret.Data["state"].(string) + userCode, _ = secret.Data["user_code"].(string) } - if callbackMode == "direct" { + if callbackMode != "client" { if state == "" { - return nil, errors.New("no state returned in direct callback mode") + return nil, errors.New("no state returned in " + callbackMode + " callback mode") } if pollInterval == "" { - return nil, errors.New("no poll_interval returned in direct callback mode") + return nil, errors.New("no poll_interval returned in " + callbackMode + " callback mode") } interval, err = strconv.Atoi(pollInterval) if err != nil { @@ -199,7 +201,11 @@ func (h *CLIHandler) Auth(c *api.Client, m map[string]string) (*api.Secret, erro } fmt.Fprintf(os.Stderr, "Waiting for OIDC authentication to complete...\n") - if callbackMode == "direct" { + if userCode != "" { + fmt.Fprintf(os.Stderr, "When prompted, enter code %s\n\n", userCode) + } + + if callbackMode != "client" { data := map[string]interface{}{ "state": state, "client_nonce": clientNonce, @@ -212,7 +218,9 @@ func (h *CLIHandler) Auth(c *api.Client, m map[string]string) (*api.Secret, erro if err == nil { return secret, nil } - if !strings.HasSuffix(err.Error(), "authorization_pending") { + if strings.HasSuffix(err.Error(), "slow_down") { + interval *= 2 + } else if !strings.HasSuffix(err.Error(), "authorization_pending") { return nil, err } // authorization is pending, try again @@ -376,8 +384,9 @@ Configuration: Vault role of type "OIDC" to use for authentication. %s= - Mode of callback: "direct" for direct connection to Vault or "client" - for connection to command line client (default: client). + Mode of callback: "direct" for direct connection to Vault, "client" + for connection to command line client, or "device" for device flow + which has no callback (default: client). %s= Optional address to bind the OIDC callback listener to in client callback diff --git a/path_config.go b/path_config.go index 6d137bac..38ae1bfb 100644 --- a/path_config.go +++ b/path_config.go @@ -9,9 +9,12 @@ import ( "crypto/tls" "crypto/x509" "encoding/asn1" + "encoding/json" "errors" "fmt" + "io/ioutil" "net/http" + "net/url" "strings" "github.com/hashicorp/cap/jwt" @@ -174,6 +177,91 @@ func (b *jwtAuthBackend) config(ctx context.Context, s logical.Storage) (*jwtCon return config, nil } +func contactIssuer(ctx context.Context, uri string, data *url.Values, ignoreBad bool) ([]byte, error) { + var req *http.Request + var err error + if data == nil { + req, err = http.NewRequest("GET", uri, nil) + } else { + req, err = http.NewRequest("POST", uri, strings.NewReader(data.Encode())) + } + if err != nil { + return nil, err + } + if data != nil { + req.Header.Add("Content-Type", "application/x-www-form-urlencoded") + } + + client, ok := ctx.Value(oauth2.HTTPClient).(*http.Client) + if !ok { + client = http.DefaultClient + } + resp, err := client.Do(req.WithContext(ctx)) + if err != nil { + return nil, nil + } + defer resp.Body.Close() + + body, err := ioutil.ReadAll(resp.Body) + if err != nil { + return nil, nil + } + + if resp.StatusCode != http.StatusOK && (!ignoreBad || resp.StatusCode != http.StatusBadRequest) { + return nil, fmt.Errorf("%s: %s", resp.Status, body) + } + + return body, nil +} + +// Discover the device_authorization_endpoint URL and store it in the config +// This should be in coreos/go-oidc but they don't yet support device flow +// At the same time, look up token_endpoint and store it as well +// Returns nil on success, otherwise returns an error +func (b *jwtAuthBackend) configDeviceAuthURL(ctx context.Context, s logical.Storage) error { + config, err := b.config(ctx, s) + if err != nil { + return err + } + + b.l.Lock() + defer b.l.Unlock() + + if config.OIDCDeviceAuthURL != "" { + if config.OIDCDeviceAuthURL == "N/A" { + return fmt.Errorf("no device auth endpoint url discovered") + } + return nil + } + + caCtx, err := b.createCAContext(b.providerCtx, config.OIDCDiscoveryCAPEM) + if err != nil { + return errwrap.Wrapf("error creating context for device auth: {{err}}", err) + } + + issuer := config.OIDCDiscoveryURL + + wellKnown := strings.TrimSuffix(issuer, "/") + "/.well-known/openid-configuration" + body, err := contactIssuer(caCtx, wellKnown, nil, false) + if err != nil { + return errwrap.Wrapf("error reading issuer config: {{err}}", err) + } + + var daj struct { + DeviceAuthURL string `json:"device_authorization_endpoint"` + TokenURL string `json:"token_endpoint"` + } + err = json.Unmarshal(body, &daj) + if err != nil || daj.DeviceAuthURL == "" { + b.cachedConfig.OIDCDeviceAuthURL = "N/A" + return fmt.Errorf("no device auth endpoint url discovered") + } + + b.cachedConfig.OIDCDeviceAuthURL = daj.DeviceAuthURL + b.cachedConfig.OIDCTokenURL = daj.TokenURL + return nil +} + func (b *jwtAuthBackend) pathConfigRead(ctx context.Context, req *logical.Request, d *framework.FieldData) (*logical.Response, error) { config, err := b.config(ctx, req.Storage) if err != nil { @@ -502,6 +590,9 @@ type jwtConfig struct { UnsupportedCriticalCertExtensions []string `json:"unsupported_critical_cert_extensions"` ParsedJWTPubKeys []crypto.PublicKey `json:"-"` + // These are looked up from OIDCDiscoveryURL when needed + OIDCDeviceAuthURL string `json:"-"` + OIDCTokenURL string `json:"-"` } const ( diff --git a/path_oidc.go b/path_oidc.go index 6f7304d9..5f81ece8 100644 --- a/path_oidc.go +++ b/path_oidc.go @@ -54,6 +54,9 @@ type oidcRequest struct { // this is for storing the response in direct callback mode auth *logical.Auth + + // the device flow code + deviceCode string } func pathOIDC(b *jwtAuthBackend) []*framework.Path { @@ -146,7 +149,7 @@ func pathOIDC(b *jwtAuthBackend) []*framework.Path { }, "redirect_uri": { Type: framework.TypeString, - Description: "The OAuth redirect_uri to use in the authorization URL.", + Description: "The OAuth redirect_uri to use in the authorization URL. Not needed with device flow.", }, "client_nonce": { Type: framework.TypeString, @@ -241,6 +244,13 @@ func (b *jwtAuthBackend) pathCallback(ctx context.Context, req *logical.Request, return logical.ErrorResponse(errLoginFailed + " Expired or missing OAuth state."), nil } + deleteRequest := true + defer func() { + if deleteRequest { + b.deleteOIDCRequest(stateID) + } + }() + roleName := oidcReq.rolename role, err := b.role(ctx, req.Storage, roleName) if err != nil { @@ -248,17 +258,14 @@ func (b *jwtAuthBackend) pathCallback(ctx context.Context, req *logical.Request, return nil, err } if role == nil { - b.deleteOIDCRequest(stateID) return logical.ErrorResponse(errLoginFailed + " Role could not be found"), nil } useHttp := false if role.CallbackMode == callbackModeDirect { useHttp = true - } - if !useHttp { - // state is only accessed once when not using direct callback - b.deleteOIDCRequest(stateID) + // save request for poll + deleteRequest = false } errorDescription := d.Get("error_description").(string) @@ -290,8 +297,13 @@ func (b *jwtAuthBackend) pathCallback(ctx context.Context, req *logical.Request, return nil, errwrap.Wrapf("error getting provider for login operation: {{err}}", err) } - var rawToken oidc.IDToken + oidcCtx, err := b.createCAContext(ctx, config.OIDCDiscoveryCAPEM) + if err != nil { + return nil, errwrap.Wrapf("error preparing context for login operation: {{err}}", err) + } + var token *oidc.Tk + var tokenSource oauth2.TokenSource code := d.Get("code").(string) if code == noCode { @@ -304,10 +316,15 @@ func (b *jwtAuthBackend) pathCallback(ctx context.Context, req *logical.Request, } // Verify the ID token received from the authentication response. - rawToken = oidc.IDToken(oidcReq.idToken) + rawToken := oidc.IDToken(oidcReq.idToken) if _, err := provider.VerifyIDToken(ctx, rawToken, oidcReq); err != nil { return logical.ErrorResponse("%s %s", errTokenVerification, err.Error()), nil } + + token, err = oidc.NewToken(rawToken, nil) + if err != nil { + return nil, errwrap.Wrapf("error creating oidc token: {{err}}", err) + } } else { // Exchange the authorization code for an ID token and access token. // ID token verification takes place in provider.Exchange. @@ -316,13 +333,19 @@ func (b *jwtAuthBackend) pathCallback(ctx context.Context, req *logical.Request, return loginFailedResponse(useHttp, fmt.Sprintf("Error exchanging oidc code: %q.", err.Error())), nil } - rawToken = token.IDToken() + tokenSource = token.StaticTokenSource() } + return b.processToken(ctx, config, oidcCtx, provider, roleName, role, token, tokenSource, stateID, oidcReq, useHttp) +} + +// Continue processing a token after it has been received from the +// OIDC provider from either code or device authorization flows +func (b *jwtAuthBackend) processToken(ctx context.Context, config *jwtConfig, oidcCtx context.Context, provider *oidc.Provider, roleName string, role *jwtRole, token *oidc.Tk, tokenSource oauth2.TokenSource, stateID string, oidcReq *oidcRequest, useHttp bool) (*logical.Response, error) { if role.VerboseOIDCLogging { loggedToken := "invalid token format" - parts := strings.Split(string(rawToken), ".") + parts := strings.Split(string(token.IDToken()), ".") if len(parts) == 3 { // strip signature from logged token loggedToken = fmt.Sprintf("%s.%s.xxxxxxxxxxx", parts[0], parts[1]) @@ -333,10 +356,16 @@ func (b *jwtAuthBackend) pathCallback(ctx context.Context, req *logical.Request, // Parse claims from the ID token payload. var allClaims map[string]interface{} - if err := rawToken.Claims(&allClaims); err != nil { + if err := token.IDToken().Claims(&allClaims); err != nil { return nil, err } - delete(allClaims, "nonce") + + if claimNonce, ok := allClaims["nonce"]; ok { + if oidcReq != nil && claimNonce != oidcReq.Nonce() { + return loginFailedResponse(useHttp, "invalid ID token nonce."), nil + } + delete(allClaims, "nonce") + } // Get the subject claim for bound subject and user info validation var subject string @@ -348,15 +377,7 @@ func (b *jwtAuthBackend) pathCallback(ctx context.Context, req *logical.Request, return loginFailedResponse(useHttp, "sub claim does not match bound subject"), nil } - // Set the token source for the access token if it's available. It will only - // be available for the authorization code flow (oidc_response_types=code). - // The access token will be used for fetching additional user and group info. - var tokenSource oauth2.TokenSource - if token != nil { - tokenSource = token.StaticTokenSource() - } - - // If we have a token, attempt to fetch information from the /userinfo endpoint + // If we have a tokenSource, attempt to fetch information from the /userinfo endpoint // and merge it with the existing claims data. A failure to fetch additional information // from this endpoint will not invalidate the authorization flow. if tokenSource != nil { @@ -428,27 +449,116 @@ func (b *jwtAuthBackend) pathCallback(ctx context.Context, req *logical.Request, return resp, nil } +// second half of the client API for direct and device callback modes func (b *jwtAuthBackend) pathPoll(ctx context.Context, req *logical.Request, d *framework.FieldData) (*logical.Response, error) { stateID := d.Get("state").(string) - oidcReq := b.getOIDCRequest(stateID) if oidcReq == nil { return logical.ErrorResponse(errLoginFailed + " Expired or missing OAuth state."), nil } + deleteRequest := true + defer func() { + if deleteRequest { + b.deleteOIDCRequest(stateID) + } + }() + clientNonce := d.Get("client_nonce").(string) if oidcReq.clientNonce != "" && clientNonce != oidcReq.clientNonce { - b.deleteOIDCRequest(stateID) return logical.ErrorResponse("invalid client_nonce"), nil } + roleName := oidcReq.rolename + role, err := b.role(ctx, req.Storage, roleName) + if err != nil { + return nil, err + } + if role == nil { + return logical.ErrorResponse(errLoginFailed + " Role could not be found"), nil + } + + if role.CallbackMode == callbackModeDevice { + config, err := b.config(ctx, req.Storage) + if err != nil { + return nil, err + } + if config == nil { + return logical.ErrorResponse(errLoginFailed + " Could not load configuration"), nil + } + + caCtx, err := b.createCAContext(ctx, config.OIDCDiscoveryCAPEM) + if err != nil { + return nil, err + } + provider, err := b.getProvider(config) + if err != nil { + return nil, errwrap.Wrapf("error getting provider for poll operation: {{err}}", err) + } + + values := url.Values{ + "client_id": {config.OIDCClientID}, + "client_secret": {config.OIDCClientSecret}, + "device_code": {oidcReq.deviceCode}, + "grant_type": {"urn:ietf:params:oauth:grant-type:device_code"}, + } + body, err := contactIssuer(caCtx, config.OIDCTokenURL, &values, true) + if err != nil { + return nil, errwrap.Wrapf("error polling for device authorization: {{err}}", err) + } + + var tokenOrError struct { + *oauth2.Token + Error string `json:"error,omitempty"` + } + err = json.Unmarshal(body, &tokenOrError) + if err != nil { + return nil, fmt.Errorf("error decoding issuer response while polling for token: %v; response: %v", err, string(body)) + } + + if tokenOrError.Error != "" { + if tokenOrError.Error == "authorization_pending" || tokenOrError.Error == "slow_down" { + // save request for another poll + deleteRequest = false + return logical.ErrorResponse(tokenOrError.Error), nil + } + return logical.ErrorResponse("authorization failed: %v", tokenOrError.Error), nil + } + + extra := make(map[string]interface{}) + err = json.Unmarshal(body, &extra) + if err != nil { + // already been unmarshalled once, unlikely + return nil, err + } + oauth2Token := tokenOrError.Token.WithExtra(extra) + + // idToken, ok := oauth2Token.Extra("id_token").(oidc.IDToken) + rawToken, ok := oauth2Token.Extra("id_token").(string) + if !ok { + return logical.ErrorResponse(errTokenVerification + " No id_token found in response."), nil + } + idToken := oidc.IDToken(rawToken) + token, err := oidc.NewToken(idToken, tokenOrError.Token) + if err != nil { + return nil, errwrap.Wrapf("error creating oidc token: {{err}}", err) + } + + return b.processToken(ctx, config, caCtx, provider, roleName, role, token, oauth2.StaticTokenSource(oauth2Token), "", nil, false) + } + + // else it's the direct callback mode + if oidcReq.auth == nil { + // save request for another poll + deleteRequest = false + } + if oidcReq.auth == nil { // Return the same response as oauth 2.0 device flow in RFC8628 return logical.ErrorResponse("authorization_pending"), nil } - b.deleteOIDCRequest(stateID) resp := &logical.Response{ Auth: oidcReq.auth, } @@ -490,9 +600,6 @@ func (b *jwtAuthBackend) authURL(ctx context.Context, req *logical.Request, d *f } redirectURI := d.Get("redirect_uri").(string) - if redirectURI == "" { - return logical.ErrorResponse("missing redirect_uri"), nil - } role, err := b.role(ctx, req.Storage, roleName) if err != nil { @@ -503,10 +610,88 @@ func (b *jwtAuthBackend) authURL(ctx context.Context, req *logical.Request, d *f } clientNonce := d.Get("client_nonce").(string) - if clientNonce == "" && role.CallbackMode == callbackModeDirect { + if clientNonce == "" && + (role.CallbackMode == callbackModeDirect || + role.CallbackMode == callbackModeDevice) { return logical.ErrorResponse("missing client_nonce"), nil } + if role.CallbackMode == callbackModeDevice { + // start a device flow + caCtx, err := b.createCAContext(ctx, config.OIDCDiscoveryCAPEM) + if err != nil { + return nil, err + } + + // Discover the device url endpoint if not already known + // This adds it to the cached config + err = b.configDeviceAuthURL(ctx, req.Storage) + if err != nil { + return nil, err + } + + // "openid" is a required scope for OpenID Connect flows + scopes := append([]string{"openid"}, role.OIDCScopes...) + + values := url.Values{ + "client_id": {config.OIDCClientID}, + "client_secret": {config.OIDCClientSecret}, + "scope": {strings.Join(scopes, " ")}, + } + body, err := contactIssuer(caCtx, config.OIDCDeviceAuthURL, &values, false) + if err != nil { + return nil, errwrap.Wrapf("error authorizing device: {{err}}", err) + } + + var deviceCode struct { + DeviceCode string `json:"device_code"` + UserCode string `json:"user_code"` + VerificationURI string `json:"verification_uri"` + VerificationURIComplete string `json:"verification_uri_complete"` + // Google and other old implementations use url instead of uri + VerificationURL string `json:"verification_url"` + VerificationURLComplete string `json:"verification_url_complete"` + Interval int `json:"interval"` + } + err = json.Unmarshal(body, &deviceCode) + if err != nil { + return nil, fmt.Errorf("error decoding issuer response to device auth: %v; response: %v", err, string(body)) + } + // currently hashicorp/cap/oidc.NewRequest requires + // redirectURL to be non-empty so throw in place holder + oidcReq, err := b.createOIDCRequest(config, role, roleName, "-", deviceCode.DeviceCode, clientNonce) + if err != nil { + logger.Warn("error generating OAuth state", "error", err) + return resp, nil + } + + if deviceCode.VerificationURIComplete != "" { + resp.Data["auth_url"] = deviceCode.VerificationURIComplete + } else if deviceCode.VerificationURLComplete != "" { + resp.Data["auth_url"] = deviceCode.VerificationURLComplete + } else { + if deviceCode.VerificationURI != "" { + resp.Data["auth_url"] = deviceCode.VerificationURI + } else { + resp.Data["auth_url"] = deviceCode.VerificationURL + } + resp.Data["user_code"] = deviceCode.UserCode + } + resp.Data["state"] = oidcReq.State() + interval := 5 + if role.PollInterval != 0 { + interval = role.PollInterval + } else if deviceCode.Interval != 0 { + interval = deviceCode.Interval + } + resp.Data["poll_interval"] = fmt.Sprintf("%d", interval) + return resp, nil + } + + if redirectURI == "" { + return logical.ErrorResponse("missing redirect_uri"), nil + } + // If namespace will be passed around in oidcReq, and it has been provided as // a redirectURI query parameter, remove it from redirectURI, and append it // to the oidcReq (later in this function) @@ -545,7 +730,7 @@ func (b *jwtAuthBackend) authURL(ctx context.Context, req *logical.Request, d *f return resp, nil } - oidcReq, err := b.createOIDCRequest(config, role, roleName, redirectURI, clientNonce) + oidcReq, err := b.createOIDCRequest(config, role, roleName, redirectURI, "", clientNonce) if err != nil { logger.Warn("error generating OAuth state", "error", err) return resp, nil @@ -566,7 +751,11 @@ func (b *jwtAuthBackend) authURL(ctx context.Context, req *logical.Request, d *f resp.Data["auth_url"] = urlStr if role.CallbackMode == callbackModeDirect { resp.Data["state"] = oidcReq.State() - resp.Data["poll_interval"] = "5" + interval := 5 + if role.PollInterval != 0 { + interval = role.PollInterval + } + resp.Data["poll_interval"] = fmt.Sprintf("%d", interval) } return resp, nil @@ -574,7 +763,7 @@ func (b *jwtAuthBackend) authURL(ctx context.Context, req *logical.Request, d *f // createOIDCRequest makes an expiring request object, associated with a random state ID // that is passed throughout the OAuth process. A nonce is also included in the auth process. -func (b *jwtAuthBackend) createOIDCRequest(config *jwtConfig, role *jwtRole, rolename, redirectURI, clientNonce string) (*oidcRequest, error) { +func (b *jwtAuthBackend) createOIDCRequest(config *jwtConfig, role *jwtRole, rolename, redirectURI, deviceCode string, clientNonce string) (*oidcRequest, error) { options := []oidc.Option{ oidc.WithAudiences(role.BoundAudiences...), oidc.WithScopes(role.OIDCScopes...), @@ -604,6 +793,7 @@ func (b *jwtAuthBackend) createOIDCRequest(config *jwtConfig, role *jwtRole, rol Request: request, rolename: rolename, clientNonce: clientNonce, + deviceCode: deviceCode, } b.oidcRequests.SetDefault(request.State(), oidcReq) diff --git a/path_oidc_test.go b/path_oidc_test.go index c78c7f49..2ab7f12c 100644 --- a/path_oidc_test.go +++ b/path_oidc_test.go @@ -774,14 +774,16 @@ func TestOIDC_Callback(t *testing.T) { t.Run("successful login", func(t *testing.T) { // run test with and without bound_cidrs configured // and with and without direct callback mode - for i := 1; i <= 3; i++ { + for i := 1; i <= 4; i++ { var useBoundCIDRs bool - var callbackMode string + callbackMode := "client" if i == 2 { useBoundCIDRs = true } else if i == 3 { callbackMode = "direct" + } else if i == 4 { + callbackMode = "device" } b, storage, s := getBackendAndServer(t, useBoundCIDRs, callbackMode) @@ -789,6 +791,9 @@ func TestOIDC_Callback(t *testing.T) { clientNonce := "456" + // set mock provider's expected code + s.code = "abc" + // get auth_url data := map[string]interface{}{ "role": "test", @@ -807,42 +812,45 @@ func TestOIDC_Callback(t *testing.T) { t.Fatalf("err:%v resp:%#v\n", err, resp) } - authURL := resp.Data["auth_url"].(string) - - state := getQueryParam(t, authURL, "state") - nonce := getQueryParam(t, authURL, "nonce") + var state string - // set provider claims that will be returned by the mock server - s.customClaims = sampleClaims(nonce) + if callbackMode == "device" { + state = resp.Data["state"].(string) + s.customClaims = sampleClaims("") + } else { + authURL := resp.Data["auth_url"].(string) + state = getQueryParam(t, authURL, "state") + nonce := getQueryParam(t, authURL, "nonce") - // set mock provider's expected code - s.code = "abc" + // set provider claims that will be returned by the mock server + s.customClaims = sampleClaims(nonce) - // save PKCE challenge - s.codeChallenge = getQueryParam(t, authURL, "code_challenge") + // save PKCE challenge + s.codeChallenge = getQueryParam(t, authURL, "code_challenge") - // invoke the callback, which will try to exchange the code - // with the mock provider. - req = &logical.Request{ - Operation: logical.ReadOperation, - Path: "oidc/callback", - Storage: storage, - Data: map[string]interface{}{ - "state": state, - "code": "abc", - "client_nonce": clientNonce, - }, - Connection: &logical.Connection{ - RemoteAddr: "127.0.0.42", - }, - } + // invoke the callback, which will try to exchange the code + // with the mock provider. + req = &logical.Request{ + Operation: logical.ReadOperation, + Path: "oidc/callback", + Storage: storage, + Data: map[string]interface{}{ + "state": state, + "code": "abc", + "client_nonce": clientNonce, + }, + Connection: &logical.Connection{ + RemoteAddr: "127.0.0.42", + }, + } - resp, err = b.HandleRequest(context.Background(), req) - if err != nil { - t.Fatal(err) + resp, err = b.HandleRequest(context.Background(), req) + if err != nil { + t.Fatal(err) + } } - if callbackMode == "direct" { + if callbackMode != "client" { req = &logical.Request{ Operation: logical.UpdateOperation, Path: "oidc/poll", @@ -1466,6 +1474,7 @@ func (o *oidcProvider) ServeHTTP(w http.ResponseWriter, r *http.Request) { { "issuer": "%s", "authorization_endpoint": "%s/auth", + "device_authorization_endpoint": "%s/device", "token_endpoint": "%s/token", "jwks_uri": "%s/certs", "userinfo_endpoint": "%s/userinfo" @@ -1477,21 +1486,38 @@ func (o *oidcProvider) ServeHTTP(w http.ResponseWriter, r *http.Request) { w.WriteHeader(404) case "/certs_invalid": w.Write([]byte("It's not a keyset!")) + case "/device": + values := map[string]interface{}{ + "device_code": o.code, + } + data, err := json.Marshal(values) + if err != nil { + o.t.Fatal(err) + } + w.Write(data) case "/token": - code := r.FormValue("code") - codeVerifier := r.FormValue("code_verifier") + var code string + grant_type := r.FormValue("grant_type") + if grant_type == "urn:ietf:params:oauth:grant-type:device_code" { + code = r.FormValue("device_code") + } else { + code = r.FormValue("code") + } if code != o.code { w.WriteHeader(401) break } - sum := sha256.Sum256([]byte(codeVerifier)) - computedChallenge := base64.RawURLEncoding.EncodeToString(sum[:]) + if o.codeChallenge != "" { + codeVerifier := r.FormValue("code_verifier") + sum := sha256.Sum256([]byte(codeVerifier)) + computedChallenge := base64.RawURLEncoding.EncodeToString(sum[:]) - if computedChallenge != o.codeChallenge { - w.WriteHeader(401) - break + if computedChallenge != o.codeChallenge { + w.WriteHeader(401) + break + } } stdClaims := jwt.Claims{ diff --git a/path_role.go b/path_role.go index 6e1b33e6..8c5f56c5 100644 --- a/path_role.go +++ b/path_role.go @@ -26,6 +26,7 @@ const ( boundClaimsTypeGlob = "glob" callbackModeDirect = "direct" callbackModeClient = "client" + callbackModeDevice = "device" ) func pathRoleList(b *jwtAuthBackend) *framework.Path { @@ -158,9 +159,14 @@ for referencing claims.`, }, "callback_mode": { Type: framework.TypeString, - Description: `OIDC callback mode from Authorization Server: allowed values are 'direct' to Vault or 'client', default 'client'`, + Description: `OIDC callback mode from Authorization Server: allowed values are 'device' for device flow, 'direct' to Vault, or 'client', default 'client'`, Default: callbackModeClient, }, + "poll_interval": { + Type: framework.TypeInt, + Description: `poll interval in seconds for device and direct flows, default value from Authorization Server for device flow, or '5'`, + // don't set Default here because server may set a default + }, "verbose_oidc_logging": { Type: framework.TypeBool, Description: `Log received OIDC tokens and claims when debug-level logging is active. @@ -230,6 +236,7 @@ type jwtRole struct { OIDCScopes []string `json:"oidc_scopes"` AllowedRedirectURIs []string `json:"allowed_redirect_uris"` CallbackMode string `json:"callback_mode"` + PollInterval int `json:"poll_interval"` VerboseOIDCLogging bool `json:"verbose_oidc_logging"` MaxAge time.Duration `json:"max_age"` UserClaimJSONPointer bool `json:"user_claim_json_pointer"` @@ -346,6 +353,10 @@ func (b *jwtAuthBackend) pathRoleRead(ctx context.Context, req *logical.Request, role.PopulateTokenData(d) + if role.PollInterval > 0 { + d["poll_interval"] = role.PollInterval + } + if len(role.Policies) > 0 { d["policies"] = d["token_policies"] } @@ -564,9 +575,13 @@ func (b *jwtAuthBackend) pathRoleCreateUpdate(ctx context.Context, req *logical. role.AllowedRedirectURIs = allowedRedirectURIs.([]string) } + if pollInterval, ok := data.GetOk("poll_interval"); ok { + role.PollInterval = pollInterval.(int) + } + callbackMode := data.Get("callback_mode").(string) switch callbackMode { - case callbackModeDirect, callbackModeClient: + case callbackModeDevice, callbackModeDirect, callbackModeClient: role.CallbackMode = callbackMode default: return logical.ErrorResponse("invalid 'callback_mode': %s", callbackMode), nil