Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 32 additions & 9 deletions cmd/drone-acr/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ import (
"github.com/sirupsen/logrus"

docker "github.com/drone-plugins/drone-docker"
azureutil "github.com/drone-plugins/drone-docker/internal/azure"
)

type subscriptionUrlResponse struct {
Expand Down Expand Up @@ -62,12 +63,14 @@ func main() {
password = getenv("SERVICE_PRINCIPAL_CLIENT_SECRET")

// Service principal credentials
clientId = getenv("CLIENT_ID")
clientSecret = getenv("CLIENT_SECRET")
clientCert = getenv("CLIENT_CERTIFICATE")
tenantId = getenv("TENANT_ID")
subscriptionId = getenv("SUBSCRIPTION_ID")
publicUrl = getenv("DAEMON_REGISTRY")
clientId = getenv("CLIENT_ID", "AZURE_CLIENT_ID", "AZURE_APP_ID", "PLUGIN_CLIENT_ID")
clientSecret = getenv("CLIENT_SECRET", "PLUGIN_CLIENT_SECRET")
clientCert = getenv("CLIENT_CERTIFICATE", "PLUGIN_CLIENT_CERTIFICATE")
tenantId = getenv("TENANT_ID", "AZURE_TENANT_ID", "PLUGIN_TENANT_ID")
subscriptionId = getenv("SUBSCRIPTION_ID", "PLUGIN_SUBSCRIPTION_ID")
publicUrl = getenv("DAEMON_REGISTRY", "PLUGIN_DAEMON_REGISTRY")
authorityHost = getenv("AZURE_AUTHORITY_HOST", "PLUGIN_AZURE_AUTHORITY_HOST")
idToken = getenv("PLUGIN_OIDC_TOKEN_ID")
)

// default registry value
Expand All @@ -80,9 +83,29 @@ func main() {
// docker login credentials are not provided
var err error
username = defaultUsername
password, publicUrl, err = getAuth(clientId, clientSecret, clientCert, tenantId, subscriptionId, registry)
if err != nil {
logrus.Fatal(err)
if idToken != "" && clientId != "" && tenantId != "" {
logrus.Debug("Using OIDC authentication flow")
var aadToken string
aadToken, err = azureutil.GetAADAccessTokenViaClientAssertion(context.Background(), tenantId, clientId, idToken, authorityHost)
if err != nil {
logrus.Fatal(err)
}
var p string
p, err = getPublicUrl(aadToken, registry, subscriptionId)
if err == nil {
publicUrl = p
} else {
fmt.Fprintf(os.Stderr, "failed to get public url with error: %s\n", err)
}
password, err = fetchACRToken(tenantId, aadToken, registry)
if err != nil {
logrus.Fatal(err)
}
} else {
password, publicUrl, err = getAuth(clientId, clientSecret, clientCert, tenantId, subscriptionId, registry)
if err != nil {
logrus.Fatal(err)
}
}
}

Expand Down
32 changes: 32 additions & 0 deletions cmd/drone-acr/main_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
package main

import (
"os"
"testing"
)

func TestGetAuthInputValidation(t *testing.T) {
// missing tenant
if _, _, err := getAuth("client", "secret", "", "", "sub", "registry.azurecr.io"); err == nil {
t.Fatalf("expected error for missing tenantId")
}
// missing clientId
if _, _, err := getAuth("", "secret", "", "tenant", "sub", "registry.azurecr.io"); err == nil {
t.Fatalf("expected error for missing clientId")
}
// missing both secret and cert
if _, _, err := getAuth("client", "", "", "tenant", "sub", "registry.azurecr.io"); err == nil {
t.Fatalf("expected error for missing credentials")
}
}

func TestGetenvAuthorityHost(t *testing.T) {
os.Setenv("AZURE_AUTHORITY_HOST", "https://login.microsoftonline.us")
defer os.Unsetenv("AZURE_AUTHORITY_HOST")

got := getenv("AZURE_AUTHORITY_HOST")
if got != "https://login.microsoftonline.us" {
t.Fatalf("expected AZURE_AUTHORITY_HOST to be returned, got %q", got)
}
}

75 changes: 75 additions & 0 deletions internal/azure/tokenutil.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
package azure

import (
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"net/url"
"strings"
"time"
)

const DefaultResource = "https://management.azure.com/"
const defaultAuthorityHost = "https://login.microsoftonline.com"
const defaultHTTPTimeout = 30 * time.Second

// GetAADAccessTokenViaClientAssertion exchanges an external OIDC ID token for an Azure AD access token

func GetAADAccessTokenViaClientAssertion(ctx context.Context, tenantID, clientID, oidcToken, authorityHost string) (string, error) {
resource := DefaultResource

form := url.Values{
"client_id": {clientID},
"scope": {resource + ".default"},
"grant_type": {"client_credentials"},
"client_assertion_type": {"urn:ietf:params:oauth:client-assertion-type:jwt-bearer"},
"client_assertion": {oidcToken},
}

base := authorityHost
if strings.TrimSpace(base) == "" {
base = defaultAuthorityHost
}
base = strings.TrimRight(base, "/")
endpoint := fmt.Sprintf("%s/%s/oauth2/v2.0/token", base, tenantID)

client := &http.Client{Timeout: defaultHTTPTimeout}
req, err := http.NewRequestWithContext(ctx, http.MethodPost, endpoint, strings.NewReader(form.Encode()))
if err != nil {
return "", err
}
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
req.Header.Set("Accept", "application/json")

resp, err := client.Do(req)
if err != nil {
return "", err
}
defer resp.Body.Close()
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
var aadErr struct {
Error string `json:"error"`
ErrorDescription string `json:"error_description"`
}
limited := io.LimitedReader{R: resp.Body, N: 4096}
_ = json.NewDecoder(&limited).Decode(&aadErr)
if aadErr.Error != "" {
return "", fmt.Errorf("AAD token request failed: status=%d, error=%s", resp.StatusCode, aadErr.Error)
}
return "", fmt.Errorf("AAD token request failed: status=%d", resp.StatusCode)
}
var payload struct {
AccessToken string `json:"access_token"`
TokenType string `json:"token_type"`
ExpiresIn int `json:"expires_in"`
}
if err := json.NewDecoder(resp.Body).Decode(&payload); err != nil {
return "", err
}
if payload.AccessToken == "" {
return "", fmt.Errorf("AAD token response missing access_token")
}
return payload.AccessToken, nil
}
104 changes: 104 additions & 0 deletions internal/azure/tokenutil_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
package azure

import (
"context"
"net/http"
"net/http/httptest"
"strings"
"testing"
)

func TestGetAADAccessTokenViaClientAssertion_Success(t *testing.T) {

ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
t.Fatalf("expected POST, got %s", r.Method)
}
if ct := r.Header.Get("Content-Type"); !strings.Contains(ct, "application/x-www-form-urlencoded") {
t.Fatalf("expected form content-type, got %s", ct)
}
if err := r.ParseForm(); err != nil {
t.Fatalf("failed parsing form: %v", err)
}
assertEq(t, r.Form.Get("client_id"), "client")
assertEq(t, r.Form.Get("grant_type"), "client_credentials")
assertEq(t, r.Form.Get("client_assertion_type"), "urn:ietf:params:oauth:client-assertion-type:jwt-bearer")
assertEq(t, r.Form.Get("client_assertion"), "idtoken")
assertEq(t, r.Form.Get("scope"), DefaultResource+".default")

w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
_, _ = w.Write([]byte(`{"access_token":"AT","token_type":"Bearer","expires_in":3600}`))
}))
defer ts.Close()

