diff --git a/common/component/kafka/auth_test.go b/common/component/kafka/auth_test.go index 41afc5f634..10c0bc1a15 100644 --- a/common/component/kafka/auth_test.go +++ b/common/component/kafka/auth_test.go @@ -18,14 +18,20 @@ import ( "crypto/rsa" "crypto/x509" "crypto/x509/pkix" + "encoding/base64" + "encoding/json" "encoding/pem" "fmt" "math/big" + "net/http" + "net/http/httptest" + "strings" "testing" "time" "github.com/IBM/sarama" "github.com/stretchr/testify/require" + "golang.org/x/oauth2" ) func getAuthBaseMetadata() map[string]string { @@ -122,4 +128,53 @@ func TestAuth(t *testing.T) { require.False(t, mockConfig.Net.TLS.Enable) require.Nil(t, mockConfig.Net.TLS.Config) }) + + t.Run("oidc private key jwt uses flattened audience", func(t *testing.T) { + key, err := rsa.GenerateKey(rand.Reader, 2048) + require.NoError(t, err) + keyPEM := pem.EncodeToMemory(&pem.Block{ + Type: "RSA PRIVATE KEY", + Bytes: x509.MarshalPKCS1PrivateKey(key), + }) + certPEM, _, err := createTestCert() + require.NoError(t, err) + + var receivedAssertion string + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + require.NoError(t, r.ParseForm()) + receivedAssertion = r.FormValue("client_assertion") + + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(map[string]interface{}{ + "access_token": "test-token", + "expires_in": 3600, + }) + })) + defer server.Close() + + ts := &OAuthTokenSourcePrivateKeyJWT{ + TokenEndpoint: oauth2.Endpoint{TokenURL: server.URL}, + ClientID: "test-client", + ClientAssertionCert: string(certPEM), + ClientAssertionKey: string(keyPEM), + } + + _, err = ts.Token() + require.NoError(t, err) + require.NotEmpty(t, receivedAssertion) + + parts := strings.Split(receivedAssertion, ".") + require.Len(t, parts, 3, "JWT should have 3 parts") + + decodedPayload, err := base64.RawURLEncoding.DecodeString(parts[1]) + require.NoError(t, err) + + var rawClaims map[string]interface{} + err = json.Unmarshal(decodedPayload, &rawClaims) + require.NoError(t, err) + + audValue := rawClaims["aud"] + require.IsType(t, "", audValue) + }) } diff --git a/common/component/kafka/sasl_oauthbearer_private_key_jwt.go b/common/component/kafka/sasl_oauthbearer_private_key_jwt.go index f7480c30d0..1c3cd4091e 100644 --- a/common/component/kafka/sasl_oauthbearer_private_key_jwt.go +++ b/common/component/kafka/sasl_oauthbearer_private_key_jwt.go @@ -169,6 +169,9 @@ func (ts *OAuthTokenSourcePrivateKeyJWT) Token() (*sarama.AccessToken, error) { return nil, fmt.Errorf("failed to build token: %w", err) } + // Some IdPs require the audience to be set as a single string + token.Options().Enable(jwt.FlattenAudience) + var signOptions []jwt.Option if ts.Kid != "" { headers := jws.NewHeaders()