diff --git a/cmd/drone-acr/main.go b/cmd/drone-acr/main.go index 3205912a..4688f8dd 100644 --- a/cmd/drone-acr/main.go +++ b/cmd/drone-acr/main.go @@ -20,6 +20,7 @@ import ( "github.com/sirupsen/logrus" docker "github.com/drone-plugins/drone-docker" + azureutil "github.com/drone-plugins/drone-docker/internal/azure" ) type subscriptionUrlResponse struct { @@ -62,12 +63,14 @@ func main() { password = getenv("SERVICE_PRINCIPAL_CLIENT_SECRET") // Service principal credentials - clientId = getenv("CLIENT_ID") - clientSecret = getenv("CLIENT_SECRET") - clientCert = getenv("CLIENT_CERTIFICATE") - tenantId = getenv("TENANT_ID") - subscriptionId = getenv("SUBSCRIPTION_ID") - publicUrl = getenv("DAEMON_REGISTRY") + clientId = getenv("CLIENT_ID", "AZURE_CLIENT_ID", "AZURE_APP_ID", "PLUGIN_CLIENT_ID") + clientSecret = getenv("CLIENT_SECRET", "PLUGIN_CLIENT_SECRET") + clientCert = getenv("CLIENT_CERTIFICATE", "PLUGIN_CLIENT_CERTIFICATE") + tenantId = getenv("TENANT_ID", "AZURE_TENANT_ID", "PLUGIN_TENANT_ID") + subscriptionId = getenv("SUBSCRIPTION_ID", "PLUGIN_SUBSCRIPTION_ID") + publicUrl = getenv("DAEMON_REGISTRY", "PLUGIN_DAEMON_REGISTRY") + authorityHost = getenv("AZURE_AUTHORITY_HOST", "PLUGIN_AZURE_AUTHORITY_HOST") + idToken = getenv("PLUGIN_OIDC_TOKEN_ID") ) // default registry value @@ -80,9 +83,29 @@ func main() { // docker login credentials are not provided var err error username = defaultUsername - password, publicUrl, err = getAuth(clientId, clientSecret, clientCert, tenantId, subscriptionId, registry) - if err != nil { - logrus.Fatal(err) + if idToken != "" && clientId != "" && tenantId != "" { + logrus.Debug("Using OIDC authentication flow") + var aadToken string + aadToken, err = azureutil.GetAADAccessTokenViaClientAssertion(context.Background(), tenantId, clientId, idToken, authorityHost) + if err != nil { + logrus.Fatal(err) + } + var p string + p, err = getPublicUrl(aadToken, registry, subscriptionId) + if err == nil { + publicUrl = p + } else { + fmt.Fprintf(os.Stderr, "failed to get public url with error: %s\n", err) + } + password, err = fetchACRToken(tenantId, aadToken, registry) + if err != nil { + logrus.Fatal(err) + } + } else { + password, publicUrl, err = getAuth(clientId, clientSecret, clientCert, tenantId, subscriptionId, registry) + if err != nil { + logrus.Fatal(err) + } } } diff --git a/cmd/drone-acr/main_test.go b/cmd/drone-acr/main_test.go new file mode 100644 index 00000000..7fa62cf8 --- /dev/null +++ b/cmd/drone-acr/main_test.go @@ -0,0 +1,32 @@ +package main + +import ( + "os" + "testing" +) + +func TestGetAuthInputValidation(t *testing.T) { + // missing tenant + if _, _, err := getAuth("client", "secret", "", "", "sub", "registry.azurecr.io"); err == nil { + t.Fatalf("expected error for missing tenantId") + } + // missing clientId + if _, _, err := getAuth("", "secret", "", "tenant", "sub", "registry.azurecr.io"); err == nil { + t.Fatalf("expected error for missing clientId") + } + // missing both secret and cert + if _, _, err := getAuth("client", "", "", "tenant", "sub", "registry.azurecr.io"); err == nil { + t.Fatalf("expected error for missing credentials") + } +} + +func TestGetenvAuthorityHost(t *testing.T) { + os.Setenv("AZURE_AUTHORITY_HOST", "https://login.microsoftonline.us") + defer os.Unsetenv("AZURE_AUTHORITY_HOST") + + got := getenv("AZURE_AUTHORITY_HOST") + if got != "https://login.microsoftonline.us" { + t.Fatalf("expected AZURE_AUTHORITY_HOST to be returned, got %q", got) + } +} + diff --git a/internal/azure/tokenutil.go b/internal/azure/tokenutil.go new file mode 100644 index 00000000..5d9ec822 --- /dev/null +++ b/internal/azure/tokenutil.go @@ -0,0 +1,75 @@ +package azure + +import ( + "context" + "encoding/json" + "fmt" + "io" + "net/http" + "net/url" + "strings" + "time" +) + +const DefaultResource = "https://management.azure.com/" +const defaultAuthorityHost = "https://login.microsoftonline.com" +const defaultHTTPTimeout = 30 * time.Second + +// GetAADAccessTokenViaClientAssertion exchanges an external OIDC ID token for an Azure AD access token + +func GetAADAccessTokenViaClientAssertion(ctx context.Context, tenantID, clientID, oidcToken, authorityHost string) (string, error) { + resource := DefaultResource + + form := url.Values{ + "client_id": {clientID}, + "scope": {resource + ".default"}, + "grant_type": {"client_credentials"}, + "client_assertion_type": {"urn:ietf:params:oauth:client-assertion-type:jwt-bearer"}, + "client_assertion": {oidcToken}, + } + + base := authorityHost + if strings.TrimSpace(base) == "" { + base = defaultAuthorityHost + } + base = strings.TrimRight(base, "/") + endpoint := fmt.Sprintf("%s/%s/oauth2/v2.0/token", base, tenantID) + + client := &http.Client{Timeout: defaultHTTPTimeout} + req, err := http.NewRequestWithContext(ctx, http.MethodPost, endpoint, strings.NewReader(form.Encode())) + if err != nil { + return "", err + } + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + req.Header.Set("Accept", "application/json") + + resp, err := client.Do(req) + if err != nil { + return "", err + } + defer resp.Body.Close() + if resp.StatusCode < 200 || resp.StatusCode >= 300 { + var aadErr struct { + Error string `json:"error"` + ErrorDescription string `json:"error_description"` + } + limited := io.LimitedReader{R: resp.Body, N: 4096} + _ = json.NewDecoder(&limited).Decode(&aadErr) + if aadErr.Error != "" { + return "", fmt.Errorf("AAD token request failed: status=%d, error=%s", resp.StatusCode, aadErr.Error) + } + return "", fmt.Errorf("AAD token request failed: status=%d", resp.StatusCode) + } + var payload struct { + AccessToken string `json:"access_token"` + TokenType string `json:"token_type"` + ExpiresIn int `json:"expires_in"` + } + if err := json.NewDecoder(resp.Body).Decode(&payload); err != nil { + return "", err + } + if payload.AccessToken == "" { + return "", fmt.Errorf("AAD token response missing access_token") + } + return payload.AccessToken, nil +} diff --git a/internal/azure/tokenutil_test.go b/internal/azure/tokenutil_test.go new file mode 100644 index 00000000..d87c79d7 --- /dev/null +++ b/internal/azure/tokenutil_test.go @@ -0,0 +1,104 @@ +package azure + +import ( + "context" + "net/http" + "net/http/httptest" + "strings" + "testing" +) + +func TestGetAADAccessTokenViaClientAssertion_Success(t *testing.T) { + + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodPost { + t.Fatalf("expected POST, got %s", r.Method) + } + if ct := r.Header.Get("Content-Type"); !strings.Contains(ct, "application/x-www-form-urlencoded") { + t.Fatalf("expected form content-type, got %s", ct) + } + if err := r.ParseForm(); err != nil { + t.Fatalf("failed parsing form: %v", err) + } + assertEq(t, r.Form.Get("client_id"), "client") + assertEq(t, r.Form.Get("grant_type"), "client_credentials") + assertEq(t, r.Form.Get("client_assertion_type"), "urn:ietf:params:oauth:client-assertion-type:jwt-bearer") + assertEq(t, r.Form.Get("client_assertion"), "idtoken") + assertEq(t, r.Form.Get("scope"), DefaultResource+".default") + + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte(`{"access_token":"AT","token_type":"Bearer","expires_in":3600}`)) + })) + defer ts.Close() + + tok, err := GetAADAccessTokenViaClientAssertion(context.Background(), "tenant", "client", "idtoken", ts.URL) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if tok != "AT" { + t.Fatalf("expected access token AT, got %q", tok) + } +} + +func TestGetAADAccessTokenViaClientAssertion_400WithErrorField(t *testing.T) { + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusBadRequest) + _, _ = w.Write([]byte(`{"error":"invalid_client","error_description":"bad"}`)) + })) + defer ts.Close() + + _, err := GetAADAccessTokenViaClientAssertion(context.Background(), "tenant", "client", "idtoken", ts.URL) + if err == nil || !strings.Contains(err.Error(), "status=400") || !strings.Contains(err.Error(), "invalid_client") { + t.Fatalf("expected 400 with invalid_client error, got %v", err) + } +} + +func TestGetAADAccessTokenViaClientAssertion_400WithoutErrorField(t *testing.T) { + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusBadRequest) + _, _ = w.Write([]byte("{}")) + })) + defer ts.Close() + + _, err := GetAADAccessTokenViaClientAssertion(context.Background(), "tenant", "client", "idtoken", ts.URL) + if err == nil || !strings.Contains(err.Error(), "status=400") { + t.Fatalf("expected 400 error, got %v", err) + } +} + +func TestGetAADAccessTokenViaClientAssertion_MalformedJSON(t *testing.T) { + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte("not-json")) + })) + defer ts.Close() + + _, err := GetAADAccessTokenViaClientAssertion(context.Background(), "tenant", "client", "idtoken", ts.URL) + if err == nil { + t.Fatalf("expected JSON decode error, got nil") + } +} + +func TestGetAADAccessTokenViaClientAssertion_MissingAccessToken(t *testing.T) { + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte(`{"token_type":"Bearer","expires_in":3600}`)) + })) + defer ts.Close() + + _, err := GetAADAccessTokenViaClientAssertion(context.Background(), "tenant", "client", "idtoken", ts.URL) + if err == nil || !strings.Contains(err.Error(), "missing access_token") { + t.Fatalf("expected missing access_token error, got %v", err) + } +} + +func assertEq(t *testing.T, got, want string) { + t.Helper() + if got != want { + t.Fatalf("mismatch: got=%q want=%q", got, want) + } +} +