Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 55 additions & 0 deletions common/component/kafka/auth_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,20 @@ import (
"crypto/rsa"
"crypto/x509"
"crypto/x509/pkix"
"encoding/base64"
"encoding/json"
"encoding/pem"
"fmt"
"math/big"
"net/http"
"net/http/httptest"
"strings"
"testing"
"time"

"github.com/IBM/sarama"
"github.com/stretchr/testify/require"
"golang.org/x/oauth2"
)

func getAuthBaseMetadata() map[string]string {
Expand Down Expand Up @@ -122,4 +128,53 @@ func TestAuth(t *testing.T) {
require.False(t, mockConfig.Net.TLS.Enable)
require.Nil(t, mockConfig.Net.TLS.Config)
})

t.Run("oidc private key jwt uses flattened audience", func(t *testing.T) {
key, err := rsa.GenerateKey(rand.Reader, 2048)
require.NoError(t, err)
keyPEM := pem.EncodeToMemory(&pem.Block{
Type: "RSA PRIVATE KEY",
Bytes: x509.MarshalPKCS1PrivateKey(key),
})
certPEM, _, err := createTestCert()
require.NoError(t, err)

var receivedAssertion string

server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
require.NoError(t, r.ParseForm())
receivedAssertion = r.FormValue("client_assertion")

w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(map[string]interface{}{
"access_token": "test-token",
"expires_in": 3600,
})
}))
defer server.Close()

ts := &OAuthTokenSourcePrivateKeyJWT{
TokenEndpoint: oauth2.Endpoint{TokenURL: server.URL},
ClientID: "test-client",
ClientAssertionCert: string(certPEM),
ClientAssertionKey: string(keyPEM),
}

_, err = ts.Token()
require.NoError(t, err)
require.NotEmpty(t, receivedAssertion)

parts := strings.Split(receivedAssertion, ".")
require.Len(t, parts, 3, "JWT should have 3 parts")

decodedPayload, err := base64.RawURLEncoding.DecodeString(parts[1])
require.NoError(t, err)

var rawClaims map[string]interface{}
err = json.Unmarshal(decodedPayload, &rawClaims)
require.NoError(t, err)

audValue := rawClaims["aud"]
require.IsType(t, "", audValue)
})
}
3 changes: 3 additions & 0 deletions common/component/kafka/sasl_oauthbearer_private_key_jwt.go
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,9 @@ func (ts *OAuthTokenSourcePrivateKeyJWT) Token() (*sarama.AccessToken, error) {
return nil, fmt.Errorf("failed to build token: %w", err)
}

// Some IdPs require the audience to be set as a single string
token.Options().Enable(jwt.FlattenAudience)

var signOptions []jwt.Option
if ts.Kid != "" {
headers := jws.NewHeaders()
Expand Down
Loading