tok, err := GetAADAccessTokenViaClientAssertion(context.Background(), "tenant", "client", "idtoken", ts.URL)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if tok != "AT" {
t.Fatalf("expected access token AT, got %q", tok)
}
}

func TestGetAADAccessTokenViaClientAssertion_400WithErrorField(t *testing.T) {
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusBadRequest)
_, _ = w.Write([]byte(`{"error":"invalid_client","error_description":"bad"}`))
}))
defer ts.Close()

_, err := GetAADAccessTokenViaClientAssertion(context.Background(), "tenant", "client", "idtoken", ts.URL)
if err == nil || !strings.Contains(err.Error(), "status=400") || !strings.Contains(err.Error(), "invalid_client") {
t.Fatalf("expected 400 with invalid_client error, got %v", err)
}
}

func TestGetAADAccessTokenViaClientAssertion_400WithoutErrorField(t *testing.T) {
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusBadRequest)
_, _ = w.Write([]byte("{}"))
}))
defer ts.Close()

_, err := GetAADAccessTokenViaClientAssertion(context.Background(), "tenant", "client", "idtoken", ts.URL)
if err == nil || !strings.Contains(err.Error(), "status=400") {
t.Fatalf("expected 400 error, got %v", err)
}
}

func TestGetAADAccessTokenViaClientAssertion_MalformedJSON(t *testing.T) {
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
_, _ = w.Write([]byte("not-json"))
}))
defer ts.Close()

_, err := GetAADAccessTokenViaClientAssertion(context.Background(), "tenant", "client", "idtoken", ts.URL)
if err == nil {
t.Fatalf("expected JSON decode error, got nil")
}
}

func TestGetAADAccessTokenViaClientAssertion_MissingAccessToken(t *testing.T) {
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
_, _ = w.Write([]byte(`{"token_type":"Bearer","expires_in":3600}`))
}))
defer ts.Close()

_, err := GetAADAccessTokenViaClientAssertion(context.Background(), "tenant", "client", "idtoken", ts.URL)
if err == nil || !strings.Contains(err.Error(), "missing access_token") {
t.Fatalf("expected missing access_token error, got %v", err)
}
}

func assertEq(t *testing.T, got, want string) {
t.Helper()
if got != want {
t.Fatalf("mismatch: got=%q want=%q", got, want)
}
}