Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 72 additions & 0 deletions pkg/runner/config_builder_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,13 @@ import (
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"

"github.com/stacklok/toolhive/pkg/auth"
"github.com/stacklok/toolhive/pkg/auth/tokenexchange"
"github.com/stacklok/toolhive/pkg/logger"
"github.com/stacklok/toolhive/pkg/mcp"
"github.com/stacklok/toolhive/pkg/permissions"
"github.com/stacklok/toolhive/pkg/registry"
"github.com/stacklok/toolhive/pkg/transport/types"
)

func TestRunConfigBuilder_Build_WithPermissionProfile(t *testing.T) {
Expand Down Expand Up @@ -365,6 +369,74 @@ func createTempProfileFile(t *testing.T, content string) (string, func()) {
return tempFile.Name(), cleanup
}

// TestAddCoreMiddlewares_TokenExchangeIntegration verifies token-exchange middleware integration and parameter propagation.
func TestAddCoreMiddlewares_TokenExchangeIntegration(t *testing.T) {
t.Parallel()

// Prevent nil pointer dereference in the logger.
logger.Initialize()

t.Run("token-exchange NOT added when config is nil", func(t *testing.T) {
t.Parallel()

var mws []types.MiddlewareConfig
// OIDC config can be empty for this unit test since we're only testing token-exchange behavior.
mws = addCoreMiddlewares(mws, &auth.TokenValidatorConfig{}, nil)

// Expect only auth + mcp parser when token-exchange config == nil
require.Len(t, mws, 2, "expected only auth and mcp parser middlewares when token-exchange config is nil")
assert.Equal(t, auth.MiddlewareType, mws[0].Type, "first middleware should be auth")
assert.Equal(t, mcp.ParserMiddlewareType, mws[1].Type, "second middleware should be MCP parser")

// Ensure token-exchange type is not present in any middleware slot.
for i, mw := range mws {
assert.NotEqual(t, tokenexchange.MiddlewareType, mw.Type, "middleware[%d] should not be token-exchange", i)
}
})

t.Run("token-exchange IS added, correctly ordered and parameters populated when config provided", func(t *testing.T) {
t.Parallel()

var mws []types.MiddlewareConfig
// Provide a realistic config to ensure full parameter serialization and propagation.
teCfg := &tokenexchange.Config{
TokenURL: "https://example.com/token",
ClientID: "test-client-id",
ClientSecret: "test-client-secret",
Audience: "test-audience",
Scopes: []string{"scope1", "scope2"},
// SubjectTokenType: "", // default is access_token if empty
HeaderStrategy: tokenexchange.HeaderStrategyReplace, // default behavior
// ExternalTokenHeaderName not required for replace strategy
}

mws = addCoreMiddlewares(mws, &auth.TokenValidatorConfig{}, teCfg)

// Expect auth, token-exchange, then mcp parser — verify correct order and count.
require.Len(t, mws, 3, "expected auth, token-exchange and mcp parser middlewares when token-exchange config is provided")
assert.Equal(t, auth.MiddlewareType, mws[0].Type, "first middleware should be auth")
assert.Equal(t, tokenexchange.MiddlewareType, mws[1].Type, "second middleware should be token-exchange")
assert.Equal(t, mcp.ParserMiddlewareType, mws[2].Type, "third middleware should be MCP parser")

// Verify the token-exchange middleware parameters are serialized and populated.
require.NotNil(t, mws[1].Parameters, "token-exchange middleware Parameters should not be nil")
require.NotZero(t, len(mws[1].Parameters), "token-exchange middleware Parameters should not be empty")

// Deserialize middleware parameters and validate field propagation.
var mwParams tokenexchange.MiddlewareParams
err := json.Unmarshal(mws[1].Parameters, &mwParams)
require.NoError(t, err, "unmarshal of middleware Parameters should not fail")

require.NotNil(t, mwParams.TokenExchangeConfig, "TokenExchangeConfig in middleware params should not be nil")
assert.Equal(t, teCfg.TokenURL, mwParams.TokenExchangeConfig.TokenURL, "TokenURL should propagate into middleware params")
assert.Equal(t, teCfg.ClientID, mwParams.TokenExchangeConfig.ClientID, "ClientID should propagate into middleware params")
assert.Equal(t, teCfg.ClientSecret, mwParams.TokenExchangeConfig.ClientSecret, "ClientSecret should propagate into middleware params")
assert.Equal(t, teCfg.Audience, mwParams.TokenExchangeConfig.Audience, "Audience should propagate into middleware params")
assert.Equal(t, teCfg.Scopes, mwParams.TokenExchangeConfig.Scopes, "Scopes should propagate into middleware params")
assert.Equal(t, teCfg.HeaderStrategy, mwParams.TokenExchangeConfig.HeaderStrategy, "HeaderStrategy should propagate into middleware params")
})
}

func TestRunConfigBuilder_WithToolOverride(t *testing.T) {
t.Parallel()

Expand Down