diff --git a/internal/envconfig/xds.go b/internal/envconfig/xds.go index b1f883bcac1e..7685d08b54de 100644 --- a/internal/envconfig/xds.go +++ b/internal/envconfig/xds.go @@ -74,4 +74,9 @@ var ( // For more details, see: // https://github.com/grpc/proposal/blob/master/A86-xds-http-connect.md XDSHTTPConnectEnabled = boolFromEnv("GRPC_EXPERIMENTAL_XDS_HTTP_CONNECT", false) + + // XDSBootstrapCallCredsEnabled controls if call credentials can be used in + // xDS bootstrap configuration via the `call_creds` field. For more details, + // see: https://github.com/grpc/proposal/blob/master/A97-xds-jwt-call-creds.md + XDSBootstrapCallCredsEnabled = boolFromEnv("GRPC_EXPERIMENTAL_XDS_BOOTSTRAP_CALL_CREDS", false) ) diff --git a/internal/xds/bootstrap/bootstrap.go b/internal/xds/bootstrap/bootstrap.go index 4278702ec0c7..e718877bde44 100644 --- a/internal/xds/bootstrap/bootstrap.go +++ b/internal/xds/bootstrap/bootstrap.go @@ -31,6 +31,7 @@ import ( "strings" "google.golang.org/grpc" + "google.golang.org/grpc/credentials" "google.golang.org/grpc/credentials/tls/certprovider" "google.golang.org/grpc/internal" "google.golang.org/grpc/internal/envconfig" @@ -83,6 +84,40 @@ func (cc ChannelCreds) String() string { return cc.Type + "-" + string(b) } +// CallCredsConfig contains the call credentials configuration to be used on +// RPCs to the management server. +type CallCredsConfig struct { + // Type contains a name identifying the call credentials type. + Type string `json:"type,omitempty"` + // Config contains the JSON configuration for this call credentials. + Config json.RawMessage `json:"config,omitempty"` +} + +// Equal reports whether cc and other are considered equal. +func (cc CallCredsConfig) Equal(other CallCredsConfig) bool { + return cc.Type == other.Type && bytes.Equal(cc.Config, other.Config) +} + +func (cc CallCredsConfig) String() string { + if cc.Config == nil { + return cc.Type + } + // We do not expect the Marshal call to fail since we wrote to cc.Config. + b, _ := json.Marshal(cc.Config) + return cc.Type + "-" + string(b) +} + +// CallCredsConfigs represents a collection of call credentials configurations. +type CallCredsConfigs []CallCredsConfig + +func (ccs CallCredsConfigs) String() string { + var creds []string + for _, cc := range ccs { + creds = append(creds, cc.String()) + } + return strings.Join(creds, ",") +} + // ServerConfigs represents a collection of server configurations. type ServerConfigs []*ServerConfig @@ -163,16 +198,20 @@ func (a *Authority) Equal(other *Authority) bool { // ServerConfig contains the configuration to connect to a server. type ServerConfig struct { - serverURI string - channelCreds []ChannelCreds - serverFeatures []string + serverURI string + // TODO: rename ChannelCreds to ChannelCredsConfigs for consistency with + // CallCredsConfigs. + channelCreds []ChannelCreds + callCredsConfigs []CallCredsConfig + serverFeatures []string // As part of unmarshalling the JSON config into this struct, we ensure that // the credentials config is valid by building an instance of the specified // credentials and store it here for easy access. - selectedCreds ChannelCreds - credsDialOption grpc.DialOption - extraDialOptions []grpc.DialOption + selectedChannelCreds ChannelCreds + selectedCallCreds []credentials.PerRPCCredentials + credsDialOption grpc.DialOption + extraDialOptions []grpc.DialOption cleanups []func() } @@ -194,6 +233,11 @@ func (sc *ServerConfig) ServerFeatures() []string { return sc.serverFeatures } +// CallCredsConfigs returns the call credentials configuration for this server. +func (sc *ServerConfig) CallCredsConfigs() CallCredsConfigs { + return sc.callCredsConfigs +} + // ServerFeaturesIgnoreResourceDeletion returns true if this server supports a // feature where the xDS client can ignore resource deletions from this server, // as described in gRFC A53. @@ -211,10 +255,10 @@ func (sc *ServerConfig) ServerFeaturesIgnoreResourceDeletion() bool { return false } -// SelectedCreds returns the selected credentials configuration for +// SelectedChannelCreds returns the selected credentials configuration for // communicating with this server. -func (sc *ServerConfig) SelectedCreds() ChannelCreds { - return sc.selectedCreds +func (sc *ServerConfig) SelectedChannelCreds() ChannelCreds { + return sc.selectedChannelCreds } // DialOptions returns a slice of all the configured dial options for this @@ -245,9 +289,11 @@ func (sc *ServerConfig) Equal(other *ServerConfig) bool { return false case !slices.EqualFunc(sc.channelCreds, other.channelCreds, func(a, b ChannelCreds) bool { return a.Equal(b) }): return false + case !slices.EqualFunc(sc.callCredsConfigs, other.callCredsConfigs, func(a, b CallCredsConfig) bool { return a.Equal(b) }): + return false case !slices.Equal(sc.serverFeatures, other.serverFeatures): return false - case !sc.selectedCreds.Equal(other.selectedCreds): + case !sc.selectedChannelCreds.Equal(other.selectedChannelCreds): return false } return true @@ -256,25 +302,27 @@ func (sc *ServerConfig) Equal(other *ServerConfig) bool { // String returns the string representation of the ServerConfig. func (sc *ServerConfig) String() string { if len(sc.serverFeatures) == 0 { - return fmt.Sprintf("%s-%s", sc.serverURI, sc.selectedCreds.String()) + return strings.Join([]string{sc.serverURI, sc.selectedChannelCreds.String(), sc.CallCredsConfigs().String()}, "-") } features := strings.Join(sc.serverFeatures, "-") - return strings.Join([]string{sc.serverURI, sc.selectedCreds.String(), features}, "-") + return strings.Join([]string{sc.serverURI, sc.selectedChannelCreds.String(), features, sc.CallCredsConfigs().String()}, "-") } // The following fields correspond 1:1 with the JSON schema for ServerConfig. type serverConfigJSON struct { - ServerURI string `json:"server_uri,omitempty"` - ChannelCreds []ChannelCreds `json:"channel_creds,omitempty"` - ServerFeatures []string `json:"server_features,omitempty"` + ServerURI string `json:"server_uri,omitempty"` + ChannelCreds []ChannelCreds `json:"channel_creds,omitempty"` + CallCredsConfigs []CallCredsConfig `json:"call_creds,omitempty"` + ServerFeatures []string `json:"server_features,omitempty"` } // MarshalJSON returns marshaled JSON bytes corresponding to this server config. func (sc *ServerConfig) MarshalJSON() ([]byte, error) { server := &serverConfigJSON{ - ServerURI: sc.serverURI, - ChannelCreds: sc.channelCreds, - ServerFeatures: sc.serverFeatures, + ServerURI: sc.serverURI, + ChannelCreds: sc.channelCreds, + CallCredsConfigs: sc.callCredsConfigs, + ServerFeatures: sc.serverFeatures, } return json.Marshal(server) } @@ -294,11 +342,12 @@ func (sc *ServerConfig) UnmarshalJSON(data []byte) error { sc.serverURI = server.ServerURI sc.channelCreds = server.ChannelCreds + sc.callCredsConfigs = server.CallCredsConfigs sc.serverFeatures = server.ServerFeatures for _, cc := range server.ChannelCreds { // We stop at the first credential type that we support. - c := bootstrap.GetCredentials(cc.Type) + c := bootstrap.GetChannelCredentials(cc.Type) if c == nil { continue } @@ -306,7 +355,7 @@ func (sc *ServerConfig) UnmarshalJSON(data []byte) error { if err != nil { return fmt.Errorf("failed to build credentials bundle from bootstrap for %q: %v", cc.Type, err) } - sc.selectedCreds = cc + sc.selectedChannelCreds = cc sc.credsDialOption = grpc.WithCredentialsBundle(bundle) if d, ok := bundle.(extraDialOptions); ok { sc.extraDialOptions = d.DialOptions() @@ -314,6 +363,30 @@ func (sc *ServerConfig) UnmarshalJSON(data []byte) error { sc.cleanups = append(sc.cleanups, cancel) break } + + if envconfig.XDSBootstrapCallCredsEnabled { + // Process call credentials - unlike channel creds, we use ALL supported + // types. Also, call credentials are optional as per gRFC A97. + for _, cfg := range server.CallCredsConfigs { + c := bootstrap.GetCallCredentials(cfg.Type) + if c == nil { + // Skip unsupported call credential types (don't fail bootstrap). + continue + } + callCreds, cancel, err := c.Build(cfg.Config) + if err != nil { + // Call credential validation failed - this should fail bootstrap. + return fmt.Errorf("failed to build call credentials from bootstrap for %q: %v", cfg.Type, err) + } + if callCreds == nil { + continue + } + sc.selectedCallCreds = append(sc.selectedCallCreds, callCreds) + sc.extraDialOptions = append(sc.extraDialOptions, grpc.WithPerRPCCredentials(callCreds)) + sc.cleanups = append(sc.cleanups, cancel) + } + } + if sc.serverURI == "" { return fmt.Errorf("xds: `server_uri` field in server config cannot be empty: %s", string(data)) } @@ -333,6 +406,9 @@ type ServerConfigTestingOptions struct { // ChannelCreds contains a list of channel credentials to use when talking // to this server. If unspecified, `insecure` credentials will be used. ChannelCreds []ChannelCreds + // CallCredsConfigs contains a list of call credentials to use for individual RPCs + // to this server. Optional. + CallCredsConfigs []CallCredsConfig // ServerFeatures represents the list of features supported by this server. ServerFeatures []string } @@ -347,9 +423,10 @@ func ServerConfigForTesting(opts ServerConfigTestingOptions) (*ServerConfig, err cc = []ChannelCreds{{Type: "insecure"}} } scInternal := &serverConfigJSON{ - ServerURI: opts.URI, - ChannelCreds: cc, - ServerFeatures: opts.ServerFeatures, + ServerURI: opts.URI, + ChannelCreds: cc, + CallCredsConfigs: opts.CallCredsConfigs, + ServerFeatures: opts.ServerFeatures, } scJSON, err := json.Marshal(scInternal) if err != nil { diff --git a/internal/xds/bootstrap/bootstrap_test.go b/internal/xds/bootstrap/bootstrap_test.go index 5d5ed90f03de..773cc870005d 100644 --- a/internal/xds/bootstrap/bootstrap_test.go +++ b/internal/xds/bootstrap/bootstrap_test.go @@ -28,10 +28,13 @@ import ( v3corepb "github.com/envoyproxy/go-control-plane/envoy/config/core/v3" "github.com/google/go-cmp/cmp" "google.golang.org/grpc" + "google.golang.org/grpc/credentials" + "google.golang.org/grpc/credentials/jwt" "google.golang.org/grpc/credentials/tls/certprovider" "google.golang.org/grpc/internal" "google.golang.org/grpc/internal/envconfig" "google.golang.org/grpc/internal/grpctest" + "google.golang.org/grpc/internal/testutils" "google.golang.org/grpc/xds/bootstrap" "google.golang.org/protobuf/testing/protocmp" "google.golang.org/protobuf/types/known/structpb" @@ -196,6 +199,74 @@ var ( "server_features" : ["ignore_resource_deletion", "xds_v3"] }] }`, + // example data seeded from + // https://github.com/istio/istio/blob/master/pkg/istio-agent/testdata/grpc-bootstrap.json + "istioStyleInsecureWithJWTCallCreds": ` + { + "node": { + "id": "sidecar~127.0.0.1~pod1.fake-namespace~fake-namespace.svc.cluster.local", + "metadata": { + "GENERATOR": "grpc", + "INSTANCE_IPS": "127.0.0.1", + "ISTIO_VERSION": "1.26.2", + "WORKLOAD_IDENTITY_SOCKET_FILE": "socket" + }, + "locality": {} + }, + "xds_servers" : [{ + "server_uri": "unix:///etc/istio/XDS", + "channel_creds": [ + { "type": "insecure" } + ], + "call_creds": [ + { "type": "jwt_token_file", "config": {"jwt_token_file": "/var/run/secrets/tokens/istio-token"} } + ], + "server_features" : ["xds_v3"] + }] + }`, + "istioStyleInsecureWithoutCallCreds": ` + { + "node": { + "id": "sidecar~127.0.0.1~pod1.fake-namespace~fake-namespace.svc.cluster.local", + "metadata": { + "GENERATOR": "grpc", + "INSTANCE_IPS": "127.0.0.1", + "ISTIO_VERSION": "1.26.2", + "WORKLOAD_IDENTITY_SOCKET_FILE": "socket" + }, + "locality": {} + }, + "xds_servers" : [{ + "server_uri": "unix:///etc/istio/XDS", + "channel_creds": [ + { "type": "insecure" } + ], + "server_features" : ["xds_v3"] + }] + }`, + "istioStyleWithTLSAndJWT": ` + { + "node": { + "id": "sidecar~127.0.0.1~pod1.fake-namespace~fake-namespace.svc.cluster.local", + "metadata": { + "GENERATOR": "grpc", + "INSTANCE_IPS": "127.0.0.1", + "ISTIO_VERSION": "1.26.2", + "WORKLOAD_IDENTITY_SOCKET_FILE": "socket" + }, + "locality": {} + }, + "xds_servers" : [{ + "server_uri": "unix:///etc/istio/XDS", + "channel_creds": [ + { "type": "tls", "config": {} } + ], + "call_creds": [ + { "type": "jwt_token_file", "config": {"jwt_token_file": "/var/run/secrets/tokens/istio-token"} } + ], + "server_features" : ["xds_v3"] + }] + }`, } metadata = &structpb.Struct{ Fields: map[string]*structpb.Value{ @@ -213,29 +284,29 @@ var ( } configWithInsecureCreds = &Config{ xDSServers: []*ServerConfig{{ - serverURI: "trafficdirector.googleapis.com:443", - channelCreds: []ChannelCreds{{Type: "insecure"}}, - selectedCreds: ChannelCreds{Type: "insecure"}, + serverURI: "trafficdirector.googleapis.com:443", + channelCreds: []ChannelCreds{{Type: "insecure"}}, + selectedChannelCreds: ChannelCreds{Type: "insecure"}, }}, node: v3Node, clientDefaultListenerResourceNameTemplate: "%s", } configWithMultipleChannelCredsAndV3 = &Config{ xDSServers: []*ServerConfig{{ - serverURI: "trafficdirector.googleapis.com:443", - channelCreds: []ChannelCreds{{Type: "not-google-default"}, {Type: "google_default"}}, - serverFeatures: []string{"xds_v3"}, - selectedCreds: ChannelCreds{Type: "google_default"}, + serverURI: "trafficdirector.googleapis.com:443", + channelCreds: []ChannelCreds{{Type: "not-google-default"}, {Type: "google_default"}}, + serverFeatures: []string{"xds_v3"}, + selectedChannelCreds: ChannelCreds{Type: "google_default"}, }}, node: v3Node, clientDefaultListenerResourceNameTemplate: "%s", } configWithGoogleDefaultCredsAndV3 = &Config{ xDSServers: []*ServerConfig{{ - serverURI: "trafficdirector.googleapis.com:443", - channelCreds: []ChannelCreds{{Type: "google_default"}}, - serverFeatures: []string{"xds_v3"}, - selectedCreds: ChannelCreds{Type: "google_default"}, + serverURI: "trafficdirector.googleapis.com:443", + channelCreds: []ChannelCreds{{Type: "google_default"}}, + serverFeatures: []string{"xds_v3"}, + selectedChannelCreds: ChannelCreds{Type: "google_default"}, }}, node: v3Node, clientDefaultListenerResourceNameTemplate: "%s", @@ -243,15 +314,15 @@ var ( configWithMultipleServers = &Config{ xDSServers: []*ServerConfig{ { - serverURI: "trafficdirector.googleapis.com:443", - channelCreds: []ChannelCreds{{Type: "google_default"}}, - serverFeatures: []string{"xds_v3"}, - selectedCreds: ChannelCreds{Type: "google_default"}, + serverURI: "trafficdirector.googleapis.com:443", + channelCreds: []ChannelCreds{{Type: "google_default"}}, + serverFeatures: []string{"xds_v3"}, + selectedChannelCreds: ChannelCreds{Type: "google_default"}, }, { - serverURI: "backup.never.use.com:1234", - channelCreds: []ChannelCreds{{Type: "google_default"}}, - selectedCreds: ChannelCreds{Type: "google_default"}, + serverURI: "backup.never.use.com:1234", + channelCreds: []ChannelCreds{{Type: "google_default"}}, + selectedChannelCreds: ChannelCreds{Type: "google_default"}, }, }, node: v3Node, @@ -259,23 +330,99 @@ var ( } configWithGoogleDefaultCredsAndIgnoreResourceDeletion = &Config{ xDSServers: []*ServerConfig{{ - serverURI: "trafficdirector.googleapis.com:443", - channelCreds: []ChannelCreds{{Type: "google_default"}}, - serverFeatures: []string{"ignore_resource_deletion", "xds_v3"}, - selectedCreds: ChannelCreds{Type: "google_default"}, + serverURI: "trafficdirector.googleapis.com:443", + channelCreds: []ChannelCreds{{Type: "google_default"}}, + serverFeatures: []string{"ignore_resource_deletion", "xds_v3"}, + selectedChannelCreds: ChannelCreds{Type: "google_default"}, }}, node: v3Node, clientDefaultListenerResourceNameTemplate: "%s", } configWithGoogleDefaultCredsAndNoServerFeatures = &Config{ xDSServers: []*ServerConfig{{ - serverURI: "trafficdirector.googleapis.com:443", - channelCreds: []ChannelCreds{{Type: "google_default"}}, - selectedCreds: ChannelCreds{Type: "google_default"}, + serverURI: "trafficdirector.googleapis.com:443", + channelCreds: []ChannelCreds{{Type: "google_default"}}, + selectedChannelCreds: ChannelCreds{Type: "google_default"}, }}, node: v3Node, clientDefaultListenerResourceNameTemplate: "%s", } + + istioNodeMetadata = &structpb.Struct{ + Fields: map[string]*structpb.Value{ + "GENERATOR": { + Kind: &structpb.Value_StringValue{StringValue: "grpc"}, + }, + "INSTANCE_IPS": { + Kind: &structpb.Value_StringValue{StringValue: "127.0.0.1"}, + }, + "ISTIO_VERSION": { + Kind: &structpb.Value_StringValue{StringValue: "1.26.2"}, + }, + "WORKLOAD_IDENTITY_SOCKET_FILE": { + Kind: &structpb.Value_StringValue{StringValue: "socket"}, + }, + }, + } + jwtCallCreds, _ = jwt.NewTokenFileCallCredentials("/var/run/secrets/tokens/istio-token") + selectedJWTCallCreds = []credentials.PerRPCCredentials{jwtCallCreds} + configWithIstioJWTCallCreds = &Config{ + xDSServers: []*ServerConfig{{ + serverURI: "unix:///etc/istio/XDS", + channelCreds: []ChannelCreds{{Type: "insecure"}}, + callCredsConfigs: []CallCredsConfig{{Type: "jwt_token_file", Config: json.RawMessage("{\n\"jwt_token_file\": \"/var/run/secrets/tokens/istio-token\"\n}")}}, + serverFeatures: []string{"xds_v3"}, + selectedChannelCreds: ChannelCreds{Type: "insecure"}, + selectedCallCreds: selectedJWTCallCreds, + }}, + node: node{ + ID: "sidecar~127.0.0.1~pod1.fake-namespace~fake-namespace.svc.cluster.local", + Metadata: istioNodeMetadata, + userAgentName: gRPCUserAgentName, + userAgentVersionType: userAgentVersion{UserAgentVersion: grpc.Version}, + clientFeatures: []string{clientFeatureNoOverprovisioning, clientFeatureResourceWrapper}, + }, + certProviderConfigs: map[string]*certprovider.BuildableConfig{}, + clientDefaultListenerResourceNameTemplate: "%s", + } + + configWithIstioStyleNoCallCreds = &Config{ + xDSServers: []*ServerConfig{{ + serverURI: "unix:///etc/istio/XDS", + channelCreds: []ChannelCreds{{Type: "insecure"}}, + serverFeatures: []string{"xds_v3"}, + selectedChannelCreds: ChannelCreds{Type: "insecure"}, + }}, + node: node{ + ID: "sidecar~127.0.0.1~pod1.fake-namespace~fake-namespace.svc.cluster.local", + Metadata: istioNodeMetadata, + userAgentName: gRPCUserAgentName, + userAgentVersionType: userAgentVersion{UserAgentVersion: grpc.Version}, + clientFeatures: []string{clientFeatureNoOverprovisioning, clientFeatureResourceWrapper}, + }, + certProviderConfigs: map[string]*certprovider.BuildableConfig{}, + clientDefaultListenerResourceNameTemplate: "%s", + } + + configWithIstioStyleWithTLSAndJWT = &Config{ + xDSServers: []*ServerConfig{{ + serverURI: "unix:///etc/istio/XDS", + channelCreds: []ChannelCreds{{Type: "tls", Config: json.RawMessage("{}")}}, + callCredsConfigs: []CallCredsConfig{{Type: "jwt_token_file", Config: json.RawMessage("{\n\"jwt_token_file\": \"/var/run/secrets/tokens/istio-token\"\n}")}}, + serverFeatures: []string{"xds_v3"}, + selectedChannelCreds: ChannelCreds{Type: "tls", Config: json.RawMessage("{}")}, + selectedCallCreds: selectedJWTCallCreds, + }}, + node: node{ + ID: "sidecar~127.0.0.1~pod1.fake-namespace~fake-namespace.svc.cluster.local", + Metadata: istioNodeMetadata, + userAgentName: gRPCUserAgentName, + userAgentVersionType: userAgentVersion{UserAgentVersion: grpc.Version}, + clientFeatures: []string{clientFeatureNoOverprovisioning, clientFeatureResourceWrapper}, + }, + certProviderConfigs: map[string]*certprovider.BuildableConfig{}, + clientDefaultListenerResourceNameTemplate: "%s", + } ) func fileReadFromFileMap(bootstrapFileMap map[string]string, name string) ([]byte, error) { @@ -413,9 +560,9 @@ func (s) TestGetConfiguration_Success(t *testing.T) { name: "emptyNodeProto", wantConfig: &Config{ xDSServers: []*ServerConfig{{ - serverURI: "trafficdirector.googleapis.com:443", - channelCreds: []ChannelCreds{{Type: "insecure"}}, - selectedCreds: ChannelCreds{Type: "insecure"}, + serverURI: "trafficdirector.googleapis.com:443", + channelCreds: []ChannelCreds{{Type: "insecure"}}, + selectedChannelCreds: ChannelCreds{Type: "insecure"}, }}, node: node{ userAgentName: gRPCUserAgentName, @@ -432,6 +579,29 @@ func (s) TestGetConfiguration_Success(t *testing.T) { {"goodBootstrap", configWithGoogleDefaultCredsAndV3}, {"multipleXDSServers", configWithMultipleServers}, {"serverSupportsIgnoreResourceDeletion", configWithGoogleDefaultCredsAndIgnoreResourceDeletion}, + {"istioStyleInsecureWithoutCallCreds", configWithIstioStyleNoCallCreds}, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + testGetConfigurationWithFileNameEnv(t, test.name, false, test.wantConfig) + testGetConfigurationWithFileContentEnv(t, test.name, false, test.wantConfig) + }) + } +} + +// Tests Istio-style bootstrap configurations with JWT call credentials. +func (s) TestGetConfiguration_IstioStyleWithCallCreds(t *testing.T) { + testutils.SetEnvConfig(t, &envconfig.XDSBootstrapCallCredsEnabled, true) + cancel := setupBootstrapOverride(v3BootstrapFileMap) + defer cancel() + + tests := []struct { + name string + wantConfig *Config + }{ + {"istioStyleInsecureWithJWTCallCreds", configWithIstioJWTCallCreds}, + {"istioStyleWithTLSAndJWT", configWithIstioStyleWithTLSAndJWT}, } for _, test := range tests { @@ -672,10 +842,10 @@ func (s) TestGetConfiguration_CertificateProviders(t *testing.T) { goodConfig := &Config{ xDSServers: []*ServerConfig{{ - serverURI: "trafficdirector.googleapis.com:443", - channelCreds: []ChannelCreds{{Type: "insecure"}}, - serverFeatures: []string{"xds_v3"}, - selectedCreds: ChannelCreds{Type: "insecure"}, + serverURI: "trafficdirector.googleapis.com:443", + channelCreds: []ChannelCreds{{Type: "insecure"}}, + serverFeatures: []string{"xds_v3"}, + selectedChannelCreds: ChannelCreds{Type: "insecure"}, }}, certProviderConfigs: map[string]*certprovider.BuildableConfig{ "fakeProviderInstance": wantCfg, @@ -766,9 +936,9 @@ func (s) TestGetConfiguration_ServerListenerResourceNameTemplate(t *testing.T) { name: "goodServerListenerResourceNameTemplate", wantConfig: &Config{ xDSServers: []*ServerConfig{{ - serverURI: "trafficdirector.googleapis.com:443", - channelCreds: []ChannelCreds{{Type: "google_default"}}, - selectedCreds: ChannelCreds{Type: "google_default"}, + serverURI: "trafficdirector.googleapis.com:443", + channelCreds: []ChannelCreds{{Type: "google_default"}}, + selectedChannelCreds: ChannelCreds{Type: "google_default"}, }}, node: v3Node, serverListenerResourceNameTemplate: "grpc/server?xds.resource.listening_address=%s", @@ -915,9 +1085,9 @@ func (s) TestGetConfiguration_Federation(t *testing.T) { name: "good", wantConfig: &Config{ xDSServers: []*ServerConfig{{ - serverURI: "trafficdirector.googleapis.com:443", - channelCreds: []ChannelCreds{{Type: "google_default"}}, - selectedCreds: ChannelCreds{Type: "google_default"}, + serverURI: "trafficdirector.googleapis.com:443", + channelCreds: []ChannelCreds{{Type: "google_default"}}, + selectedChannelCreds: ChannelCreds{Type: "google_default"}, }}, node: v3Node, serverListenerResourceNameTemplate: "xdstp://xds.example.com/envoy.config.listener.v3.Listener/grpc/server?listening_address=%s", @@ -926,10 +1096,10 @@ func (s) TestGetConfiguration_Federation(t *testing.T) { "xds.td.com": { ClientListenerResourceNameTemplate: "xdstp://xds.td.com/envoy.config.listener.v3.Listener/%s", XDSServers: []*ServerConfig{{ - serverURI: "td.com", - channelCreds: []ChannelCreds{{Type: "google_default"}}, - serverFeatures: []string{"xds_v3"}, - selectedCreds: ChannelCreds{Type: "google_default"}, + serverURI: "td.com", + channelCreds: []ChannelCreds{{Type: "google_default"}}, + serverFeatures: []string{"xds_v3"}, + selectedChannelCreds: ChannelCreds{Type: "google_default"}, }}, }, }, @@ -939,9 +1109,9 @@ func (s) TestGetConfiguration_Federation(t *testing.T) { name: "goodWithDefaultDefaultClientListenerTemplate", wantConfig: &Config{ xDSServers: []*ServerConfig{{ - serverURI: "trafficdirector.googleapis.com:443", - channelCreds: []ChannelCreds{{Type: "google_default"}}, - selectedCreds: ChannelCreds{Type: "google_default"}, + serverURI: "trafficdirector.googleapis.com:443", + channelCreds: []ChannelCreds{{Type: "google_default"}}, + selectedChannelCreds: ChannelCreds{Type: "google_default"}, }}, node: v3Node, clientDefaultListenerResourceNameTemplate: "%s", @@ -951,9 +1121,9 @@ func (s) TestGetConfiguration_Federation(t *testing.T) { name: "goodWithDefaultClientListenerTemplatePerAuthority", wantConfig: &Config{ xDSServers: []*ServerConfig{{ - serverURI: "trafficdirector.googleapis.com:443", - channelCreds: []ChannelCreds{{Type: "google_default"}}, - selectedCreds: ChannelCreds{Type: "google_default"}, + serverURI: "trafficdirector.googleapis.com:443", + channelCreds: []ChannelCreds{{Type: "google_default"}}, + selectedChannelCreds: ChannelCreds{Type: "google_default"}, }}, node: v3Node, clientDefaultListenerResourceNameTemplate: "xdstp://xds.example.com/envoy.config.listener.v3.Listener/%s", @@ -971,9 +1141,9 @@ func (s) TestGetConfiguration_Federation(t *testing.T) { name: "goodWithNoServerPerAuthority", wantConfig: &Config{ xDSServers: []*ServerConfig{{ - serverURI: "trafficdirector.googleapis.com:443", - channelCreds: []ChannelCreds{{Type: "google_default"}}, - selectedCreds: ChannelCreds{Type: "google_default"}, + serverURI: "trafficdirector.googleapis.com:443", + channelCreds: []ChannelCreds{{Type: "google_default"}}, + selectedChannelCreds: ChannelCreds{Type: "google_default"}, }}, node: v3Node, clientDefaultListenerResourceNameTemplate: "xdstp://xds.example.com/envoy.config.listener.v3.Listener/%s", @@ -1018,7 +1188,7 @@ func (s) TestDefaultBundles(t *testing.T) { for _, typename := range tests { t.Run(typename, func(t *testing.T) { - if c := bootstrap.GetCredentials(typename); c == nil { + if c := bootstrap.GetChannelCredentials(typename); c == nil { t.Errorf(`bootstrap.GetCredentials(%s) credential is nil, want non-nil`, typename) } }) @@ -1033,6 +1203,185 @@ func Test(t *testing.T) { grpctest.RunSubTests(t, s{}) } +func (s) TestCallCreds_Equal(t *testing.T) { + tests := []struct { + name string + cc1 CallCredsConfig + cc2 CallCredsConfig + want bool + }{ + { + name: "identical_configs", + cc1: CallCredsConfig{Type: "jwt_token_file", Config: json.RawMessage(`{"jwt_token_file": "/path/to/token"}`)}, + cc2: CallCredsConfig{Type: "jwt_token_file", Config: json.RawMessage(`{"jwt_token_file": "/path/to/token"}`)}, + want: true, + }, + { + name: "different_types", + cc1: CallCredsConfig{Type: "jwt_token_file", Config: json.RawMessage(`{"jwt_token_file": "/path/to/token"}`)}, + cc2: CallCredsConfig{Type: "other_type", Config: json.RawMessage(`{"jwt_token_file": "/path/to/token"}`)}, + want: false, + }, + { + name: "different_configs", + cc1: CallCredsConfig{Type: "jwt_token_file", Config: json.RawMessage(`{"jwt_token_file": "/path/to/token"}`)}, + cc2: CallCredsConfig{Type: "jwt_token_file", Config: json.RawMessage(`{"jwt_token_file": "/different/path"}`)}, + want: false, + }, + { + name: "nil_vs_non-nil_configs", + cc1: CallCredsConfig{Type: "jwt_token_file", Config: nil}, + cc2: CallCredsConfig{Type: "jwt_token_file", Config: json.RawMessage(`{"jwt_token_file": "/path/to/token"}`)}, + want: false, + }, + { + name: "both_nil_configs", + cc1: CallCredsConfig{Type: "jwt_token_file", Config: nil}, + cc2: CallCredsConfig{Type: "jwt_token_file", Config: nil}, + want: true, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + if got := test.cc1.Equal(test.cc2); got != test.want { + t.Errorf("CallCreds.Equal() = %v, want %v", got, test.want) + } + }) + } +} + +func (s) TestServerConfig_UnmarshalJSON_WithCallCreds(t *testing.T) { + testutils.SetEnvConfig(t, &envconfig.XDSBootstrapCallCredsEnabled, true) + tests := []struct { + name string + json string + wantCallCreds CallCredsConfigs + }{ + { + name: "valid_call_creds_with_jwt_token_file", + json: `{ + "server_uri": "xds-server:443", + "channel_creds": [{"type": "insecure"}], + "call_creds": [ + { + "type": "jwt_token_file", + "config": {"jwt_token_file": "/path/to/token.jwt"} + } + ] + }`, + wantCallCreds: []CallCredsConfig{{ + Type: "jwt_token_file", + Config: json.RawMessage(`{"jwt_token_file": "/path/to/token.jwt"}`), + }}, + }, + { + name: "multiple_call_creds_types", + json: `{ + "server_uri": "xds-server:443", + "channel_creds": [{"type": "insecure"}], + "call_creds": [ + {"type": "jwt_token_file", "config": {"jwt_token_file": "/token1.jwt"}}, + {"type": "unsupported_type", "config": {}} + ] + }`, + wantCallCreds: []CallCredsConfig{ + {Type: "jwt_token_file", Config: json.RawMessage(`{"jwt_token_file": "/token1.jwt"}`)}, + {Type: "unsupported_type", Config: json.RawMessage(`{}`)}, + }, + }, + { + name: "empty_call_creds_array", + json: `{ + "server_uri": "xds-server:443", + "channel_creds": [{"type": "insecure"}], + "call_creds": [] + }`, + wantCallCreds: []CallCredsConfig{}, + }, + { + name: "unspecified_call_creds_field", + json: `{ + "server_uri": "xds-server:443", + "channel_creds": [{"type": "insecure"}] + }`, + wantCallCreds: nil, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + var sc ServerConfig + err := sc.UnmarshalJSON([]byte(test.json)) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + if diff := cmp.Diff(test.wantCallCreds, sc.CallCredsConfigs()); diff != "" { + t.Errorf("CallCreds mismatch (-want +got):\n%s", diff) + } + }) + } +} + +func (s) TestServerConfig_Equal_WithCallCreds(t *testing.T) { + callCreds := []CallCredsConfig{{ + Type: "jwt_token_file", + Config: json.RawMessage(`{"jwt_token_file": "/test/token.jwt"}`), + }} + sc1 := &ServerConfig{ + serverURI: "server1", + channelCreds: []ChannelCreds{{Type: "insecure"}}, + callCredsConfigs: callCreds, + serverFeatures: []string{"feature1"}, + } + sc2 := &ServerConfig{ + serverURI: "server1", + channelCreds: []ChannelCreds{{Type: "insecure"}}, + callCredsConfigs: callCreds, + serverFeatures: []string{"feature1"}, + } + sc3 := &ServerConfig{ + serverURI: "server1", + channelCreds: []ChannelCreds{{Type: "insecure"}}, + callCredsConfigs: []CallCredsConfig{{Type: "different"}}, + serverFeatures: []string{"feature1"}, + } + + if !sc1.Equal(sc2) { + t.Error("Equal ServerConfigs with same call creds should be equal") + } + if sc1.Equal(sc3) { + t.Error("ServerConfigs with different call creds should not be equal") + } +} + +func (s) TestServerConfig_MarshalJSON_WithCallCreds(t *testing.T) { + testutils.SetEnvConfig(t, &envconfig.XDSBootstrapCallCredsEnabled, true) + sc := &ServerConfig{ + serverURI: "test-server:443", + channelCreds: []ChannelCreds{{Type: "insecure"}}, + callCredsConfigs: []CallCredsConfig{{ + Type: "jwt_token_file", + Config: json.RawMessage(`{"jwt_token_file":"/test/token.jwt"}`), + }}, + serverFeatures: []string{"test_feature"}, + } + + data, err := sc.MarshalJSON() + if err != nil { + t.Fatalf("MarshalJSON failed: %v", err) + } + + // Check Marshal/Unmarshal symmetry. + var unmarshaled ServerConfig + if err := json.Unmarshal(data, &unmarshaled); err != nil { + t.Fatalf("Unmarshal failed: %v", err) + } + if diff := cmp.Diff(sc.CallCredsConfigs(), unmarshaled.CallCredsConfigs()); diff != "" { + t.Errorf("Marshal/Unmarshal call credentials produces differences:\n%s", diff) + } +} + func newStructProtoFromMap(t *testing.T, input map[string]any) *structpb.Struct { t.Helper() @@ -1201,3 +1550,105 @@ func (s) TestNode_ToProto(t *testing.T) { }) } } + +func (s) TestBootstrap_SelectedChannelCredsAndCallCreds(t *testing.T) { + testutils.SetEnvConfig(t, &envconfig.XDSBootstrapCallCredsEnabled, true) + tests := []struct { + name string + bootstrapConfig string + wantDialOpts int + wantTransportType string + }{ + { + name: "JWT_call_creds_with_TLS_channel_creds", + bootstrapConfig: `{ + "server_uri": "xds-server:443", + "channel_creds": [{"type": "tls", "config": {}}], + "call_creds": [ + { + "type": "jwt_token_file", + "config": {"jwt_token_file": "/token.jwt"} + } + ] + }`, + wantDialOpts: 1, + wantTransportType: "tls", + }, + { + name: "JWT_call_creds_with_multiple_channel_creds", + bootstrapConfig: `{ + "server_uri": "xds-server:443", + "channel_creds": [{"type": "tls", "config": {}}, {"type": "insecure"}], + "call_creds": [ + { + "type": "jwt_token_file", + "config": {"jwt_token_file": "/token.jwt"} + }, + { + "type": "jwt_token_file", + "config": {"jwt_token_file": "/token2.jwt"} + } + ] + }`, + wantDialOpts: 2, + wantTransportType: "tls", // The first channel creds is selected. + }, + { + name: "JWT_call_creds_with_insecure_channel_creds", + bootstrapConfig: `{ + "server_uri": "xds-server:443", + "channel_creds": [{"type": "insecure"}], + "call_creds": [ + { + "type": "jwt_token_file", + "config": {"jwt_token_file": "/token.jwt"} + } + ] + }`, + wantDialOpts: 1, + wantTransportType: "insecure", + }, + { + name: "No_call_creds", + bootstrapConfig: `{ + "server_uri": "xds-server:443", + "channel_creds": [{"type": "insecure"}] + }`, + wantDialOpts: 0, + wantTransportType: "insecure", + }, + { + name: "No_call_creds_multiple_channel_creds", + bootstrapConfig: `{ + "server_uri": "xds-server:443", + "channel_creds": [{"type": "insecure"}, {"type": "tls", "config": {}}] + }`, + wantDialOpts: 0, + wantTransportType: "insecure", + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + var sc ServerConfig + err := sc.UnmarshalJSON([]byte(test.bootstrapConfig)) + if err != nil { + t.Fatalf("Failed to unmarshal bootstrap config: %v", err) + } + + // Verify call credentials processing. + callCredsConfig := sc.CallCredsConfigs() + dialOpts := sc.DialOptions() + if len(callCredsConfig) != test.wantDialOpts { + t.Errorf("Call creds configs count = %d, want %d", len(callCredsConfig), test.wantDialOpts) + } + if len(dialOpts) != test.wantDialOpts { + t.Errorf("Call creds count = %d, want %d", len(dialOpts), test.wantDialOpts) + } + // Verify transport credentials are properly selected. + if sc.SelectedChannelCreds().Type != test.wantTransportType { + t.Errorf("Selected transport creds type = %q, want %q", sc.SelectedChannelCreds().Type, test.wantTransportType) + } + }) + } +} diff --git a/internal/xds/bootstrap/jwtcreds/call_creds.go b/internal/xds/bootstrap/jwtcreds/call_creds.go new file mode 100644 index 000000000000..60f00acedad3 --- /dev/null +++ b/internal/xds/bootstrap/jwtcreds/call_creds.go @@ -0,0 +1,53 @@ +/* + * + * Copyright 2025 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +// Package jwtcreds implements JWT CallCredentials for XDS, configured via xDS +// Bootstrap File. For more details, see gRFC A97: +// https://github.com/grpc/proposal/blob/master/A97-xds-jwt-call-creds.md +package jwtcreds + +import ( + "encoding/json" + "fmt" + + "google.golang.org/grpc/credentials" + "google.golang.org/grpc/credentials/jwt" +) + +// NewCallCredentials returns a new JWT token based call credentials. The input +// config must match the structure specified in gRFC A97. The caller is expected +// to invoke the second return value when they are done using the returned call creds. +// The cancel function is idempotent. +func NewCallCredentials(configJSON json.RawMessage) (credentials.PerRPCCredentials, func(), error) { + var cfg struct { + JWTTokenFile string `json:"jwt_token_file"` + } + emptyFn := func() {} + + if err := json.Unmarshal(configJSON, &cfg); err != nil { + return nil, emptyFn, fmt.Errorf("failed to unmarshal JWT call credentials config: %v", err) + } + if cfg.JWTTokenFile == "" { + return nil, emptyFn, fmt.Errorf("jwt_token_file is required in JWT call credentials config") + } + callCreds, err := jwt.NewTokenFileCallCredentials(cfg.JWTTokenFile) + if err != nil { + return nil, emptyFn, fmt.Errorf("failed to create JWT call credentials: %v", err) + } + return callCreds, emptyFn, nil +} diff --git a/internal/xds/bootstrap/jwtcreds/call_creds_test.go b/internal/xds/bootstrap/jwtcreds/call_creds_test.go new file mode 100644 index 000000000000..90f61efd7e0c --- /dev/null +++ b/internal/xds/bootstrap/jwtcreds/call_creds_test.go @@ -0,0 +1,168 @@ +/* + * + * Copyright 2025 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +package jwtcreds + +import ( + "context" + "encoding/json" + "os" + "path/filepath" + "strings" + "testing" + "time" + + "google.golang.org/grpc/credentials" + "google.golang.org/grpc/internal/grpctest" +) + +type s struct { + grpctest.Tester +} + +func Test(t *testing.T) { + grpctest.RunSubTests(t, s{}) +} + +func (s) TestNewCallCredentialsWithInvalidConfig(t *testing.T) { + tests := []struct { + name string + config string + }{ + { + name: "empty_file", + config: `""`, + }, + { + name: "empty_config", + config: `{}`, + }, + { + name: "empty_path", + config: `{"jwt_token_file": ""}`, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + callCreds, cleanup, err := NewCallCredentials(json.RawMessage(tt.config)) + if err == nil { + t.Fatalf("NewCallCredentials(%s): got nil, want error", tt.config) + } + if callCreds != nil { + t.Errorf("NewCallCredentials(%s): Expected nil call credentials to be returned", tt.config) + } + if cleanup == nil { + t.Errorf("NewCallCredentials(%s): Expected non-nil cleanup function to be returned", tt.config) + } + }) + } +} + +func (s) TestNewCallCredentialsWithValidConfig(t *testing.T) { + token := createTestJWT(t) + tokenFile := writeTempFile(t, token) + config := `{"jwt_token_file": "` + tokenFile + `"}` + + callCreds, cleanup, err := NewCallCredentials(json.RawMessage(config)) + if err != nil { + t.Fatalf("NewCallCredentials(%s) failed: %v", config, err) + } + if callCreds == nil { + t.Fatalf("NewCallCredentials(%s): Expected non-nil credentials to be returned", config) + } + if cleanup == nil { + t.Errorf("NewCallCredentials(%s): Expected non-nil cleanup function to be returned", config) + } else { + defer cleanup() + } + + // Test that call credentials get used. + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + defer cancel() + ctx = credentials.NewContextWithRequestInfo(ctx, credentials.RequestInfo{ + AuthInfo: &testAuthInfo{secLevel: credentials.PrivacyAndIntegrity}, + }) + metadata, err := callCreds.GetRequestMetadata(ctx) + if err != nil { + t.Fatalf("GetRequestMetadata failed: %v", err) + } + if len(metadata) == 0 { + t.Fatal("GetRequestMetadata: Expected metadata to be returned") + } + authHeader, ok := metadata["authorization"] + if !ok { + t.Fatal("GetRequestMetadata: Expected authorization header in metadata") + } + if !strings.HasPrefix(authHeader, "Bearer ") { + t.Errorf("GetRequestMetadata: Authorization header should start with 'Bearer ', got %q", authHeader) + } +} + +func (s) TestCallCredentials_Cleanup(t *testing.T) { + token := createTestJWT(t) + tokenFile := writeTempFile(t, token) + config := `{"jwt_token_file": "` + tokenFile + `"}` + _, cleanup, err := NewCallCredentials(json.RawMessage(config)) + if err != nil { + t.Fatalf("NewCallCredentials failed: %v", err) + } + if cleanup == nil { + t.Fatal("NewCallCredentials: Expected non-nil cleanup function") + } + // Cleanup should not panic + cleanup() + // Multiple cleanup calls should be safe + cleanup() +} + +// testAuthInfo implements credentials.AuthInfo for testing. +type testAuthInfo struct { + secLevel credentials.SecurityLevel +} + +func (t *testAuthInfo) AuthType() string { + return "test" +} + +func (t *testAuthInfo) GetCommonAuthInfo() credentials.CommonAuthInfo { + return credentials.CommonAuthInfo{SecurityLevel: t.secLevel} +} + +// createTestJWT creates a test JWT token for testing. +func createTestJWT(t *testing.T) string { + t.Helper() + + // Header: {"typ":"JWT","alg":"HS256"} + header := "eyJ0eXAiOiJKV1QiLCJhbGciOiJIUzI1NiJ9" + // Claims: {"aud":"https://example.com","exp":future_timestamp} + claims := "eyJhdWQiOiJodHRwczovL2V4YW1wbGUuY29tIiwiZXhwIjoyMDAwMDAwMDAwfQ" + signature := "fake_signature_for_testing" + + return header + "." + claims + "." + signature +} + +func writeTempFile(t *testing.T, content string) string { + t.Helper() + tempDir := t.TempDir() + filePath := filepath.Join(tempDir, "jwt_token") + if err := os.WriteFile(filePath, []byte(content), 0600); err != nil { + t.Fatalf("Failed to write temp file: %v", err) + } + return filePath +} diff --git a/internal/xds/xdsclient/clientimpl.go b/internal/xds/xdsclient/clientimpl.go index b1f797993fd7..f2325e0e552f 100644 --- a/internal/xds/xdsclient/clientimpl.go +++ b/internal/xds/xdsclient/clientimpl.go @@ -181,7 +181,7 @@ func buildXDSClientConfig(config *bootstrap.Config, metricsRecorder estats.Metri return xdsclient.Config{}, err } gsc := xdsclient.ServerConfig{ - ServerIdentifier: clients.ServerIdentifier{ServerURI: sc.ServerURI(), Extensions: grpctransport.ServerIdentifierExtension{ConfigName: sc.SelectedCreds().Type}}, + ServerIdentifier: clients.ServerIdentifier{ServerURI: sc.ServerURI(), Extensions: grpctransport.ServerIdentifierExtension{ConfigName: sc.SelectedChannelCreds().Type}}, IgnoreResourceDeletion: sc.ServerFeaturesIgnoreResourceDeletion()} gServerCfg = append(gServerCfg, gsc) gServerCfgMap[gsc] = sc @@ -195,7 +195,7 @@ func buildXDSClientConfig(config *bootstrap.Config, metricsRecorder estats.Metri return xdsclient.Config{}, err } gsc := xdsclient.ServerConfig{ - ServerIdentifier: clients.ServerIdentifier{ServerURI: sc.ServerURI(), Extensions: grpctransport.ServerIdentifierExtension{ConfigName: sc.SelectedCreds().Type}}, + ServerIdentifier: clients.ServerIdentifier{ServerURI: sc.ServerURI(), Extensions: grpctransport.ServerIdentifierExtension{ConfigName: sc.SelectedChannelCreds().Type}}, IgnoreResourceDeletion: sc.ServerFeaturesIgnoreResourceDeletion()} gServerCfgs = append(gServerCfgs, gsc) gServerCfgMap[gsc] = sc @@ -233,7 +233,7 @@ func buildXDSClientConfig(config *bootstrap.Config, metricsRecorder estats.Metri // and populates the grpctransport.Config map. func populateGRPCTransportConfigsFromServerConfig(sc *bootstrap.ServerConfig, grpcTransportConfigs map[string]grpctransport.Config) error { for _, cc := range sc.ChannelCreds() { - c := xdsbootstrap.GetCredentials(cc.Type) + c := xdsbootstrap.GetChannelCredentials(cc.Type) if c == nil { continue } diff --git a/internal/xds/xdsclient/clientimpl_loadreport.go b/internal/xds/xdsclient/clientimpl_loadreport.go index ffd0c90b8f54..023f302136cb 100644 --- a/internal/xds/xdsclient/clientimpl_loadreport.go +++ b/internal/xds/xdsclient/clientimpl_loadreport.go @@ -35,7 +35,7 @@ func (c *clientImpl) ReportLoad(server *bootstrap.ServerConfig) (*lrsclient.Load load, err := c.lrsClient.ReportLoad(clients.ServerIdentifier{ ServerURI: server.ServerURI(), Extensions: grpctransport.ServerIdentifierExtension{ - ConfigName: server.SelectedCreds().Type, + ConfigName: server.SelectedChannelCreds().Type, }, }) if err != nil { diff --git a/internal/xds/xdsclient/clientimpl_test.go b/internal/xds/xdsclient/clientimpl_test.go index d2fc8f7f9332..830e1ef00066 100644 --- a/internal/xds/xdsclient/clientimpl_test.go +++ b/internal/xds/xdsclient/clientimpl_test.go @@ -29,6 +29,8 @@ import ( "github.com/google/go-cmp/cmp/cmpopts" "google.golang.org/grpc" "google.golang.org/grpc/credentials/insecure" + "google.golang.org/grpc/internal/envconfig" + "google.golang.org/grpc/internal/testutils" "google.golang.org/grpc/internal/testutils/stats" "google.golang.org/grpc/internal/xds/bootstrap" "google.golang.org/grpc/internal/xds/clients" @@ -259,3 +261,32 @@ func (s) TestBuildXDSClientConfig_Success(t *testing.T) { }) } } + +func (s) TestServerConfigCallCredsIntegration(t *testing.T) { + testutils.SetEnvConfig(t, &envconfig.XDSBootstrapCallCredsEnabled, true) + // Test server config with both channel and call credentials. + serverConfigJSON := `{ + "server_uri": "xds-server:443", + "channel_creds": [{"type": "tls", "config": {}}], + "call_creds": [ + { + "type": "jwt_token_file", + "config": {"jwt_token_file": "/token.jwt"} + } + ] + }` + + var sc bootstrap.ServerConfig + if err := sc.UnmarshalJSON([]byte(serverConfigJSON)); err != nil { + t.Fatalf("Failed to unmarshal server config: %v", err) + } + // Verify call credentials are processed. + callCreds := sc.CallCredsConfigs() + if len(callCreds) != 1 { + t.Errorf("Got %d call credential configs, want 1", len(callCreds)) + } + dialOpts := sc.DialOptions() + if len(dialOpts) != 1 { + t.Errorf("Got %d dial options, want 1", len(dialOpts)) + } +} diff --git a/internal/xds/xdsclient/tests/client_custom_dialopts_test.go b/internal/xds/xdsclient/tests/client_custom_dialopts_test.go index 39deb1240fe0..fd70f8556eaa 100644 --- a/internal/xds/xdsclient/tests/client_custom_dialopts_test.go +++ b/internal/xds/xdsclient/tests/client_custom_dialopts_test.go @@ -81,7 +81,7 @@ func (s) TestClientCustomDialOptsFromCredentialsBundle(t *testing.T) { credsBuilder := &testCredsBuilder{ testDialOptNames: []string{"opt1", "opt2", "opt3"}, } - bootstrap.RegisterCredentials(credsBuilder) + bootstrap.RegisterChannelCredentials(credsBuilder) // Start an xDS management server. mgmtServer := e2e.StartManagementServer(t, e2e.ManagementServerOptions{}) diff --git a/xds/bootstrap/bootstrap.go b/xds/bootstrap/bootstrap.go index ef55ff0c02db..3b7d1e9b2c5a 100644 --- a/xds/bootstrap/bootstrap.go +++ b/xds/bootstrap/bootstrap.go @@ -31,34 +31,70 @@ import ( "google.golang.org/grpc/credentials" ) -// registry is a map from credential type name to Credential builder. -var registry = make(map[string]Credentials) +// channelCredsRegistry is a map from channel credential type name to +// ChannelCredential builder. +var channelCredsRegistry = make(map[string]ChannelCredentials) -// Credentials interface encapsulates a credentials.Bundle builder +// callCredsRegistry is a map from call credential type name to +// ChannelCredential builder. +var callCredsRegistry = make(map[string]CallCredentials) + +// ChannelCredentials interface encapsulates a credentials.Bundle builder // that can be used for communicating with the xDS Management server. -type Credentials interface { - // Build returns a credential bundle associated with this credential, and - // a function to cleans up additional resources associated with this bundle +type ChannelCredentials interface { + // Build returns a credential bundle associated with this credential, and a + // function to clean up any additional resources associated with this bundle // when it is no longer needed. Build(config json.RawMessage) (credentials.Bundle, func(), error) // Name returns the credential name associated with this credential. Name() string } -// RegisterCredentials registers Credentials used for connecting to the xds -// management server. +// RegisterChannelCredentials registers ChannelCredentials used for connecting +// to the xDS management server. +// +// NOTE: this function must only be called during initialization time (i.e. in +// an init() function), and is not thread-safe. If multiple credentials are +// registered with the same name, the one registered last will take effect. +func RegisterChannelCredentials(c ChannelCredentials) { + channelCredsRegistry[c.Name()] = c +} + +// GetChannelCredentials returns the credentials associated with a given name. +// If no credentials are registered with the name, nil will be returned. +func GetChannelCredentials(name string) ChannelCredentials { + if c, ok := channelCredsRegistry[name]; ok { + return c + } + + return nil +} + +// CallCredentials interface encapsulates a credentials.PerRPCCredentials +// builder that can be used for communicating with the xDS Management server. +type CallCredentials interface { + // Build returns a PerRPCCredentials created from the provided + // configuration, and a function to clean up any additional resources + // associated with them when they are no longer needed. + Build(config json.RawMessage) (credentials.PerRPCCredentials, func(), error) + // Name returns the credential name associated with this credential. + Name() string +} + +// RegisterCallCredentials registers CallCredentials used for connecting +// to the xDS management server. // // NOTE: this function must only be called during initialization time (i.e. in // an init() function), and is not thread-safe. If multiple credentials are // registered with the same name, the one registered last will take effect. -func RegisterCredentials(c Credentials) { - registry[c.Name()] = c +func RegisterCallCredentials(c CallCredentials) { + callCredsRegistry[c.Name()] = c } -// GetCredentials returns the credentials associated with a given name. +// GetCallCredentials returns the credentials associated with a given name. // If no credentials are registered with the name, nil will be returned. -func GetCredentials(name string) Credentials { - if c, ok := registry[name]; ok { +func GetCallCredentials(name string) CallCredentials { + if c, ok := callCredsRegistry[name]; ok { return c } diff --git a/xds/bootstrap/bootstrap_test.go b/xds/bootstrap/bootstrap_test.go index d1f7a1b64ee5..d8b423cabbf0 100644 --- a/xds/bootstrap/bootstrap_test.go +++ b/xds/bootstrap/bootstrap_test.go @@ -22,6 +22,8 @@ import ( "testing" "google.golang.org/grpc/credentials" + "google.golang.org/grpc/internal/envconfig" + "google.golang.org/grpc/internal/testutils" ) const testCredsBuilderName = "test_creds" @@ -29,7 +31,7 @@ const testCredsBuilderName = "test_creds" var builder = &testCredsBuilder{} func init() { - RegisterCredentials(builder) + RegisterChannelCredentials(builder) } type testCredsBuilder struct { @@ -46,7 +48,7 @@ func (t *testCredsBuilder) Name() string { } func TestRegisterNew(t *testing.T) { - c := GetCredentials(testCredsBuilderName) + c := GetChannelCredentials(testCredsBuilderName) if c == nil { t.Fatalf("GetCredentials(%q) credential = nil", testCredsBuilderName) } @@ -62,10 +64,10 @@ func TestRegisterNew(t *testing.T) { } } -func TestCredsBuilders(t *testing.T) { +func TestChannelCredsBuilders(t *testing.T) { tests := []struct { typename string - builder Credentials + builder ChannelCredentials }{ {"google_default", &googleDefaultCredsBuilder{}}, {"insecure", &insecureCredsBuilder{}}, @@ -78,25 +80,73 @@ func TestCredsBuilders(t *testing.T) { t.Errorf("%T.Name = %v, want %v", test.builder, got, want) } - _, stop, err := test.builder.Build(nil) + bundle, stop, err := test.builder.Build(nil) if err != nil { t.Fatalf("%T.Build failed: %v", test.builder, err) } + if bundle == nil { + t.Errorf("%T.Build returned nil bundle, expected non-nil", test.builder) + } stop() }) } } +func TestCallCredsBuilders(t *testing.T) { + tests := []struct { + typename string + builder CallCredentials + minimumRequiredConfig json.RawMessage + }{ + {"jwt_token_file", &jwtCallCredsBuilder{}, json.RawMessage(`{"jwt_token_file":"/path/to/token.jwt"}`)}, + } + + for _, test := range tests { + t.Run(test.typename, func(t *testing.T) { + if got, want := test.builder.Name(), test.typename; got != want { + t.Errorf("%T.Name = %v, want %v", test.builder, got, want) + } + + bundle, stop, err := test.builder.Build(test.minimumRequiredConfig) + if err != nil { + t.Fatalf("%T.Build failed: %v", test.builder, err) + } + defer stop() + if bundle == nil { + t.Errorf("%T.Build returned nil bundle, expected non-nil", test.builder) + } + }) + } +} + func TestTlsCredsBuilder(t *testing.T) { tls := &tlsCredsBuilder{} _, stop, err := tls.Build(json.RawMessage(`{}`)) if err != nil { t.Fatalf("tls.Build() failed with error %s when expected to succeed", err) } - stop() + defer stop() if _, stop, err := tls.Build(json.RawMessage(`{"ca_certificate_file":"/ca_certificates.pem","refresh_interval": "asdf"}`)); err == nil { + defer stop() t.Errorf("tls.Build() succeeded with an invalid refresh interval, when expected to fail") - stop() + } +} + +func TestJwtCallCredentials_DisabledIfFeatureNotEnabled(t *testing.T) { + builder := GetCallCredentials("jwt_call_creds") + if builder != nil { + t.Fatal("Expected nil Credentials for jwt_call_creds when the feature is disabled.") + } + + testutils.SetEnvConfig(t, &envconfig.XDSBootstrapCallCredsEnabled, true) + + // Test that GetCredentials returns the JWT builder. + builder = GetCallCredentials("jwt_token_file") + if builder == nil { + t.Fatal("GetCallCredentials(\"jwt_token_file\") returned nil") + } + if got, want := builder.Name(), "jwt_token_file"; got != want { + t.Errorf("Retrieved builder name = %q, want %q", got, want) } } diff --git a/xds/bootstrap/credentials.go b/xds/bootstrap/credentials.go index 578e1278970d..85fdd4516e01 100644 --- a/xds/bootstrap/credentials.go +++ b/xds/bootstrap/credentials.go @@ -24,16 +24,19 @@ import ( "google.golang.org/grpc/credentials" "google.golang.org/grpc/credentials/google" "google.golang.org/grpc/credentials/insecure" + "google.golang.org/grpc/internal/xds/bootstrap/jwtcreds" "google.golang.org/grpc/internal/xds/bootstrap/tlscreds" ) func init() { - RegisterCredentials(&insecureCredsBuilder{}) - RegisterCredentials(&googleDefaultCredsBuilder{}) - RegisterCredentials(&tlsCredsBuilder{}) + RegisterChannelCredentials(&insecureCredsBuilder{}) + RegisterChannelCredentials(&googleDefaultCredsBuilder{}) + RegisterChannelCredentials(&tlsCredsBuilder{}) + + RegisterCallCredentials(&jwtCallCredsBuilder{}) } -// insecureCredsBuilder implements the `Credentials` interface defined in +// insecureCredsBuilder implements the `ChannelCredentials` interface defined in // package `xds/bootstrap` and encapsulates an insecure credential. type insecureCredsBuilder struct{} @@ -45,7 +48,7 @@ func (i *insecureCredsBuilder) Name() string { return "insecure" } -// tlsCredsBuilder implements the `Credentials` interface defined in +// tlsCredsBuilder implements the `ChannelCredentials` interface defined in // package `xds/bootstrap` and encapsulates a TLS credential. type tlsCredsBuilder struct{} @@ -57,7 +60,7 @@ func (t *tlsCredsBuilder) Name() string { return "tls" } -// googleDefaultCredsBuilder implements the `Credentials` interface defined in +// googleDefaultCredsBuilder implements the `ChannelCredentials` interface defined in // package `xds/bootstrap` and encapsulates a Google Default credential. type googleDefaultCredsBuilder struct{} @@ -68,3 +71,15 @@ func (d *googleDefaultCredsBuilder) Build(json.RawMessage) (credentials.Bundle, func (d *googleDefaultCredsBuilder) Name() string { return "google_default" } + +// jwtCallCredsBuilder implements the `Credentials` interface defined in +// package `xds/bootstrap` and encapsulates JWT call credentials. +type jwtCallCredsBuilder struct{} + +func (j *jwtCallCredsBuilder) Build(configJSON json.RawMessage) (credentials.PerRPCCredentials, func(), error) { + return jwtcreds.NewCallCredentials(configJSON) +} + +func (j *jwtCallCredsBuilder) Name() string { + return "jwt_token_file" +}